<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Meta_Learning_(Learning_to_Learn).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict

# Simple neural network for classification
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.model(x)

# MAML Algorithm
class MAML(nn.Module):
    def __init__(self, model, lr_inner=0.01, lr_outer=0.001):
        super(MAML, self).__init__()
        self.model = model
        self.lr_inner = lr_inner
        self.lr_outer = lr_outer
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr_outer)

    def forward(self, x):
        return self.model(x)

    def inner_update(self, loss):
        grads = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
        updated_params = OrderedDict()
        for (name, param), grad in zip(self.model.named_parameters(), grads):
            updated_params[name] = param - self.lr_inner * grad
        return updated_params

    def forward_with_params(self, x, params):
        for name, param in self.model.named_parameters():
            if name in params:
                param.data.copy_(params[name])
        return self.model(x)

    def outer_update(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

# Few-shot classification task example
def train_maml():
    # Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize model and MAML
    model = MLP(784, 256, 10).to(device)  # For MNIST-like datasets (28x28 images flattened to 784)
    maml = MAML(model)

    # Example data for training (for simplicity, random tensors)
    num_tasks = 5
    num_shots = 5
    query_shots = 15

    for epoch in range(10):  # Meta-epochs
        total_meta_loss = 0
        for task in range(num_tasks):
            # Simulating a task with support (train) and query (test) data
            support_x = torch.rand(num_shots, 784).to(device)  # Random support samples
            support_y = torch.randint(0, 10, (num_shots,)).to(device)  # Random support labels
            query_x = torch.rand(query_shots, 784).to(device)  # Random query samples
            query_y = torch.randint(0, 10, (query_shots,)).to(device)  # Random query labels

            # Inner update (on task-specific support data)
            support_logits = maml.forward(support_x)
            support_loss = nn.CrossEntropyLoss()(support_logits, support_y)
            updated_params = maml.inner_update(support_loss)

            # Query loss computation using updated parameters
            query_logits = maml.forward_with_params(query_x, updated_params)
            query_loss = nn.CrossEntropyLoss()(query_logits, query_y)
            total_meta_loss += query_loss.item()

            # Outer update (on the meta-loss from the query set)
            maml.outer_update(query_loss)

        print(f"Epoch [{epoch+1}/10], Meta Loss: {total_meta_loss / num_tasks:.4f}")

# Train MAML
train_maml()