In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from Maml import Maml

# Assuming you have a dataset of images for classification
# Here, we'll use the CIFAR-10 dataset as an example
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Define dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [02:01<00:00, 1401860.37it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


TypeError: adaptation() missing 1 required positional argument: 'model'

In [16]:
from _utils import train_val_split_regression
# Configuration for MAML
config = {
    'num_ways': 10,  # Number of classes
    'k_shot': 5,     # Number of support samples per class
    # Other necessary hyperparameters...
    # You can set up device, learning rates, etc.
}

config['network_architecture'] = 'CNN'
config['batchnorm'] = True
config['strided'] = True


config['device'] = torch.device('cuda:0' if torch.cuda.is_available() else torch.device('cpu'))

config['loss_function'] = torch.nn.MSELoss()
config['train_val_split_function'] = train_val_split_regression
config['num_ways'] = 1
config['k_shot'] = 5
config['v_shot'] = 10
config['num_models'] = 16
config['KL_weight'] = 1e-5

config['inner_lr'] = 0.001
config['num_inner_updates'] = 15
config['meta_lr'] = 1e-3

config['train_flag'] = False
config['num_episodes'] = 1

In [17]:
# Train and validate the MAML model
maml_classifier = Maml(config)
epochs = 10
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        # Perform adaptation step on the current task (episode)
        # This involves updating the inner loop parameters
        adaptation_data = data[:config['num_ways'] * config['k_shot']]
        adaptation_labels = target[:config['num_ways'] * config['k_shot']]
        model = maml_classifier.load_model(resume_epoch=0, eps_dataloader=train_loader)  # Load or initialize the model
        adaptation_loss, _ = maml_classifier.adaptation(adaptation_data, adaptation_labels, model)

        # Evaluate on a validation set (query set)
        query_data = data[config['num_ways'] * config['k_shot']:]
        query_labels = target[config['num_ways'] * config['k_shot']:]
        validation_loss, accuracy = maml_classifier.evaluation(query_data, query_labels, model)

        # Print losses and accuracy
        print(f"Epoch [{epoch + 1}/{epochs}] - Batch [{batch_idx + 1}/{len(train_loader)}] "
              f"Adaptation Loss: {adaptation_loss.item():.4f}, Validation Loss: {validation_loss:.4f}, "
              f"Validation Accuracy: {accuracy:.2f}%")


IndexError: index 73065 is out of bounds for dimension 0 with size 32