<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Meta_Learning_(Learning_to_Learn).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy

class MAMLModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MAMLModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def maml_step(model, loss_fn, tasks, inner_lr, outer_lr, inner_steps):
    outer_optimizer = optim.Adam(model.parameters(), lr=outer_lr)

    meta_loss = 0.0
    for data, target in tasks:
        # Create a copy of the model for the inner loop
        temp_model = deepcopy(model)
        inner_optimizer = optim.SGD(temp_model.parameters(), lr=inner_lr)

        # Inner loop (adaptation)
        for _ in range(inner_steps):
            inner_optimizer.zero_grad()
            output = temp_model(data)
            loss = loss_fn(output, target)
            loss.backward()
            inner_optimizer.step()

        # Calculate meta loss using the updated model
        output = temp_model(data)
        meta_loss += loss_fn(output, target)

    # Outer loop (meta-optimization)
    outer_optimizer.zero_grad()
    meta_loss.backward()
    outer_optimizer.step()

# Example usage
model = MAMLModel(input_dim=10, output_dim=1)
loss_fn = nn.MSELoss()
tasks = [(torch.randn(5, 10), torch.randn(5, 1)) for _ in range(4)]  # 4 tasks with batch size 5, input dimension 10
maml_step(model, loss_fn, tasks, inner_lr=0.01, outer_lr=0.001, inner_steps=5)