In [1]:
import os
import random

import numpy as np
import torch as th
from torch import nn
from torch import optim
from torchvision import transforms
from torchvision.datasets import ImageFolder

import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels

In [2]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = th.zeros(data.size(0)).byte()
    adaptation_indices[th.arange(shots*ways) * 2] = 1
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[1 - adaptation_indices], labels[1 - adaptation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        train_error = loss(learner(adaptation_data), adaptation_labels)
        train_error /= len(adaptation_data)
        learner.adapt(train_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    valid_error = loss(predictions, evaluation_labels)
    valid_error /= len(evaluation_data)
    valid_accuracy = accuracy(predictions, evaluation_labels)
    return valid_error, valid_accuracy


In [3]:
#variables to change
train_dataset = l2l.vision.datasets.MiniImagenet(root='../data', mode='train')
valid_dataset = l2l.vision.datasets.MiniImagenet(root='../data', mode='validation')
test_dataset = l2l.vision.datasets.MiniImagenet(root='../data', mode='test')
train_dataset = l2l.data.MetaDataset(train_dataset)
valid_dataset = l2l.data.MetaDataset(valid_dataset)
test_dataset = l2l.data.MetaDataset(test_dataset)

Downloading: ../data/mini-imagenet-cache-train.pkl
Download finished
Downloading: ../data/mini-imagenet-cache-validation.pkl
Download finished
Downloading: ../data/mini-imagenet-cache-test.pkl
Download finished


In [4]:
ways=5
shots=5
meta_lr=0.003
fast_lr=0.5
meta_batch_size=32
adaptation_steps=1
num_iterations=5
cuda=True
seed=42

In [5]:
random.seed(seed)
np.random.seed(seed)
th.manual_seed(seed)
device = th.device('cpu')
if cuda and th.cuda.device_count():
    th.cuda.manual_seed(seed)
    device = th.device('cuda')

In [6]:
train_transforms = [NWays(train_dataset, ways),KShots(train_dataset, 2*shots),LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms,num_tasks=20000)

In [7]:
valid_transforms = [NWays(valid_dataset, ways), KShots(valid_dataset, 2*shots), LoadData(valid_dataset),
        ConsecutiveLabels(train_dataset),
        RemapLabels(valid_dataset),
    ]
valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=600)

In [8]:
test_transforms = [NWays(test_dataset, ways), KShots(test_dataset, 2*shots), LoadData(test_dataset), 
        RemapLabels(test_dataset),
        ConsecutiveLabels(train_dataset),
    ]
test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=600)

In [9]:
 # Create model
model = l2l.vision.models.MiniImagenetCNN(ways)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')

In [10]:
 for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        meta_test_error = 0.0
        meta_test_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()
            batch = train_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch, learner, loss, adaptation_steps, shots, ways,
                                                               device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = valid_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

            # Compute meta-testing loss
            learner = maml.clone()
            batch = test_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_test_error += evaluation_error.item()
            meta_test_accuracy += evaluation_accuracy.item()

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)
        print('Meta Test Error', meta_test_error / meta_batch_size)
        print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()



















Iteration 0
Meta Train Error 0.06454412487801164
Meta Train Accuracy 0.25374999339692295
Meta Valid Error 0.06716673786286265
Meta Valid Accuracy 0.22499999357387424
Meta Test Error 0.06522285891696811
Meta Test Accuracy 0.25499999430030584






















Iteration 1
Meta Train Error 0.06523260218091309
Meta Train Accuracy 0.24374999315477908
Meta Valid Error 0.06682178436312824
Meta Valid Accuracy 0.20999999367631972
Meta Test Error 0.06616182962898165
Meta Test Accuracy 0.23124999389983714






















Iteration 2
Meta Train Error 0.06523774075321853
Meta Train Accuracy 0.23749999422580004
Meta Valid Error 0.06605448352638632
Meta Valid Accuracy 0.22124999482184649
Meta Test Error 0.06479842809494585
Meta Test Accuracy 0.26124999253079295




















Iteration 3
Meta Train Error 0.06471077061723918
Meta Train Accuracy 0.2437499943189323
Meta Valid Error 0.06630631268490106
Meta Valid Accuracy 0.2412499925121665
Meta Test Error 0.0653734871884808
Meta Test Accuracy 0.24624999263323843




















Iteration 4
Meta Train Error 0.06469488423317671
Meta Train Accuracy 0.2512499939184636
Meta Valid Error 0.06475561310071498
Meta Valid Accuracy 0.24374999478459358
Meta Test Error 0.06391599902417511
Meta Test Accuracy 0.2774999928660691


