<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/MAML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Simple neural network model
class SimpleNet(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.fc(x)

# Generate synthetic data for MAML
def generate_data(num_samples, input_size):
    x = torch.randn(num_samples, input_size)
    y = (x.sum(dim=1) > 0).float().unsqueeze(1)  # Ensure y has the correct shape
    return x, y

# Define MAML algorithm
class MAML:
    def __init__(self, model, lr_inner, lr_meta):
        self.model = model.to(device)
        self.lr_inner = lr_inner
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr_meta)

    def inner_update(self, x, y):
        loss_fn = nn.BCEWithLogitsLoss()
        logits = self.model(x)
        loss = loss_fn(logits, y)
        grads = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
        updated_params = [p - self.lr_inner * g for p, g in zip(self.model.parameters(), grads)]
        return updated_params

    def forward(self, x, params):
        logits = x
        i = 0
        for layer in self.model.modules():
            if isinstance(layer, nn.Linear):
                logits = torch.nn.functional.linear(logits, params[i], params[i+1])
                i += 2
        return logits

    def train_step(self, x_train, y_train, x_val, y_val):
        params = self.inner_update(x_train, y_train)
        logits = self.forward(x_val, params)
        loss_fn = nn.BCEWithLogitsLoss()
        loss = loss_fn(logits, y_val)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

# Create data and model
x_train, y_train = generate_data(100, 5)
x_val, y_val = generate_data(100, 5)
x_train, y_train = x_train.to(device), y_train.to(device)
x_val, y_val = x_val.to(device), y_val.to(device)
model = SimpleNet(input_size=5, output_size=1)

# DataLoader for batch processing
train_dataset = TensorDataset(x_train, y_train)
val_dataset = TensorDataset(x_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Train MAML
maml = MAML(model, lr_inner=0.01, lr_meta=0.001)
for epoch in range(100):
    for x_batch, y_batch in train_loader:
        for x_val_batch, y_val_batch in val_loader:
            loss = maml.train_step(x_batch, y_batch, x_val_batch, y_val_batch)
    print(f"Epoch {epoch+1}, Loss: {loss}")