<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStardust/blob/main/Meta_Learning_with_Model_Agnostic_Meta_Learning_(MAML).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, optim

# Define a simple neural network
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, x):
        return self.fc(x)

# MAML setup
def train_maml(model, tasks, inner_steps=1, outer_steps=1000, inner_lr=0.01, outer_lr=0.001):
    outer_optimizer = optim.Adam(model.parameters(), lr=outer_lr)

    for step in range(outer_steps):
        meta_loss = 0
        for x_train, y_train, x_test, y_test in tasks:
            # Clone model for task-specific adaptation
            temp_model = NeuralNet()
            temp_model.load_state_dict(model.state_dict())
            inner_optimizer = optim.SGD(temp_model.parameters(), lr=inner_lr)

            # Inner loop (task-specific adaptation)
            for _ in range(inner_steps):
                y_pred = temp_model(x_train)
                inner_loss = nn.MSELoss()(y_pred, y_train)
                inner_optimizer.zero_grad()
                inner_loss.backward()
                inner_optimizer.step()

            # Outer loop (meta-update)
            y_pred = temp_model(x_test)
            outer_loss = nn.MSELoss()(y_pred, y_test)
            meta_loss += outer_loss

        # Meta-optimization step
        outer_optimizer.zero_grad()
        meta_loss.backward()
        outer_optimizer.step()
        print(f'Step {step+1}/{outer_steps}, Meta Loss: {meta_loss.item()}')

# Sample tasks for demonstration purposes
tasks = [
    (torch.randn(10, 1), torch.randn(10, 1), torch.randn(10, 1), torch.randn(10, 1)),
    (torch.randn(10, 1), torch.randn(10, 1), torch.randn(10, 1), torch.randn(10, 1)),
    (torch.randn(10, 1), torch.randn(10, 1), torch.randn(10, 1), torch.randn(10, 1))
]

model = NeuralNet()
train_maml(model, tasks)