<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Model_Agnostic_Meta_Learning_(MAML).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

def maml_train(model, train_task_data, test_task_data, inner_lr, outer_lr, epochs):
    optimizer = optim.Adam(model.parameters(), lr=outer_lr)

    for epoch in range(epochs):
        losses = []
        for train_data, test_data in zip(train_task_data, test_task_data):
            inner_optimizer = optim.SGD(model.parameters(), lr=inner_lr)
            # Inner loop
            for _ in range(5):
                train_loss = nn.MSELoss()(model(train_data[0]), train_data[1])
                inner_optimizer.zero_grad()
                train_loss.backward()
                inner_optimizer.step()

            # Outer loop
            test_loss = nn.MSELoss()(model(test_data[0]), test_data[1])
            losses.append(test_loss.item())
            optimizer.zero_grad()
            test_loss.backward()
            optimizer.step()

        print(f'Epoch {epoch + 1}, Average Loss: {sum(losses) / len(losses):.4f}')

# Dummy data for illustration
train_task_data = [(torch.randn(10, 10), torch.randn(10, 1)) for _ in range(5)]
test_task_data = [(torch.randn(10, 10), torch.randn(10, 1)) for _ in range(5)]

# Example usage
model = SimpleNN(input_dim=10, output_dim=1)
maml_train(model, train_task_data, test_task_data, inner_lr=0.01, outer_lr=0.001, epochs=10)