In [76]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [77]:
# Model: 2 Layer feedforward neural network
class MetaModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(MetaModel, self).__init__()

        # hidden layer: 64 neurons
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, output_size)

    # ReLU activation
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [78]:
# MAML algo
def maml_update(model, task_data, alpha=0.01):
    criterion = nn.CrossEntropyLoss()  # classification task
    task_loss = []                     # stores validation losses for each task
    gradients = []                     # stores the gradients after fine-tuning the model


    for (x_train, y_train, x_val, y_val) in task_data:

        # Clone model for task specific update
        temp_model = MetaModel(10, 2)  # input=10; output=2
        temp_model.load_state_dict(model.state_dict())
        optimizer = optim.SGD(temp_model.parameters(), lr=alpha)

        # Forward pass
        y_pred = temp_model(x_train)
        loss = criterion(y_pred, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Validation loss
        y_val_pred = temp_model(x_val)
        val_loss = criterion(y_val_pred, y_val)
        task_loss.append(val_loss)
        gradients.append([p.grad for p in temp_model.parameters()])

    return sum(task_loss) / len(task_loss), gradients

In [79]:
# Synthetic data
num_tasks = 5
input_size = 10
output_size = 2
meta_model = MetaModel(input_size, output_size)

In [80]:
# Random tasks
task_data = []
for _ in range(num_tasks):
    x_train = torch.randn(5, input_size)          # random
    y_train = torch.randint(0, output_size, (5,)) # integers between 0 and 1
    x_val = torch.randn(5, input_size)
    y_val = torch.randint(0, output_size, (5,))
    task_data.append((x_train, y_train, x_val, y_val))

In [81]:
# Meta Training loop
meta_optimizer = optim.Adam(meta_model.parameters(), lr=0.001) # updates model across tasks
for epoch in range(100):
    meta_optimizer.zero_grad()
    meta_loss, grads = maml_update(meta_model, task_data) # Compute validation loss and task gradients from fine tuned copies

    # Apply Avg Meta Gradients
    for param, grad in zip(meta_model.parameters(), grads[0]):
        param.grad = grad

    meta_optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Meta Loss: {meta_loss.item()}")

Epoch 0, Meta Loss: 0.7276244163513184
Epoch 10, Meta Loss: 0.710638165473938
Epoch 20, Meta Loss: 0.6883314847946167
Epoch 30, Meta Loss: 0.6614252924919128
Epoch 40, Meta Loss: 0.6329342126846313
Epoch 50, Meta Loss: 0.6059695482254028
Epoch 60, Meta Loss: 0.5837315320968628
Epoch 70, Meta Loss: 0.5676905512809753
Epoch 80, Meta Loss: 0.5568681955337524
Epoch 90, Meta Loss: 0.5509260892868042
