<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Meta_Learning_with_Model_Agnostic_Meta_Learning_(MAML).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model class for demonstration
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

def initialize_model():
    return SimpleModel()

def compute_task_loss(model, task_data):
    inputs, labels = task_data
    outputs = model(inputs)
    loss = nn.CrossEntropyLoss()(outputs, labels)
    return loss

def compute_gradients(loss, parameters):
    loss.backward(retain_graph=True)  # Ensure retain_graph=True to keep the computation graph
    return [param.grad.clone() if param.grad is not None else torch.zeros_like(param) for param in parameters]

def update_model(model, gradients, learning_rate):
    with torch.no_grad():
        for param, grad in zip(model.parameters(), gradients):
            param -= learning_rate * grad

def compute_meta_gradient(task_loss, task_model, meta_model):
    task_loss.backward(retain_graph=True)  # Retain computation graph for subsequent backward calls
    return [param.grad.clone() if param.grad is not None else torch.zeros_like(param) for param in meta_model.parameters()]

def update_meta_model(meta_model, meta_gradient, learning_rate):
    with torch.no_grad():
        for param, grad in zip(meta_model.parameters(), meta_gradient):
            param -= learning_rate * grad
    return meta_model

# Hyperparameters
num_iterations = 10
num_tasks = 5
num_inner_updates = 1
meta_learning_rate = 0.001
inner_learning_rate = 0.01

# Initialize the meta-model
meta_model = initialize_model()

# Dummy task data
tasks = [ (torch.randn(10, 10), torch.randint(0, 2, (10,))) for _ in range(num_tasks) ]

for iteration in range(num_iterations):
    meta_gradient = [torch.zeros_like(param) for param in meta_model.parameters()]

    # Iterate over tasks
    for task_data in tasks:
        # Initialize a copy of the meta-model for task-specific training
        task_model = initialize_model()
        task_model.load_state_dict(meta_model.state_dict())

        # Perform task-specific updates
        for step in range(num_inner_updates):
            task_loss = compute_task_loss(task_model, task_data)
            task_gradients = compute_gradients(task_loss, task_model.parameters())
            update_model(task_model, task_gradients, inner_learning_rate)

        # Compute gradient for meta-update
        task_loss = compute_task_loss(task_model, task_data)
        task_meta_gradient = compute_meta_gradient(task_loss, task_model, meta_model)
        for i, grad in enumerate(task_meta_gradient):
            meta_gradient[i] += grad

    # Average the meta-gradient
    meta_gradient = [grad / num_tasks for grad in meta_gradient]

    # Update meta-model with aggregated meta-gradient
    meta_model = update_meta_model(meta_model, meta_gradient, meta_learning_rate)

print("MAML training completed.")