In [1]:
import os
import random
import sys
#sys.path.append('../../../') 

import numpy as np
import torch as th
from torch import nn
from torch import optim
from torchvision import transforms
from torchvision.datasets import ImageFolder

import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels

In [2]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

def transformbatch(batch, transformer):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)
    data_transformed = transformer(data)
    return [data_transformed, labels]

def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = th.zeros(data.size(0)).byte()
    adaptation_indices[th.arange(shots*ways) * 2] = 1
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[1 - adaptation_indices], labels[1 - adaptation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        train_error = loss(learner(adaptation_data), adaptation_labels)
        train_error /= len(adaptation_data)
        learner.adapt(train_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    valid_error = loss(predictions, evaluation_labels)
    valid_error /= len(evaluation_data)
    valid_accuracy = accuracy(predictions, evaluation_labels)
    return valid_error, valid_accuracy

In [3]:
#variables to change
train_dataset = l2l.vision.datasets.MiniImagenet(root='../data', mode='train')
valid_dataset = l2l.vision.datasets.MiniImagenet(root='../data', mode='validation')
test_dataset = l2l.vision.datasets.MiniImagenet(root='../data', mode='test')
train_dataset = l2l.data.MetaDataset(train_dataset)
valid_dataset = l2l.data.MetaDataset(valid_dataset)
test_dataset = l2l.data.MetaDataset(test_dataset)

In [4]:
ways=5
shots=5
meta_lr=0.003
fast_lr=0.5
meta_batch_size=32
adaptation_steps=1
num_iterations=5
cuda=True
seed=42

In [5]:
random.seed(seed)
np.random.seed(seed)
th.manual_seed(seed)
device = th.device('cpu')
if cuda and th.cuda.device_count():
    th.cuda.manual_seed(seed)
    device = th.device('cpu')

In [6]:
train_transforms = [NWays(train_dataset, ways),KShots(train_dataset, 2*shots),LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms,num_tasks=20000)

In [7]:
valid_transforms = [NWays(valid_dataset, ways), KShots(valid_dataset, 2*shots), LoadData(valid_dataset),
        ConsecutiveLabels(train_dataset),
        RemapLabels(valid_dataset),
    ]
valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=600)

In [8]:
test_transforms = [NWays(test_dataset, ways), KShots(test_dataset, 2*shots), LoadData(test_dataset), 
        RemapLabels(test_dataset),
        ConsecutiveLabels(train_dataset),
    ]
test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=600)

In [9]:
# Create models, 2 networks, one is the transformer (conditioned on gradients of minibatches)
# the other is the meta-learner doing MAML
embedding = 64
#same as miniImagenetCNN to learn the latent space/conditioned on the gradients,
#used size 64 like LEO encoder
transformer = l2l.vision.models.MiniImagenetCNN(embedding)
transformer.to(device)
metal = l2l.vision.models.LinearBlock(embedding, ways)
metal.to(device)
maml = l2l.algorithms.MAML(metal, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')
print(transformer)
print(metal)

MiniImagenetCNN(
  (base): ConvBase(
    (0): ConvBlock(
      (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): ConvBlock(
      (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): ConvBlock(
      (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1),

In [11]:
 for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        meta_test_error = 0.0
        meta_test_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()
            batch = train_tasks.sample()
            batch_transformed = transformbatch(batch, transformer)
            evaluation_error, evaluation_accuracy = fast_adapt(batch_transformed, learner, loss, adaptation_steps, shots, ways,
                                                               device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = valid_tasks.sample()
            batch_transformed = transformbatch(batch, transformer)
            evaluation_error, evaluation_accuracy = fast_adapt(batch_transformed, learner,loss, adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

            # Compute meta-testing loss
            learner = maml.clone()
            batch = test_tasks.sample()
            batch_transformed = transformbatch(batch, transformer)
            evaluation_error, evaluation_accuracy = fast_adapt(batch_transformed, learner, loss, adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_test_error += evaluation_error.item()
            meta_test_accuracy += evaluation_accuracy.item()

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)
        print('Meta Test Error', meta_test_error / meta_batch_size)
        print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()























Iteration 0
Meta Train Error 0.06276281294412911
Meta Train Accuracy 0.2899999995715916
Meta Valid Error 0.06453563086688519
Meta Valid Accuracy 0.2787499986588955
Meta Test Error 0.06388386012986302
Meta Test Accuracy 0.2750000008381903
























Iteration 1
Meta Train Error 0.0618893614737317
Meta Train Accuracy 0.31749999918974936
Meta Valid Error 0.06396064907312393
Meta Valid Accuracy 0.26749999984167516
Meta Test Error 0.060491363517940044
Meta Test Accuracy 0.32625000015832484
























Iteration 2
Meta Train Error 0.06243692373391241
Meta Train Accuracy 0.3287500001024455
Meta Valid Error 0.06328856549225748
Meta Valid Accuracy 0.30750000057742
Meta Test Error 0.062115458538755774
Meta Test Accuracy 0.30874999961815774
























Iteration 3
Meta Train Error 0.06213757861405611
Meta Train Accuracy 0.3287500024307519
Meta Valid Error 0.06418373435735703
Meta Valid Accuracy 0.31000000261701643
Meta Test Error 0.06271003617439419
Meta Test Accuracy 0.3074999994132668
























Iteration 4
Meta Train Error 0.06166734057478607
Meta Train Accuracy 0.3412499991245568
Meta Valid Error 0.06433674821164459
Meta Valid Accuracy 0.2875000019557774
Meta Test Error 0.05917097348719835
Meta Test Accuracy 0.3475000001490116


