<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Meta_Learning_for_Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, x):
        return self.fc(x)

def inner_loop(model, data, target, lr=0.01):
    optimizer = optim.SGD(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
    print(f"Inner Loop Loss: {loss.item():.4f}")
    return model

def outer_loop(meta_model, tasks, meta_lr=0.001):
    meta_optimizer = optim.Adam(meta_model.parameters(), lr=meta_lr)
    meta_optimizer.zero_grad()
    for data, target in tasks:
        model = SimpleModel()
        model.load_state_dict(meta_model.state_dict())
        model = inner_loop(model, data, target)
        for param, meta_param in zip(model.parameters(), meta_model.parameters()):
            meta_param.grad = param.grad if meta_param.grad is None else meta_param.grad + param.grad
    torch.nn.utils.clip_grad_norm_(meta_model.parameters(), max_norm=1.0)  # Gradient clipping
    meta_optimizer.step()

def validate(meta_model, tasks):
    meta_model.eval()
    losses = []
    for data, target in tasks:
        model = SimpleModel()
        model.load_state_dict(meta_model.state_dict())
        model = inner_loop(model, data, target, lr=0.01)
        loss_fn = nn.MSELoss()
        output = model(data)
        loss = loss_fn(output, target).item()
        losses.append(loss)
    avg_loss = sum(losses) / len(losses)
    print(f"Validation Loss: {avg_loss:.4f}")

# Create diverse synthetic tasks
tasks = [
    (torch.randn(10, 1) * scale, torch.randn(10, 1) * scale)
    for scale in [1, 0.5, 2, 1.5, 0.8]
]

meta_model = SimpleModel()
for round in range(100):
    outer_loop(meta_model, tasks)
    if round % 10 == 0:
        print(f'Meta-learning round {round} completed')
        validate(meta_model, tasks)