In [None]:
!pip install torch



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from tqdm import tqdm

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Define the model architecture

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(1 * 11025, 64)
        self.fc2 = nn.Linear(64, 5)  # 5 classes for Omniglot

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# Define MAML class
class MAML:
    def __init__(self, model, inner_lr=0.01, meta_lr=0.001, num_classes=5):
        self.model = model.to(device)
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.num_classes = num_classes
        self.meta_optimizer = optim.Adam(self.model.parameters(), lr=meta_lr)

    def inner_update(self, model, loss, step_size):
        grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
        updated_model = self._update_parameters(model, grads, step_size)
        return updated_model

    def _update_parameters(self, model, grads, step_size):
        updated_model = model
        params = list(model.parameters())
        for i in range(len(params)):
            updated_model.param()[i] = params[i] - step_size * grads[i]
        return updated_model

    def meta_update(self, batch, K=5, N=5):
        meta_loss = 0.0
        for _ in range(N):
            model_copy = SimpleModel().to(device)
            model_copy.load_state_dict(self.model.state_dict())

            inner_optimizer = optim.SGD(model_copy.parameters(), lr=self.inner_lr)

            for _ in range(K):
                inner_optimizer.zero_grad()
                x, y = batch
                x, y = x.to(device), y.to(device)
                logits = model_copy(x)

                # Reshape y to match the logits shape
                y = y.view(-1)

                loss = nn.CrossEntropyLoss()(logits, y)
                model_copy = self.inner_update(model_copy, loss, self.inner_lr)

            logits = model_copy(x)
            loss = nn.CrossEntropyLoss()(logits, y)
            meta_loss += loss

        meta_loss /= N

        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        self.meta_optimizer.step()


In [None]:
def train(self, dataloader, epochs=5, K=5, N=5):
        for epoch in range(epochs):
            for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
                inputs, targets = batch
                inputs, targets = inputs.to(device), targets.to(device)

                # Convert targets to one-hot encoding
                targets_onehot = torch.zeros(targets.size(0), self.num_classes).to(device)
                targets_onehot.scatter_(1, targets.view(-1, 1), 1)

                batch = (inputs, targets_onehot)

                self.meta_update(batch, K, N)

In [None]:
# Download Omniglot dataset
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()
])

train_dataset = Omniglot(root='./data', download=True, background=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Initialize MAML
model = SimpleModel()
maml = MAML(model)

Files already downloaded and verified


In [None]:
# Train MAML for multiple epochs
for epoch in range(5):
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/5"):
        inputs, targets = batch
        # Flatten the targets
        targets_flat = targets.view(-1)
        maml.meta_update((inputs, targets_flat), K=5, N=5)

Epoch 1/5:   0%|          | 0/19280 [00:00<?, ?it/s]


IndexError: ignored