<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Meta_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class MAMLModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(MAMLModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def maml_step(model, data, labels, inner_lr, meta_optimizer):
    # Inner loop
    temp_model = MAMLModel(model.fc1.in_features, model.fc1.out_features)
    temp_model.load_state_dict(model.state_dict())
    optimizer = torch.optim.SGD(temp_model.parameters(), lr=inner_lr)
    optimizer.zero_grad()
    loss = F.mse_loss(temp_model(data), labels)
    loss.backward()
    optimizer.step()

    # Outer loop
    meta_optimizer.zero_grad()
    for param, temp_param in zip(model.parameters(), temp_model.parameters()):
        if param.grad is None:
            param.grad = temp_param.grad.clone()
        else:
            param.grad += temp_param.grad
    meta_optimizer.step()

# Example usage
model = MAMLModel(10, 40)
meta_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
data = torch.randn(32, 10)
labels = torch.randn(32, 1)

for _ in range(5):
    maml_step(model, data, labels, inner_lr=0.01, meta_optimizer=meta_optimizer)