In [None]:
!pip install torch torchvision learn2learn

Collecting learn2learn
  Downloading learn2learn-0.2.0.tar.gz (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gsutil (from learn2learn)
  Downloading gsutil-5.27.tar.gz (3.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m61.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting qpth>=0.0.15 (from learn2learn)
  Downloading qpth-0.0.16.tar.gz (16 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting argcomplete>=1.9.4 (from gsutil->learn2learn)
  Downloading argcomplete-3.2.2-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.3/42.3 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting crcmod>=1.7 (from gsutil->learn2learn)
  Downloading crcmod-1.7.tar.gz (89 kB)
[2K     [90m━━━━━━━━━━━━━

In [None]:
import random

def n_way_k_shot_task(dataset, n, k, num_tasks):
    tasks = []
    for _ in range(num_tasks):
        # Sample N classes
        classes = random.sample(list(set(dataset.labels)), n)

        # Sample K examples per class
        support_set = []
        query_set = []
        for cls in classes:
            data_cls = [(x, y) for x, y in dataset if y == cls]
            random.shuffle(data_cls)
            support_set += data_cls[:k]
            query_set += data_cls[k:]
        tasks.append((support_set, query_set))
    return tasks

In [None]:

from torch.utils.data import Dataset
class RestructuredOmniglot(Dataset):
    def __init__(self, data):
        self.data = data
        self.labels = list(self.data.keys())

    def __getitem__(self, index):
        label = self.labels[index]
        image_index = torch.randint(len(self.data[label]), size=(1,)).item()
        image = self.data[label][image_index]
        return image, label

    def __len__(self):
        return len(self.labels)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

import learn2learn as l2l
from learn2learn.vision.datasets import FullOmniglot
from learn2learn.data import TaskDataset, MetaDataset
from learn2learn.data.transforms import NWays, KShots
# --- Dataset Preparation ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the Omniglot Dataset
omniglot = FullOmniglot(root='./data', transform=transform, download=True)

# Restructure dataset into a dictionary for NWays/KShots
restructured_data = {}
for image, label in omniglot:
    if label not in restructured_data:
        restructured_data[label] = []
    restructured_data[label].append(image)

# Create an instance of our custom dataset
dataset = RestructuredOmniglot(restructured_data)

# Wrap in a MetaDataset
meta_omniglot = MetaDataset(dataset)



Files already downloaded and verified
Files already downloaded and verified


In [None]:
# Create TaskDataset with standard transforms
train_tasks = TaskDataset(meta_omniglot, task_transforms=[
        NWays(meta_omniglot, n=5),
        KShots(meta_omniglot, k=1),
    ],
    num_tasks=1000
)

In [None]:
train_tasks.sample()

tensor([ 684,  299,  974, 1132,  952])

In [None]:
# --- Model Definition ---
class SimpleCNN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(input_size, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 32, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# --- Meta-Learner Setup ---
model = SimpleCNN(input_size=1, num_classes=5)
maml = l2l.algorithms.MAML(model, lr=0.01)
optimizer = optim.Adam(maml.parameters(), lr=0.001)

# --- Meta-Training Loop ---
for iteration in range(20000):
    learner = maml.clone()
    task = train_tasks.sample()
    data, labels = task
    adaptation_loss = learner(data, labels)
    adaptation_loss.backward()
    maml.adapt(adaptation_loss)

    # --- Evaluation ---
    if iteration % 100 == 0:
        meta_test_error = 0.0
        meta_test_accuracy = 0.0
        for _ in range(100):  # Sample 100 evaluation tasks
            test_task = train_tasks.sample()
            test_data, test_labels = test_task
            # Similar adaptation as in training
            evaluation_learner = maml.clone()
            evaluation_loss = evaluation_learner(test_data, test_labels)
            evaluation_predictions = evaluation_learner(test_data).argmax(dim=1)
            meta_test_accuracy += (evaluation_predictions == test_labels).float().mean()
            meta_test_error += evaluation_loss.item()

        meta_test_accuracy /= 100
        meta_test_error /= 100
        print('Iteration:', iteration)
        print('Meta Test Error:', meta_test_error)
        print('Meta Test Accuracy:', meta_test_accuracy)



ValueError: too many values to unpack (expected 2)