<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Meta_Learning_(Model_Agnostic_Meta_Learning%2C_MAML).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the SimpleNN class
class SimpleNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)  # First fully connected layer
        self.fc2 = nn.Linear(128, output_dim)  # Second fully connected layer

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # Apply ReLU activation after the first layer
        x = self.fc2(x)  # Apply the second fully connected layer
        return x

# Define the MAML step function
def maml_step(model, loss_fn, optimizer, data, labels, inner_lr=0.01):
    # Inner loop (task-specific update)
    task_optimizer = optim.SGD(model.parameters(), lr=inner_lr)  # Task-specific optimizer
    task_optimizer.zero_grad()
    output = model(data)  # Forward pass
    loss = loss_fn(output, labels)  # Compute loss
    loss.backward()
    task_optimizer.step()  # Update model parameters

    # Compute meta-gradient
    meta_output = model(data)
    meta_loss = loss_fn(meta_output, labels)  # Compute meta-loss
    optimizer.zero_grad()
    meta_loss.backward()  # Backpropagation for meta-gradient
    optimizer.step()  # Update model parameters with meta-gradient

    return meta_loss.item()

# Example usage
model = SimpleNN(input_dim=784, output_dim=10)  # Instantiate the model
loss_fn = nn.CrossEntropyLoss()  # Define the loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Define the optimizer

# Dummy data for demonstration
data = torch.randn(64, 784)  # Example input (batch_size=64, input_dim=784)
labels = torch.randint(0, 10, (64,))  # Example labels (batch_size=64)

# Perform a single MAML step
meta_loss = maml_step(model, loss_fn, optimizer, data, labels)
print("Meta loss after MAML step:", meta_loss)