<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/MAML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define the base model for MAML
class MAMLModel(nn.Module):
    def __init__(self):
        super(MAMLModel, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.network(x)

# MAML training step
def maml_train_step(model, data, alpha=0.01):
    train_inputs, train_labels = data
    task_loss = nn.MSELoss()

    # Perform one step of gradient descent on the task-specific loss
    predictions = model(train_inputs)
    loss = task_loss(predictions, train_labels)
    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    fast_weights = [param - alpha * grad for param, grad in zip(model.parameters(), grads)]

    return fast_weights

# Meta-training loop for MAML
def meta_train(model, meta_optimizer, train_loader, num_tasks=4, beta=0.001):
    meta_loss = 0
    for batch in train_loader:
        inputs, labels = batch
        task_losses = 0
        task_data = (inputs.chunk(num_tasks), labels.chunk(num_tasks))
        for task_inputs, task_labels in zip(*task_data):
            fast_weights = maml_train_step(model, (task_inputs, task_labels))
            task_loss = nn.MSELoss()
            task_predictions = model(task_inputs)
            loss = task_loss(task_predictions, task_labels)
            task_losses += loss

        meta_loss += task_losses / num_tasks

    meta_optimizer.zero_grad()
    meta_loss.backward()
    meta_optimizer.step()

# Simulate tasks
def generate_task():
    x = torch.linspace(-5, 5, 100).view(-1, 1)
    y = 0.5 * x + torch.randn_like(x) * 0.5
    return x, y

# Prepare data loader for multiple tasks
task_datasets = [generate_task() for _ in range(4)]
input_batches = [task[0] for task in task_datasets]
label_batches = [task[1] for task in task_datasets]
task_dataset = TensorDataset(torch.cat(input_batches), torch.cat(label_batches))
task_loader = DataLoader(task_dataset, batch_size=4, shuffle=True)

# Initialize model and optimizer
beta = 0.001  # Meta-learning rate
maml_model = MAMLModel()
meta_optimizer = optim.Adam(maml_model.parameters(), lr=beta)

# Training loop
for epoch in range(100):
    meta_train(maml_model, meta_optimizer, task_loader)
    print(f"Epoch {epoch + 1} complete")

print("Meta-training complete.")