<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Simple_implementation_of_MAML_(Model_Agnostic_Meta_Learning).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(1, 40)
        self.fc2 = nn.Linear(40, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# MAML class
class MAML:
    def __init__(self, model, lr_inner=0.01, lr_outer=0.001):
        self.model = model
        self.lr_inner = lr_inner
        self.lr_outer = lr_outer
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr_outer)
        self.criterion = nn.MSELoss()

    def clone_model(self):
        cloned_model = SimpleNet()
        cloned_model.load_state_dict(self.model.state_dict())
        return cloned_model

    def inner_update(self, model, x, y):
        model.train()
        loss = self.criterion(model(x), y)
        model.zero_grad()
        loss.backward()
        with torch.no_grad():
            for param in model.parameters():
                param -= self.lr_inner * param.grad
        return model

    def outer_update(self, support_data, query_data):
        self.optimizer.zero_grad()

        for task in support_data.keys():
            x_s, y_s = support_data[task]
            x_q, y_q = query_data[task]

            model_inner = self.clone_model()
            model_inner = self.inner_update(model_inner, x_s, y_s)

            loss = self.criterion(model_inner(x_q), y_q)
            loss.backward()

        self.optimizer.step()

# Example training loop
def train_maml(maml, tasks_support, tasks_query, epochs=100):
    for epoch in range(epochs):
        maml.outer_update(tasks_support, tasks_query)
        print(f"Epoch {epoch+1}/{epochs} completed.")

# Generate dummy data for demonstration
support_data = {0: (torch.randn(10, 1), torch.randn(10, 1))}
query_data = {0: (torch.randn(10, 1), torch.randn(10, 1))}

# Initialize model and MAML
model = SimpleNet()
maml = MAML(model)

# Train MAML
train_maml(maml, support_data, query_data, epochs=5)