# Meta-Learning with MAML and Omniglot (Optimized)

This notebook implements **Model-Agnostic Meta-Learning (MAML)**. This version has been **optimized** to address common issues with long runtimes and stagnant training loss.

**Key Changes:**
1.  **Reduced Inner Learning Rate:** To help the model learn effectively.
2.  **Reduced Batch Size:** To significantly decrease training time.
3.  **Early Stopping for Debugging:** The training loop stops after 100 batches to allow for quick verification.

## 1. Installation and Imports

In [None]:
!pip install -q comet_ml torch torchvision

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import Omniglot
from PIL import Image
import random
import copy
import numpy as np


## 2. Defining the N-way K-shot Task Dataset

This dataset generates entire few-shot classification tasks. Each item from this dataset is a complete task, containing a support set (for learning) and a query set (for testing).

In [None]:
class OmniglotTaskDataset(Dataset):
    """Dataset for generating N-way K-shot tasks from Omniglot."""
    def __init__(self, n_way, k_shot, k_query, root_dir, background=True, transform=None):
        super().__init__()
        self.omniglot = Omniglot(root=root_dir, background=background, download=True)
        self.transform = transform
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query

        # Group image indices by character class
        self.class_indices = {}
        for i, (_, label) in enumerate(self.omniglot._flat_character_images):
            if label not in self.class_indices:
                self.class_indices[label] = []
            self.class_indices[label].append(i)
        
        self.class_list = list(self.class_indices.keys())

    def __len__(self):
        # The number of possible tasks is very large, so we can set a large number for an epoch
        return 20000

    def __getitem__(self, index):
        # 1. Sample N classes for the task
        task_classes = random.sample(self.class_list, self.n_way)

        support_images, support_labels = [], []
        query_images, query_labels = [], []

        for i, cls_idx in enumerate(task_classes):
            # 2. Sample K-shot + K-query images from each class
            class_images_indices = self.class_indices[cls_idx]
            if len(class_images_indices) < self.k_shot + self.k_query:
                sampled_indices = random.choices(class_images_indices, k=self.k_shot + self.k_query)
            else:
                sampled_indices = random.sample(class_images_indices, self.k_shot + self.k_query)
            
            # 3. Split into support and query sets
            support_indices = sampled_indices[:self.k_shot]
            query_indices = sampled_indices[self.k_shot:]

            # Load and transform images for the support set
            for idx in support_indices:
                img, _ = self.omniglot[idx]
                if self.transform:
                    img = self.transform(img)
                support_images.append(img)
                support_labels.append(i)
            
            # Load and transform images for the query set
            for idx in query_indices:
                img, _ = self.omniglot[idx]
                if self.transform:
                    img = self.transform(img)
                query_images.append(img)
                query_labels.append(i)

        # Convert to tensors
        support_images = torch.stack(support_images)
        support_labels = torch.LongTensor(support_labels)
        query_images = torch.stack(query_images)
        query_labels = torch.LongTensor(query_labels)

        return support_images, support_labels, query_images, query_labels

# --- Model Definition ---
class MetaLearner(nn.Module):
    def __init__(self, input_channels, num_classes):
        super(MetaLearner, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, num_classes)
        )
    def forward(self, x):
        return self.net(x)


## 3. Hyperparameters and Data Loading

**Changes Applied Here:**
- `BATCH_SIZE` reduced from 32 to 8 to speed up training.
- `INNER_LR` reduced from 0.04 to 0.01 to fix the stagnant loss issue.

In [None]:
# Meta-learning parameters
N_WAY = 5          # Number of classes in a task
K_SHOT = 1         # Number of support examples per class
K_QUERY = 15       # Number of query examples per class

# --- OPTIMIZATION 1: Reduce batch size to decrease runtime ---
BATCH_SIZE = 8     # Reduced from 32. Fewer tasks per step = faster training.
NUM_EPOCHS = 1     # One epoch is long (20k tasks), so 1 is enough for a demo

# MAML-specific parameters
META_LR = 0.001       # Outer loop learning rate

# --- OPTIMIZATION 2: Reduce inner learning rate to fix stagnant loss ---
INNER_LR = 0.01       # Reduced from 0.04. MAML is very sensitive to this value.
INNER_STEPS = 1     # Number of gradient steps in inner loop

transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])

# Create dataset and dataloader
train_dataset = OmniglotTaskDataset(N_WAY, K_SHOT, K_QUERY, root_dir='./data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)


## 4. MAML Initialization and Training Loop

**Change Applied Here:**
- An `if` condition has been added to `break` the loop after 100 batches. This allows for a quick check to see if the loss is now decreasing.

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the meta-model
meta_model = MetaLearner(input_channels=1, num_classes=N_WAY).to(device)
meta_optimizer = optim.Adam(meta_model.parameters(), lr=META_LR)
loss_fn = nn.CrossEntropyLoss()

print(f"Starting MAML training on device: {device}")
for epoch in range(NUM_EPOCHS):
    meta_model.train()
    for batch_idx, batch in enumerate(train_loader):
        # Unpack the batch of tasks
        s_imgs, s_labels, q_imgs, q_labels = batch
        s_imgs, s_labels = s_imgs.to(device), s_labels.to(device)
        q_imgs, q_labels = q_imgs.to(device), q_labels.to(device)

        meta_optimizer.zero_grad()
        batch_meta_loss = 0.0

        # Iterate over each task in the meta-batch
        for i in range(BATCH_SIZE):
            task_s_imgs, task_s_labels = s_imgs[i], s_labels[i]
            task_q_imgs, task_q_labels = q_imgs[i], q_labels[i]

            # Create a temporary 'fast' model for the inner loop
            fast_model = copy.deepcopy(meta_model)
            fast_optimizer = optim.SGD(fast_model.parameters(), lr=INNER_LR)

            # --- Inner Loop: Adapt to the current task ---
            for _ in range(INNER_STEPS):
                support_preds = fast_model(task_s_imgs)
                support_loss = loss_fn(support_preds, task_s_labels)
                fast_optimizer.zero_grad()
                support_loss.backward()
                fast_optimizer.step()
            
            # --- Evaluate the adapted model on the query set ---
            query_preds = fast_model(task_q_imgs)
            query_loss = loss_fn(query_preds, task_q_labels)
            batch_meta_loss += query_loss

        # --- Outer Loop: Update the meta-model ---
        meta_loss = batch_meta_loss / BATCH_SIZE
        meta_loss.backward()
        meta_optimizer.step()

        if (batch_idx + 1) % 50 == 0: # Print progress every 50 batches
            print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Batch [{batch_idx+1}/{len(train_loader)}], Meta Loss: {meta_loss.item():.4f}")
        
        # --- OPTIMIZATION 3: Stop early for a quick test run ---
        if (batch_idx + 1) == 200: # Increased to 200 to see more trend
            print("Stopping epoch early for a quick test.")
            break

print("\nMeta-training finished!")


Starting MAML training on device: cpu
Epoch [1/1], Batch [50/2500], Meta Loss: 1.6100
Epoch [1/1], Batch [100/2500], Meta Loss: 1.6087
Epoch [1/1], Batch [150/2500], Meta Loss: 1.6090
Epoch [1/1], Batch [200/2500], Meta Loss: 1.6088
Stopping epoch early for a quick test.

Meta-training finished!


## 5. Evaluating the Meta-Learned Model

Now we can test how well our meta-model works. We'll create a new test task, adapt the model using the support set, and then measure its final accuracy on the query set.

In [37]:
print("\nEvaluating the meta-model on a new task...")

# Create a new dataset for testing (using the evaluation set of characters)
test_dataset = OmniglotTaskDataset(N_WAY, K_SHOT, K_QUERY, root_dir='./data', background=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# Get one test task
test_s_imgs, test_s_labels, test_q_imgs, test_q_labels = next(iter(test_loader))
test_s_imgs, test_s_labels = test_s_imgs.squeeze(0).to(device), test_s_labels.squeeze(0).to(device)
test_q_imgs, test_q_labels = test_q_imgs.squeeze(0).to(device), test_q_labels.squeeze(0).to(device)

# Adapt the model on the test support set
meta_model.eval()
fast_model_test = copy.deepcopy(meta_model)
fast_optimizer_test = optim.SGD(fast_model_test.parameters(), lr=INNER_LR)

for _ in range(INNER_STEPS * 3): # Use a few more steps for testing
    support_preds = fast_model_test(test_s_imgs)
    support_loss = loss_fn(support_preds, test_s_labels)
    fast_optimizer_test.zero_grad()
    support_loss.backward()
    fast_optimizer_test.step()

# Evaluate the adapted model
with torch.no_grad():
    query_preds = fast_model_test(test_q_imgs)
    _, predicted_labels = torch.max(query_preds, 1)
    accuracy = (predicted_labels == test_q_labels).float().mean().item()

print(f"Test Task Accuracy ({N_WAY}-way, {K_SHOT}-shot): {accuracy * 100:.2f}%")



Evaluating the meta-model on a new task...
Test Task Accuracy (5-way, 1-shot): 36.00%
