In [3]:
import sys
import random
import numpy as np
import torch
import learn2learn as l2l
import pandas as pd

from torch import nn, optim
import torchvision
from torchvision import datasets, transforms
from PIL.Image import LANCZOS
from config import *

In [2]:
torch.cuda.is_available()

True

In [4]:
maml_omniglot

{'root': '/home/anuj/Desktop/Work/TU_Delft/research/implement/omniglot',
 'n_ways': 5,
 'k_shots': 1,
 'q_shots': 1,
 'inner_adapt_steps': 1,
 'inner_lr': 0.5,
 'meta_lr': 0.003,
 'meta_batch_size': 32,
 'iterations': 60000,
 'order': False,
 'device': 'cuda'}

In [3]:
sys.path.append('/home/anuj/Desktop/Work/TU_Delft/research/implement/learning_to_meta-learn')

In [4]:
sys.path

['/home/anuj/Desktop/Work/TU_Delft/research/implement/learning_to_meta-learn/src',
 '/home/anuj/.vscode/extensions/ms-toolsai.jupyter-2021.6.832593372/pythonFiles',
 '/home/anuj/.vscode/extensions/ms-toolsai.jupyter-2021.6.832593372/pythonFiles',
 '/home/anuj/.vscode/extensions/ms-toolsai.jupyter-2021.6.832593372/pythonFiles/lib/python',
 '/home/anuj/anaconda3/envs/torch/lib/python38.zip',
 '/home/anuj/anaconda3/envs/torch/lib/python3.8',
 '/home/anuj/anaconda3/envs/torch/lib/python3.8/lib-dynload',
 '',
 '/home/anuj/anaconda3/envs/torch/lib/python3.8/site-packages',
 '/home/anuj/anaconda3/envs/torch/lib/python3.8/site-packages/datasets-1.2.1-py3.8.egg',
 '/home/anuj/anaconda3/envs/torch/lib/python3.8/site-packages/locket-0.2.1-py3.8.egg',
 '/home/anuj/anaconda3/envs/torch/lib/python3.8/site-packages/IPython/extensions',
 '/home/anuj/.ipython',
 '/home/anuj/Desktop/Work/TU_Delft/research/implement/learning_to_meta-learn']

In [5]:
from data.loaders import Omniglotmix, MiniImageNet
from data.taskers import gen_tasks
from src.zoo.archs import EncoderNN
#from src.zoo.maml_utils import inner_adapt_maml, setup, accuracy
#from src.utils import Profiler

In [6]:
classes = list(range(1623))
random.shuffle(classes)
image_transforms = transforms.Compose([transforms.Resize(28, interpolation=LANCZOS),
                                                    transforms.ToTensor(),
                                                    lambda x: 1.0 - x,
                                                ])



In [15]:
train_tasks = gen_tasks('omniglot', '/home/anuj/Desktop/Work/TU_Delft/research/implement/omniglot', image_transforms=image_transforms, n_ways=5, k_shots=5, q_shots=5, classes=classes[:1100])

In [8]:
train_tasks.sample()[0].shape

torch.Size([50, 1, 28, 28])

In [18]:
task = train_tasks.sample()

In [25]:
train_tasks = gen_tasks(dataname='miniimagenet', root='../../mini_imagenet', mode='train', n_ways=5, k_shots=1, q_shots=3, image_transforms=None)

In [26]:
data1, labels1 = train_tasks.sample()

In [27]:
data1.shape

torch.Size([20, 3, 84, 84])

In [28]:
class MatchingNetwork(nn.Module):
    def __init__(self, n: int, k: int, q: int, fce: bool, num_input_channels: int,
                 lstm_layers: int, lstm_input_size: int, unrolling_steps: int, device: torch.device):
        """Creates a Matching Network as described in Vinyals et al.

        # Arguments:
            n: Number of examples per class in the support set
            k: Number of classes in the few shot classification task
            q: Number of examples per class in the query set
            fce: Whether or not to us fully conditional embeddings
            num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1,
                miniImageNet = 3
            lstm_layers: Number of LSTM layers in the bidrectional LSTM g that embeds the support set (fce = True)
            lstm_input_size: Input size for the bidirectional and Attention LSTM. This is determined by the embedding
                dimension of the few shot encoder which is in turn determined by the size of the input data. Hence we
                have Omniglot -> 64, miniImageNet -> 1600.
            unrolling_steps: Number of unrolling steps to run the Attention LSTM
            device: Device on which to run computation
        """
        super(MatchingNetwork, self).__init__()
        self.n = n
        self.k = k
        self.q = q
        self.fce = fce
        self.num_input_channels = num_input_channels
        self.encoder = get_few_shot_encoder(self.num_input_channels)
        if self.fce:
            self.g = BidrectionalLSTM(lstm_input_size, lstm_layers).to(device, dtype=torch.double)
            self.f = AttentionLSTM(lstm_input_size, unrolling_steps=unrolling_steps).to(device, dtype=torch.double)

    def forward(self, inputs):
        pass


class BidrectionalLSTM(nn.Module):
    def __init__(self, size: int, layers: int):
        """Bidirectional LSTM used to generate fully conditional embeddings (FCE) of the support set as described
        in the Matching Networks paper.

        # Arguments
            size: Size of input and hidden layers. These are constrained to be the same in order to implement the skip
                connection described in Appendix A.2
            layers: Number of LSTM layers
        """
        super(BidrectionalLSTM, self).__init__()
        self.num_layers = layers
        self.batch_size = 1
        # Force input size and hidden size to be the same in order to implement
        # the skip connection as described in Appendix A.1 and A.2 of Matching Networks
        self.lstm = nn.LSTM(input_size=size,
                            num_layers=layers,
                            hidden_size=size,
                            bidirectional=True)

    def forward(self, inputs):
        # Give None as initial state and Pytorch LSTM creates initial hidden states
        output, (hn, cn) = self.lstm(inputs, None)

        forward_output = output[:, :, :self.lstm.hidden_size]
        backward_output = output[:, :, self.lstm.hidden_size:]

        # g(x_i, S) = h_forward_i + h_backward_i + g'(x_i) as written in Appendix A.2
        # AKA A skip connection between inputs and outputs is used
        output = forward_output + backward_output + inputs
        return output, hn, cn


class AttentionLSTM(nn.Module):
    def __init__(self, size: int, unrolling_steps: int):
        """Attentional LSTM used to generate fully conditional embeddings (FCE) of the query set as described
        in the Matching Networks paper.

        # Arguments
            size: Size of input and hidden layers. These are constrained to be the same in order to implement the skip
                connection described in Appendix A.2
            unrolling_steps: Number of steps of attention over the support set to compute. Analogous to number of
                layers in a regular LSTM
        """
        super(AttentionLSTM, self).__init__()
        self.unrolling_steps = unrolling_steps
        self.lstm_cell = nn.LSTMCell(input_size=size,
                                     hidden_size=size).double()

    def forward(self, support, queries):
        # Get embedding dimension, d
        if support.shape[-1] != queries.shape[-1]:
            raise(ValueError("Support and query set have different embedding dimension!"))

        batch_size = queries.shape[0]
        embedding_dim = queries.shape[1]

        h_hat = torch.zeros_like(queries).cuda().double()
        c = torch.zeros(batch_size, embedding_dim).cuda().double()

        for k in range(self.unrolling_steps):
            # Calculate hidden state cf. equation (4) of appendix A.2
            h = h_hat + queries

            # Calculate softmax attentions between hidden states and support set embeddings
            # cf. equation (6) of appendix A.2
            attentions = torch.mm(h.double(), support.t().double())
            attentions = attentions.softmax(dim=1)

            # Calculate readouts from support set embeddings cf. equation (5)
            readout = torch.mm(attentions.double(), support.double())

            # Run LSTM cell cf. equation (3)
            # h_hat, c = self.lstm_cell(queries, (torch.cat([h, readout], dim=1), c))
            h_hat, c = self.lstm_cell(queries.double(), (h.double() + readout.double(), c.double()))

        h = h_hat + queries

        return h


In [29]:
model = EncoderNN(3, (2,2), True)

In [30]:
a = model(data1)

In [31]:
a.shape

torch.Size([20, 1600])

In [32]:
labels1.shape

torch.Size([20])

In [34]:
device = 'cuda'
n_ways = 5; k_shots = 1; q_shots = 3

In [35]:
data, labels = a, labels1
data, labels = data.to(device), labels.to(device)
total = n_ways * (k_shots + q_shots)
queries_index = np.zeros(total)

# Extracting the evaluation datums from the entire task set, for the meta gradient calculation
for offset in range(n_ways):
    queries_index[np.random.choice(
        k_shots+q_shots, q_shots, replace=False) + ((k_shots + q_shots)*offset)] = True
support = data[np.where(queries_index == 0)]
support_labels = labels[np.where(queries_index == 0)]
queries = data[np.where(queries_index == 1)]
queries_labels = labels[np.where(queries_index == 1)]


In [37]:
queries.shape

torch.Size([15, 1600])

In [17]:
support.unsqueeze(1)

tensor([[[0.0717, 0.1103, 0.1152,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.3320, 0.1983, 0.0410,  ..., 0.3116, 0.3548, 0.0000]],

        [[0.0693, 0.1728, 0.1736,  ..., 0.0000, 0.0000, 0.0394]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.5983, 0.5914, 0.3825]],

        [[0.0000, 0.0000, 0.0132,  ..., 0.1370, 0.0000, 0.0000]]],
       device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [38]:
g = BidrectionalLSTM(1600, 1).to('cuda')

In [39]:
support, _, _ = g(support.unsqueeze(1))
support = support.squeeze(1)

In [40]:
support

tensor([[ 0.1885,  0.5725,  0.2545,  ...,  0.6749,  0.4760,  0.4216],
        [ 0.1941,  0.3804,  0.4836,  ...,  0.1441,  0.7202,  1.3960],
        [ 0.1631,  0.4248,  0.3174,  ...,  0.8357,  0.0433,  2.3977],
        [ 0.2788,  1.0453,  0.7681,  ...,  0.6483,  0.3816,  0.6864],
        [ 0.7467,  0.7731,  0.5384,  ..., -0.1106,  0.2103,  0.1904]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)

In [41]:
f = AttentionLSTM(1600, 2).to('cuda')

In [42]:
queries = f(support, queries)

In [43]:
def logits(support, queries, EPSILON):
    # Module with cosine similarity

    n_queries = queries.shape[0]
    n_support = support.shape[0]

    normalised_queries = queries / (queries.pow(2).sum(dim=1, keepdim=True).sqrt() + EPSILON)
    normalised_support = support / (support.pow(2).sum(dim=1, keepdim=True).sqrt() + EPSILON)

    expanded_x = normalised_queries.unsqueeze(1).expand(n_queries, n_support, -1)
    expanded_y = normalised_support.unsqueeze(0).expand(n_queries, n_support, -1)

    logits = (expanded_x * expanded_y).sum(dim=2)
    return 1 - logits

In [52]:
attention = (-logits(support, queries, 0.00001)).softmax(dim=1)

In [53]:
y_onehot = torch.zeros(n_ways * k_shots, n_ways).to(device)

# Unsqueeze to force y to be of shape (K*n, 1) as this
# is needed for .scatter()
y = support_labels.unsqueeze(-1)
y_onehot = y_onehot.scatter(1, y, 1)

y_pred = torch.mm(attention, y_onehot.cuda().double())

# Calculated loss with negative log likelihood
# Clip predictions for numerical stability
clipped_y_pred = y_pred.clamp(0.0001, 1 - 0.0001)
#eval_loss = loss(clipped_y_pred.log(), queries_labels)


In [56]:
clipped_y_pred.argmax(dim=1)

tensor([0, 3, 0, 0, 3, 3, 1, 1, 1, 3, 0, 3, 1, 1, 2], device='cuda:0')

In [None]:

    # Optionally apply full context embeddings
    if fce:
        # LSTM requires input of shape (seq_len, batch, input_size). `support` is of
        # shape (k_way * n_shot, embedding_dim) and we want the LSTM to treat the
        # support set as a sequence so add a single dimension to transform support set
        # to the shape (k_way * n_shot, 1, embedding_dim) and then remove the batch dimension
        # afterwards

        # Calculate the fully conditional embedding, g, for support set samples as described
        # in appendix A.2 of the paper. g takes the form of a bidirectional LSTM with a
        # skip connection from inputs to outputs
        support, _, _ = model.g(support.unsqueeze(1))
        support = support.squeeze(1)

        # Calculate the fully conditional embedding, f, for the query set samples as described
        # in appendix A.1 of the paper.
        queries = model.f(support, queries)

    # Efficiently calculate distance between all queries and all prototypes
    # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
    distances = pairwise_distances(queries, support, distance)

    # Calculate "attention" as softmax over support-query distances
    attention = (-distances).softmax(dim=1)

    # Calculate predictions as in equation (1) from Matching Networks
    # y_hat = \sum_{i=1}^{k} a(x_hat, x_i) y_i
    y_pred = matching_net_predictions(attention, n_shot, k_way, q_queries)

    # Calculated loss with negative log likelihood
    # Clip predictions for numerical stability
    clipped_y_pred = y_pred.clamp(EPSILON, 1 - EPSILON)
    loss = loss_fn(clipped_y_pred.log(), y)

    if train:
        # Backpropagate gradients
        loss.backward()
        # I found training to be quite unstable so I clip the norm
        # of the gradient to be at most 1
        clip_grad_norm_(model.parameters(), 1)
        # Take gradient step
        optimiser.step()

    return loss, y_pred


In [29]:
def logits(support, queries, n, k, q):
    prototypes = support.view(n, k, -1).mean(dim=1)
    a = queries.shape[0]
    b = prototypes.shape[0]
    logits = -((queries.unsqueeze(1).expand(a,b,-1) - prototypes.unsqueeze(0).expand(a,b,-1))**2).sum(dim=2)
    return logits

In [None]:
support.view()

In [30]:
logits(support, queries, 30, 1, 15).shape 

torch.Size([450, 30])

In [67]:
queries_labels

tensor([1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 4, 4, 4, 4,
        4], device='cuda:0')

In [25]:
support

tensor([[0.4918, 0.5899, 0.0555,  ..., 1.4189, 1.5469, 1.7144],
        [0.1789, 0.5006, 0.4659,  ..., 1.0633, 1.3545, 0.5393],
        [0.2796, 0.1395, 0.0412,  ..., 0.5320, 0.4916, 0.8295],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.7130, 0.5586, 0.8045],
        [0.2357, 0.5359, 0.5710,  ..., 0.6700, 0.0000, 0.3146],
        [0.0000, 0.0662, 0.5751,  ..., 0.6450, 0.8230, 0.8644]],
       device='cuda:0', grad_fn=<IndexBackward>)

In [44]:
a = support.view(n_ways, k_shots, -1).mean(dim=1)

In [45]:
a.shape

torch.Size([5, 1600])

In [46]:
b = queries

In [50]:
n = a.shape[0]
m = b.shape[0]

In [64]:
input1 = torch.randn(100, 128)
input1.shape

torch.Size([100, 128])

In [None]:
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
output = cos(input1, input2)

In [60]:
queries.unsqueeze(0).expand(n,m,-1)

tensor([[[0.0000, 0.0000, 0.1363,  ..., 0.7791, 0.6836, 0.4011],
         [0.6607, 0.6905, 0.4536,  ..., 1.0306, 1.2191, 1.4360],
         [0.0000, 0.0103, 0.1365,  ..., 0.0908, 0.3466, 0.5294],
         ...,
         [0.6833, 0.2566, 0.2808,  ..., 0.9656, 0.0514, 0.3068],
         [0.2939, 0.4591, 0.6109,  ..., 0.8635, 0.5289, 0.8151],
         [0.4777, 0.2959, 0.3257,  ..., 0.6681, 0.5542, 0.3793]],

        [[0.0000, 0.0000, 0.1363,  ..., 0.7791, 0.6836, 0.4011],
         [0.6607, 0.6905, 0.4536,  ..., 1.0306, 1.2191, 1.4360],
         [0.0000, 0.0103, 0.1365,  ..., 0.0908, 0.3466, 0.5294],
         ...,
         [0.6833, 0.2566, 0.2808,  ..., 0.9656, 0.0514, 0.3068],
         [0.2939, 0.4591, 0.6109,  ..., 0.8635, 0.5289, 0.8151],
         [0.4777, 0.2959, 0.3257,  ..., 0.6681, 0.5542, 0.3793]],

        [[0.0000, 0.0000, 0.1363,  ..., 0.7791, 0.6836, 0.4011],
         [0.6607, 0.6905, 0.4536,  ..., 1.0306, 1.2191, 1.4360],
         [0.0000, 0.0103, 0.1365,  ..., 0.0908, 0.3466, 0.

In [61]:
(a.unsqueeze(1).expand(n, m, -1) - queries.unsqueeze(0).expand(n,m,-1))**2

tensor([[[3.6123e-02, 1.1260e-01, 2.0007e-02,  ..., 7.4561e-02,
          4.5304e-02, 2.9212e-01],
         [2.2153e-01, 1.2597e-01, 3.0933e-02,  ..., 4.6551e-04,
          1.0405e-01, 2.4451e-01],
         [3.6123e-02, 1.0580e-01, 1.9951e-02,  ..., 9.2419e-01,
          3.0241e-01, 1.6983e-01],
         ...,
         [2.4324e-01, 6.2327e-03, 9.3467e-06,  ..., 7.4968e-03,
          7.1422e-01, 4.0293e-01],
         [1.0791e-02, 1.5251e-02, 1.1098e-01,  ..., 3.5604e-02,
          1.3514e-01, 1.5996e-02],
         [8.2713e-02, 1.5748e-03, 2.3003e-03,  ..., 1.4747e-01,
          1.1716e-01, 3.1607e-01]],

        [[8.5715e-02, 9.8171e-02, 5.0794e-02,  ..., 1.8140e-04,
          1.3637e-03, 2.7103e-02],
         [1.3539e-01, 1.4225e-01, 8.4547e-03,  ..., 5.6651e-02,
          2.4849e-01, 7.5748e-01],
         [8.5715e-02, 9.1829e-02, 5.0705e-02,  ..., 4.9246e-01,
          1.3987e-01, 1.3141e-03],
         ...,
         [1.5248e-01, 3.2159e-03, 6.5401e-03,  ..., 2.9931e-02,
          4.478

In [None]:

logits = -((a.unsqueeze(1).expand(n, m, -1) -
            b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2)

In [7]:
train_tasks, valid_tasks, test_tasks, learner = setup('omniglot', '../../omniglot', 5, 5, 5, False, 0.03, 'cuda')
opt = optim.Adam(learner.parameters(), 0.01)
loss = nn.CrossEntropyLoss(reduction='mean')



In [8]:
ttask = train_tasks.sample()

In [9]:
ttask[0].shape

torch.Size([50, 1, 28, 28])

In [10]:
model = learner.clone()

In [11]:
data, labels = ttask
data, labels = data.to(device), labels.to(device)
total = n_ways * (k_shots + q_shots)
queries_index = np.zeros(total)

# Extracting the evaluation datums from the entire task set, for the meta gradient calculation 
for offset in range(n_ways):
    queries_index[np.random.choice(k_shots+q_shots, q_shots, replace=False) + ((k_shots + q_shots)*offset)] = True
support = data[np.where(queries_index == 0)]
support_labels = labels[np.where(queries_index == 0)]
queries = data[np.where(queries_index == 1)]
queries_labels = labels[np.where(queries_index == 1)]

# Inner adapt step
for _ in range(1):
    adapt_loss = loss(learner(support), support_labels)
    learner.adapt(adapt_loss)


In [12]:
preds = learner(queries)
preds

tensor([[-2.3058,  0.8712, -0.3009, -1.7466,  0.2174],
        [-1.1646,  0.6933,  0.0663, -0.8611,  0.2450],
        [-2.1604,  1.5395, -1.5346, -2.1959,  0.7367],
        [-1.6999,  0.4818, -1.1877, -2.2231, -0.6955],
        [-2.3995,  1.2572, -0.9198, -1.8828, -0.8643],
        [-3.2638,  0.4875, -0.9614, -0.7871,  0.4653],
        [-3.3043, -0.3927, -0.9587, -3.2453,  0.3632],
        [-1.9540, -2.1205, -1.3125, -1.1369, -0.3256],
        [-2.2095, -1.0025, -2.8084, -1.4914, -2.5339],
        [-3.6524, -1.3129,  0.4602, -2.3179,  1.8694],
        [-0.2180,  0.5564,  0.4151, -0.3363, -0.8863],
        [ 3.1093, -2.3283, -1.8008,  0.9621, -2.6799],
        [ 0.1015,  0.7979, -0.8207, -1.3333, -1.3204],
        [ 0.5794, -1.4207, -0.9276,  0.3640, -1.8067],
        [ 3.1537, -6.6684, -2.5053,  0.5397, -2.4689],
        [ 1.1471, -0.7925, -0.6154, -1.3091, -1.4252],
        [ 1.9084, -4.7103, -3.9228, -2.1826, -1.3713],
        [ 2.2197, -0.5845, -0.1133,  0.5664, -1.8442],
        [ 

In [13]:
eval_loss = loss(preds, queries_labels)

In [16]:
 preds.argmax(dim=1) == queries_labels

tensor([ True,  True,  True,  True,  True, False,  True,  True, False,  True,
        False, False, False, False, False,  True,  True,  True,  True,  True,
         True, False, False, False, False], device='cuda:0')

In [14]:
predictions = preds.argmax(dim=1).view(queries_labels.shape)
(predictions == queries_labels).sum().float() / queries_labels.size(0)

tensor(0.5600, device='cuda:0')

In [28]:
evaluation_loss, evaluation_accuracy = inner_adapt_maml(ttask, loss, model, 5,5,5, 1, device)

In [29]:
evaluation_loss.backward()

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.

In [27]:
for p in model.parameters():
    print(p.grad.data)

  print(p.grad.data)


AttributeError: 'NoneType' object has no attribute 'data'

In [32]:
evaluation_accuracy.item()

0.6800000071525574

In [36]:
np.arange(3).std()

0.816496580927726

In [38]:
name = 'anuj'

In [41]:
path = '../logs/{name}.csv' 
path

'../logs/{name}.csv'

In [6]:
profiler = Profiler('ProNets_{}_{}-shot_{}-way_{}-queries'.format('omni', 5,5,5))

In [87]:
name = 'omni'
prof = Profiler('abc_{}'.format(name))

In [7]:
profiler.log([2,2,3,4])

In [46]:
prof = Profiler('MAML_omni')

In [47]:
prof.log([1,2,3,4,5,6,7,8])

In [96]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

In [147]:
def set_Dataset(seed, mach, nways, kshots, tasks, data_loc):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cuda') if mach == 'cuda' else torch.device('cpu')

    data = l2l.vision.datasets.FullOmniglot(root=data_loc, transform=transforms.Compose([
                                                transforms.Resize(28),
                                                transforms.ToTensor()]))
    data = l2l.data.MetaDataset(data)
    transform = [
    l2l.data.transforms.NWays(data, n=2*nways),
    l2l.data.transforms.KShots(data, k=2*kshots),
    l2l.data.transforms.LoadData(data),
]
    tasksets = l2l.data.TaskDataset(data, transform, num_tasks=tasks)
    
    return tasksets, data


In [148]:
tasksets, data = set_Dataset(42, 'cuda', 5, 5, 20000, '../../omniglot')

In [157]:
len(data)

32460

In [158]:
# Load train/validation/test tasksets using the benchmark interface (5-way, 5-shot)
tasksets = l2l.vision.benchmarks.get_tasksets('omniglot',
                                                train_ways=5,
                                                train_samples=2*5,
                                                test_ways=5,
                                                test_samples=2*5,
                                                num_tasks=20000,
                                                root='../../omniglot',
)

Files already downloaded and verified
Files already downloaded and verified


In [161]:
len(tasksets.validation)

20000

In [46]:
tasksets.train.sample()[1].shape

torch.Size([50])

In [47]:
tasksets.test.sample()[0].shape

torch.Size([50, 1, 28, 28])

In [48]:
# Create model
model = l2l.vision.models.OmniglotFC(28 ** 2, 5)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=0.5, first_order=False) #wrapper on model for in-place weight updation (adaptation)
opt = optim.Adam(maml.parameters(), 0.003) #meta-optimizer
loss = nn.CrossEntropyLoss(reduction='mean')

In [49]:
maml

MAML(
  (module): OmniglotFC(
    (features): Sequential(
      (0): Flatten()
      (1): Sequential(
        (0): LinearBlock(
          (relu): ReLU()
          (normalize): BatchNorm1d(256, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
          (linear): Linear(in_features=784, out_features=256, bias=True)
        )
        (1): LinearBlock(
          (relu): ReLU()
          (normalize): BatchNorm1d(128, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
          (linear): Linear(in_features=256, out_features=128, bias=True)
        )
        (2): LinearBlock(
          (relu): ReLU()
          (normalize): BatchNorm1d(64, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
          (linear): Linear(in_features=128, out_features=64, bias=True)
        )
        (3): LinearBlock(
          (relu): ReLU()
          (normalize): BatchNorm1d(64, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
          (linear): 

In [68]:

def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True # set even indices to true
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        train_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(train_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    valid_error = loss(predictions, evaluation_labels)
    valid_accuracy = accuracy(predictions, evaluation_labels)
    return valid_error, valid_accuracy



In [90]:
learner = maml.clone()
batch = tasksets.train.sample()
evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                    learner,
                                                    loss,
                                                    1,
                                                    5,
                                                    5,
                                                    'cuda') # updates model inplace for inner update

evaluation_error.backward() # gradients comp on eval set for meta-update 

In [91]:
a=[]
for p in maml.parameters():
    a.append(p.grad.data)

torch.Size([256])
torch.Size([256])
torch.Size([256, 784])
torch.Size([256])
torch.Size([128])
torch.Size([128])
torch.Size([128, 256])
torch.Size([128])
torch.Size([64])
torch.Size([64])
torch.Size([64, 128])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64, 64])
torch.Size([64])
torch.Size([5, 64])
torch.Size([5])


In [94]:
a

tensor([ 0.0839, -0.0497, -0.0046, -0.0065, -0.0231], device='cuda:0')

In [None]:

for iteration in range(1):
    opt.zero_grad()
    meta_train_error = 0.0
    meta_train_accuracy = 0.0
    meta_valid_error = 0.0
    meta_valid_accuracy = 0.0
    for task in range(32):
        # Compute meta-training loss
        learner = maml.clone()
        batch = tasksets.train.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                            learner,
                                                            loss,
                                                            adaptation_steps,
                                                            shots,
                                                            ways,
                                                            device) # updates model inplace for inner update

        evaluation_error.backward() # gradients comp on eval set for meta-update 
        meta_train_error += evaluation_error.item()
        meta_train_accuracy += evaluation_accuracy.item()

        # Compute meta-validation loss
        learner = maml.clone()
        batch = tasksets.validation.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                            learner,
                                                            loss,
                                                            adaptation_steps,
                                                            shots,
                                                            ways,
                                                            device)
        meta_valid_error += evaluation_error.item()
        meta_valid_accuracy += evaluation_accuracy.item()

    # Print some metrics
    print('\n')
    print('Iteration', iteration)
    print('Meta Train Error', meta_train_error / meta_batch_size)
    print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
    print('Meta Valid Error', meta_valid_error / meta_batch_size)
    print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

    # Average the accumulated gradients and optimize
    for p in maml.parameters():
        p.grad.data.mul_(1.0 / meta_batch_size)
    opt.step()

meta_test_error = 0.0
meta_test_accuracy = 0.0
for task in range(meta_batch_size):
    # Compute meta-testing loss
    learner = maml.clone()
    batch = tasksets.test.sample()
    evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                        learner,
                                                        loss,
                                                        adaptation_steps,
                                                        shots,
                                                        ways,
                                                        device)
    meta_test_error += evaluation_error.item()
    meta_test_accuracy += evaluation_accuracy.item()
print('Meta Test Error', meta_test_error / meta_batch_size)
print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)



In [None]:

def main(
        ways=5,
        shots=1,
        meta_lr=0.003,
        fast_lr=0.5,
        meta_batch_size=32,
        adaptation_steps=1,
        num_iterations=60000,
        cuda=True,
        seed=42,
):
    
    



In [None]:
if __name__ == '__main__':
    main()