<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/MAML_for_Few_Shot_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Assuming a simple model for demonstration purposes
class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

# Compute loss function (assume defined)
def compute_loss(predictions, labels):
    criterion = nn.CrossEntropyLoss()
    return criterion(predictions, labels)

# Define MAML training step
def maml_train_step(model, task_batch, inner_steps=5, lr_inner=0.01):
    optimizer = optim.Adam(model.parameters(), lr=lr_inner)

    meta_loss = 0
    for task in task_batch:
        # Clone the model for the inner loop
        cloned_model = SimpleModel(model.fc.in_features, model.fc.out_features)
        cloned_model.load_state_dict(model.state_dict())
        cloned_model.train()

        inner_optimizer = optim.SGD(cloned_model.parameters(), lr=lr_inner)

        # Inner loop: Fine-tune on each task
        for _ in range(inner_steps):
            inner_optimizer.zero_grad()
            predictions = cloned_model(task['support_data'])
            loss = compute_loss(predictions, task['support_labels'])
            loss.backward()
            inner_optimizer.step()

        # Compute loss on query data
        predictions = cloned_model(task['query_data'])
        task_loss = compute_loss(predictions, task['query_labels'])
        meta_loss += task_loss

    # Outer loop: Update global parameters based on task performance
    optimizer.zero_grad()
    meta_loss.backward()
    optimizer.step()

    return meta_loss.item()

# Example usage
model = SimpleModel(input_dim=10, output_dim=2)
task_batch = [
    {
        'support_data': torch.randn(5, 10), 'support_labels': torch.randint(0, 2, (5,)),
        'query_data': torch.randn(5, 10), 'query_labels': torch.randint(0, 2, (5,))
    }
    # Add more tasks as needed
]

loss = maml_train_step(model, task_batch)
print(f"Meta Loss: {loss}")