<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Full_MAML_Meta_Training_Script_(Synthetic_Tasks).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy

# --- Simple MLP Model ---
class SimpleNN(nn.Module):
    def __init__(self, input_size=10, hidden_size=32, output_size=2):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# --- MAML Inner-Loop Update ---
def maml_inner_update(model, loss_fn, x_support, y_support, inner_lr):
    model_clone = copy.deepcopy(model)
    optimizer = optim.SGD(model_clone.parameters(), lr=inner_lr)

    output = model_clone(x_support)
    loss = loss_fn(output, y_support)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return model_clone

# --- Meta-Training Step ---
def meta_train_step(meta_model, tasks, loss_fn, meta_optimizer, inner_lr):
    meta_loss = 0.0

    for x_support, y_support, x_query, y_query in tasks:
        # Inner loop adaptation
        adapted_model = maml_inner_update(meta_model, loss_fn, x_support, y_support, inner_lr)

        # Evaluate on query set using adapted model
        output = adapted_model(x_query)
        loss = loss_fn(output, y_query)
        meta_loss += loss

    meta_loss = meta_loss / len(tasks)

    # Meta update
    meta_optimizer.zero_grad()
    meta_loss.backward()
    meta_optimizer.step()

    return meta_loss.item()

# --- Dummy Task Sampler ---
def generate_dummy_task(n_support=5, n_query=15, input_dim=10, num_classes=2):
    x_support = torch.randn(n_support, input_dim)
    y_support = torch.randint(0, num_classes, (n_support,))
    x_query = torch.randn(n_query, input_dim)
    y_query = torch.randint(0, num_classes, (n_query,))
    return x_support, y_support, x_query, y_query

# --- Main Training Loop ---
input_dim = 10
meta_model = SimpleNN(input_size=input_dim)
meta_optimizer = optim.Adam(meta_model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
inner_lr = 0.01

# Meta-training
for epoch in range(1, 101):
    tasks = [generate_dummy_task() for _ in range(4)]  # 4 tasks per meta-batch
    loss = meta_train_step(meta_model, tasks, loss_fn, meta_optimizer, inner_lr)

    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch:03d} | Meta Loss: {loss:.4f}")