<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Meta_Learning_(Learning_to_Learn).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict

# Define a simple feedforward neural network
class SimpleModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# MAML framework
class MAML:
    def __init__(self, model, lr_inner=0.01, lr_outer=0.001, num_adapt_steps=1):
        self.model = model
        self.lr_inner = lr_inner  # Inner loop learning rate
        self.lr_outer = lr_outer  # Outer loop learning rate
        self.num_adapt_steps = num_adapt_steps  # Steps in the inner loop
        self.outer_optimizer = optim.Adam(self.model.parameters(), lr=self.lr_outer)

    def adapt(self, loss):
        # Compute gradients for inner loop updates
        grads = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
        updated_params = OrderedDict()
        for (name, param), grad in zip(self.model.named_parameters(), grads):
            updated_params[name] = param - self.lr_inner * grad
        return updated_params

    def forward_with_params(self, x, params):
        # Use updated parameters for forward pass
        for name, param in self.model.named_parameters():
            if name in params:
                param.data.copy_(params[name])
        return self.model(x)

    def meta_update(self, meta_loss):
        # Backpropagate and update the meta-parameters (outer loop)
        self.outer_optimizer.zero_grad()
        meta_loss.backward()
        self.outer_optimizer.step()

# Sinusoidal regression task data generator
def generate_sinusoidal_data(batch_size, amplitude_range=(0.1, 5.0), phase_range=(0, 3.14)):
    amplitudes = torch.rand(batch_size) * (amplitude_range[1] - amplitude_range[0]) + amplitude_range[0]
    phases = torch.rand(batch_size) * (phase_range[1] - phase_range[0]) + phase_range[0]
    x = torch.linspace(-5, 5, 100).unsqueeze(0).repeat(batch_size, 1)  # (batch_size, num_points)
    y = amplitudes.unsqueeze(1) * torch.sin(x + phases.unsqueeze(1))  # (batch_size, num_points)
    return x.view(-1, 1), y.view(-1, 1)  # Reshape to (batch_size * num_points, input_dim)

# Training loop
input_dim = 1  # Input dimension
hidden_dim = 40  # Hidden layer dimension
output_dim = 1  # Output dimension
meta_batch_size = 32  # Number of tasks in each meta-update
task_data_points = 10  # Number of data points per task
num_meta_epochs = 100  # Number of meta-epochs

# Initialize model and MAML framework
model = SimpleModel(input_dim, hidden_dim, output_dim)
maml = MAML(model)

for epoch in range(num_meta_epochs):
    meta_loss = 0
    for _ in range(meta_batch_size):
        # Generate task-specific data
        x, y = generate_sinusoidal_data(1)  # One task (batch size = 1)
        x_train, y_train = x[:task_data_points], y[:task_data_points]  # Training set
        x_val, y_val = x[task_data_points:], y[task_data_points:]  # Validation set

        # Inner loop: Adapt to task
        outputs_train = maml.model(x_train)
        inner_loss = nn.MSELoss()(outputs_train, y_train)
        task_params = maml.adapt(inner_loss)

        # Outer loop: Evaluate on validation set
        outputs_val = maml.forward_with_params(x_val, task_params)
        meta_loss += nn.MSELoss()(outputs_val, y_val)

    # Meta-update (outer loop)
    meta_loss /= meta_batch_size  # Average meta-loss across tasks
    maml.meta_update(meta_loss)

    # Log meta-loss for the epoch
    print(f"Epoch [{epoch + 1}/{num_meta_epochs}], Meta Loss: {meta_loss.item():.4f}")