<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Meta_Learning_(Learning_to_Learn).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy

class MetaLearner(nn.Module):
    def __init__(self, model, inner_lr):
        super(MetaLearner, self).__init__()
        self.model = model
        self.inner_lr = inner_lr

    def forward(self, x, params=None):
        if params is None:  # Use current model parameters
            return self.model(x)
        else:  # Use updated parameters (during inner loop)
            x = x @ params["fc.weight"].T + params["fc.bias"]
            return x

    def inner_update(self, loss, params=None):
        if params is None:
            params = {name: p for name, p in self.model.named_parameters()}

        grads = torch.autograd.grad(loss, params.values(), create_graph=True)
        updated_params = {
            name: param - self.inner_lr * grad
            for (name, param), grad in zip(params.items(), grads)
        }
        return updated_params

    def meta_update(self, meta_loss, meta_optimizer):
        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()

# Example task-specific model
class SimpleNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

# MAML training loop
def maml_train(meta_learner, tasks, meta_optimizer, n_inner_steps=1):
    meta_loss = 0
    for x_train, y_train, x_val, y_val in tasks:
        # Forward pass on the training set
        outputs = meta_learner(x_train)
        task_loss = nn.CrossEntropyLoss()(outputs, y_train)

        # Inner loop: Update model parameters
        updated_params = meta_learner.inner_update(task_loss)

        # Compute loss on the validation set using updated parameters
        for _ in range(n_inner_steps - 1):
            outputs = meta_learner(x_train, params=updated_params)
            task_loss = nn.CrossEntropyLoss()(outputs, y_train)
            updated_params = meta_learner.inner_update(task_loss, params=updated_params)

        # Validation step using updated parameters
        val_outputs = meta_learner(x_val, params=updated_params)
        val_loss = nn.CrossEntropyLoss()(val_outputs, y_val)
        meta_loss += val_loss

    # Outer loop (meta-update)
    meta_loss /= len(tasks)
    meta_learner.meta_update(meta_loss, meta_optimizer)

# Example usage
model = SimpleNN(input_dim=10, output_dim=2)
meta_learner = MetaLearner(model, inner_lr=0.01)
meta_optimizer = optim.Adam(meta_learner.parameters(), lr=0.001)

# Create dummy tasks: (x_train, y_train, x_val, y_val)
tasks = []
for _ in range(10):  # Assume 10 tasks
    x_train = torch.randn(5, 10)  # 5 examples, 10 features
    y_train = torch.randint(0, 2, (5,))  # Binary classification
    x_val = torch.randn(5, 10)
    y_val = torch.randint(0, 2, (5,))
    tasks.append((x_train, y_train, x_val, y_val))

# Training loop
for epoch in range(100):
    maml_train(meta_learner, tasks, meta_optimizer)
    print(f"Epoch {epoch + 1} completed.")