In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt

# Define the neural network
class SimpleNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

# Function to generate sine wave tasks
def generate_sine_wave(amplitude, phase, num_samples):
    x = np.linspace(-5, 5, num_samples)
    y = amplitude * np.sin(x + phase)
    return torch.tensor(x, dtype=torch.float32).unsqueeze(1), torch.tensor(y, dtype=torch.float32).unsqueeze(1)

# MAML training loop
def maml_training(model, tasks, meta_lr, inner_lr, num_inner_steps, meta_batch_size, num_iterations):
    meta_optimizer = optim.Adam(model.parameters(), lr=meta_lr)

    for iteration in range(num_iterations):
        meta_optimizer.zero_grad()

        meta_loss = 0
        for _ in range(meta_batch_size):
            # Sample a task
            amplitude, phase = np.random.uniform(0.1, 5.0), np.random.uniform(0, np.pi)
            x_train, y_train = generate_sine_wave(amplitude, phase, 10)
            x_val, y_val = generate_sine_wave(amplitude, phase, 10)
            
            # Inner loop: Clone the model and perform gradient descent
            cloned_model = SimpleNN(1, 40, 1)
            cloned_model.load_state_dict(model.state_dict())
            inner_optimizer = optim.SGD(cloned_model.parameters(), lr=inner_lr)

            for _ in range(num_inner_steps):
                y_pred = cloned_model(x_train)
                loss = nn.MSELoss()(y_pred, y_train)
                inner_optimizer.zero_grad()
                loss.backward()
                inner_optimizer.step()

            # Outer loop: Compute validation loss
            y_val_pred = cloned_model(x_val)
            val_loss = nn.MSELoss()(y_val_pred, y_val)
            meta_loss += val_loss
        
        # Meta optimization step
        meta_loss /= meta_batch_size
        meta_loss.backward()
        meta_optimizer.step()

        if iteration % 100 == 0:
            print(f"Iteration {iteration}, Meta Loss: {meta_loss.item()}")

# Hyperparameters
input_dim = 1
hidden_dim = 40
output_dim = 1
meta_lr = 0.001
inner_lr = 0.01
num_inner_steps = 1
meta_batch_size = 5
num_iterations = 1000

# Initialize model and train
model = SimpleNN(input_dim, hidden_dim, output_dim)
maml_training(model, generate_sine_wave, meta_lr, inner_lr, num_inner_steps, meta_batch_size, num_iterations)

# Testing the trained model on a new task
amplitude, phase = 2.5, 0.5
x_test, y_test = generate_sine_wave(amplitude, phase, 50)
y_pred = model(x_test).detach().numpy()
plt.plot(x_test, y_test, label='True')
plt.plot(x_test, y_pred, label='Predicted')
plt.legend()
plt.show()
