<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Model_Agnostic_Meta_Learning_(MAML)_with_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

def train_maml(model, tasks, meta_lr=0.001, inner_lr=0.01, inner_steps=5):
    meta_optimizer = optim.Adam(model.parameters(), lr=meta_lr)

    for task in tasks:
        model_copy = SimpleModel()
        model_copy.load_state_dict(model.state_dict())
        inner_optimizer = optim.SGD(model_copy.parameters(), lr=inner_lr)

        for _ in range(inner_steps):
            # Simulate a single task training loop
            x_train, y_train = task
            loss = ((model_copy(x_train) - y_train)**2).mean()
            inner_optimizer.zero_grad()
            loss.backward()
            inner_optimizer.step()

        meta_optimizer.zero_grad()
        x_val, y_val = task  # Assuming we have validation data from the same task
        meta_loss = ((model_copy(x_val) - y_val)**2).mean()
        meta_loss.backward()
        meta_optimizer.step()
        print(f'Meta Loss: {meta_loss.item()}')

# Example usage with dummy data
tasks = [(
    torch.tensor([[1.0], [2.0]]),
    torch.tensor([[2.0], [4.0]])
) for _ in range(10)]
model = SimpleModel()
train_maml(model, tasks)