<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Meta_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, optim
import numpy as np  # Import numpy for task generation

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, x):
        return self.fc(x)

# Define a task (e.g., learning a sine wave)
def task():
    a = np.random.uniform(0.1, 5.0)  # Amplitude
    phase = np.random.uniform(0, np.pi)  # Phase shift
    inputs = np.random.uniform(-5, 5, (10, 1))  # Input samples
    outputs = a * np.sin(inputs + phase)  # Sine wave outputs
    return torch.tensor(inputs, dtype=torch.float32), torch.tensor(outputs, dtype=torch.float32)

# MAML training loop
meta_model = SimpleModel()  # Meta-model
meta_optimizer = optim.Adam(meta_model.parameters(), lr=0.001)  # Meta-optimizer

for epoch in range(1000):
    meta_optimizer.zero_grad()
    meta_loss = 0.0

    for _ in range(10):  # 10 tasks per meta-update
        # Create task-specific model and optimizer
        task_model = SimpleModel()
        task_model.load_state_dict(meta_model.state_dict())  # Copy meta-model parameters
        task_optimizer = optim.SGD(task_model.parameters(), lr=0.01)

        # Generate task data
        inputs, outputs = task()

        # Inner loop: task-specific adaptation
        for _ in range(5):  # 5 gradient steps per task
            predictions = task_model(inputs)
            loss = nn.MSELoss()(predictions, outputs)
            task_optimizer.zero_grad()
            loss.backward()
            task_optimizer.step()

        # Compute meta-loss with the task-adapted model
        predictions = task_model(inputs)
        task_loss = nn.MSELoss()(predictions, outputs)
        meta_loss += task_loss

    # Average the meta-loss and update the meta-model
    meta_loss /= 10  # Average loss across tasks
    meta_loss.backward()
    meta_optimizer.step()

    # Print progress
    print(f"Epoch {epoch + 1}, Meta-Loss: {meta_loss.item():.4f}")