<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Meta_Learning_(Learning_to_Learn).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, optim

# Simple regression model
class RegressionModel(nn.Module):
    def __init__(self):
        super(RegressionModel, self).__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, x):
        return self.fc(x)

def maml_update(model, loss_fn, x, y, lr_inner):
    # Compute loss and gradients for the task
    loss = loss_fn(model(x), y)
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)

    # Update model parameters manually with the computed gradients
    updated_params = {}
    for (name, param), grad in zip(model.named_parameters(), gradients):
        updated_params[name] = param - lr_inner * grad

    return updated_params

def apply_updated_params(model, updated_params):
    # Create a temporary model with updated parameters
    temp_model = RegressionModel()
    temp_model.load_state_dict(model.state_dict())  # Copy original model state
    for name, param in temp_model.named_parameters():
        param.data.copy_(updated_params[name].data)
    return temp_model

# Outer loop: Meta-learning step
def meta_learning_step(model, optimizer, tasks, loss_fn, lr_inner):
    meta_loss = 0

    for x, y in tasks:
        # Perform inner loop (fine-tuning) and get updated parameters
        updated_params = maml_update(model, loss_fn, x, y, lr_inner)
        task_model = apply_updated_params(model, updated_params)

        # Evaluate task-specific model on the task
        task_loss = loss_fn(task_model(x), y)
        meta_loss += task_loss

    # Update the original model's parameters based on the meta-loss
    optimizer.zero_grad()
    meta_loss.backward()
    optimizer.step()

    return meta_loss.item()

# Initialize model, optimizer, and loss function
model = RegressionModel()
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

# Example usage: updating the model for a set of tasks
tasks = [(torch.tensor([[1.0], [2.0]]), torch.tensor([[3.0], [5.0]])),
         (torch.tensor([[3.0], [4.0]]), torch.tensor([[7.0], [9.0]]))]

# Meta-training loop
epochs = 100
lr_inner = 0.01  # Learning rate for inner loop

for epoch in range(epochs):
    meta_loss = meta_learning_step(model, optimizer, tasks, loss_fn, lr_inner)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Meta-Loss: {meta_loss:.4f}")