# Advances in few-shot learning: reproducing results in PyTorch

![fsl](https://miro.medium.com/max/3840/1*cbsl9NSG5DXq4qSWJXLNsw.png)

Few-shot learning is an exciting field of machine learning which aims to close the gap between machine and human in the challenging task of learning from few examples. In my [previous post](https://medium.com/@oknagg/advances-in-few-shot-learning-a-guided-tour-36bc10a68b77) I provided a high level summary of three cutting edge papers in few-shot learning — I assume you’ve either read that, are already familiar with these papers or are in the process of reproducing them yourself.

In this post I will guide you through my experience in reproducing the results of these papers on the Omniglot and miniImageNet datasets, including some of the pitfalls and stumbling blocks on the way. Each paper has its own section in which I provide a Github gist with PyTorch code to perform a single parameter update on the model described by the paper. To train the model just have to put that function inside a loop over the training data. Less interesting details such as dataset handling are omitted for brevity.

Reproducibility is very important, it is the foundation of any field that claims to be scientific. This makes me believe that the prevalence of code-sharing and open-sourcing in machine learning is truly admirable. While publishing code alone is not reproducibility (as there may be implementation errors) it opens up researchers methods to public scrutiny and more importantly accelerates the research of others in the field. In light of this I’d like to thank the authors of these papers for sharing their code as well as any others who’ve open-sourced their implementations.

For the full implementation please see my Github repo at https://github.com/oscarknagg/few-shot.

![Things don’t always go to plan. Just see this training curve of a failed MAML implementation for example!](https://miro.medium.com/max/784/1*cVebkA3CiwpaIIUA1UNe8Q.png)
*Things don’t always go to plan. Just see this training curve of a failed MAML implementation for example!*


## Datasets

There are two image datasets on which few-shot learning algorithms are evaluated. The first is the Omniglot dataset, which contains 20 images each of roughly 1600 characters from 50 alphabets. These images are typically 28x28 grayscale which is one reason why this dataset is often called the transpose of MNIST.

![Samples from the Omniglot dataset.](https://miro.medium.com/max/1400/1*T_4SiA5WB1tJ4makjdsW3Q.png)
*Samples from the Omniglot dataset*

The second is the miniImageNet dataset, a subset of ImageNet intended to be a more challenging benchmark without being as cumbersome as the full ImageNet dataset. miniImageNet consists of 60,000, 84x84 RGB images with 600 images per class.
![Samples from the miniImageNet dataset before taking center crop and resizing to 84x84](https://miro.medium.com/max/1400/1*HL2_qYvwx_6wPrBB4RxPKA.png)
*Samples from the miniImageNet dataset before taking center crop and resizing to 84x84*

In both cases the classes in the training and validation sets are disjoint. I did not use the same training and validation splits as the original papers as my goal is not to reproduce them down to their last minute detail.


## Matching Networks

In [Matching Networks](https://arxiv.org/pdf/1606.04080.pdf) Vinyals et al introduce the idea of a **fully differentiable nearest neighbours classifier** that is both trained and tested on few-shot tasks.

The Matching Networks algorithm can be summarised as follows:

- First embed all samples (query and support set) using an **encoder network** (4 layer CNN in this case). This is performed by model.encode() (line 41).
- Optionally calculate **full context embeddings (FCE)**. An LSTM takes the original embeddings as inputs and outputs modified embeddings, taking into account the support set. This is performed by model.f() and model.g() (lines 62 and 67).
- Calculate **pairwise distances** between query samples and support sets and normalise using softmax (lines 69 to 77).
- Calculate **predictions** by taking the weighted average of the support set labels with the normalised distance (lines 83–89).

Some things to note:

- In this example the $x$ Tensor contains first the support set samples and then the query. For Omniglot it will have shape $(n_{support} + n_{query}, 1, 28, 28)$.
- The math in the previous post is for one query sample but Matching Networks are in fact trained with a batch of query samples of size $q_{queries} * k_{way}$.


In [1]:
import torch
from torch.nn.utils import clip_grad_norm_

In [None]:
def matching_net_episode(model: Module,
                         optimizer: Optimizer,
                         loss_fn: Loss,
                         x: torch.Tensor,
                         y: torch.Tensor,
                         n_shot: int,
                         k_way: int,
                         q_queries: int,
                         distance: str,
                         fce: bool,
                         train: bool):
    """
    Performs a single training episode for a Matching Network.
    
    Args:
        model: Matching Network to be trained.
        optimizer: Optimizer to calculate gradient step from loss.
        loss_fn: Loss function to calculate between predictions and outputs.
        x: Input samples of few-shot classification task.
        y: Input labels of few-shot classification task.
        n_shot: Number of examples per class in the support set.
        k_way: Number of classes in the few-shot classification task.
        q_queries: Number of example per class in the query set.
        distance: Distance metric to use when claculating distance between support and query set samples.
        fce: Whether or not to use fully condictional embeddings.
        train: Whether (True) or not (False) to perform a parameter update.
        
    Returns:
        loss: Loss of the Matching Network on this task.
        y_pred: Predicted class probabilities for the query set on this task.
    """
    
    if train:
        model.train()
        optimizer.zero_grad()
    else:
        model.eval()
        
    # Embed all samples
    embeddings = model.encoder(x)
    
    # Samples are ordered by the NShotWrapper class as follows:
    # k lots of n support samples from a particular class
    # k lots of q query samples from those classes
    support = embeddings[: n_shot * k_way]
    queries = embeddings[n_shot * q_queries :]
    y_support = y[: n_shot * k_way]
    y_queries = y[n_shot * q_queries :]
    
    # Optionally apply full context embeddings (FCE):
    if fce:
        # LSTM requires input of shape (seq_len, batch, input_size).
        # 'support' is of shape (k_way * n_shot, embedding_dim) and we want the LSTM to 
        # treat the support set as a sequence so add a single dimension to transform support
        # set to the shape (k_way * n_shot, 1, embedding_dim) and then remove the batch 
        # dimension afterwards
        
        # Calculate the fully conditional embedding, g, for support set samples as 
        # described in appendix A.2 of the paper. 
        # g takes the form of a bidirectional LSTM with a skip connection from inputs to outputs.
        support, _, _ = model.g(support.unsqueeze(1))
        support = support.sequeeze(1)
        
        # Calculate the fully conditional embedding, f, for the query set samples
        # as described in appendix A.1 of the paper.
        queries = model.f(support, queries)
    
    # Calculate distance between all queries and all prototypes
    # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
    distances = (
        queries.unsqueeze(1).expand(queries.shape[0], support.shape[0], -1) - 
        queries.unsqueeze(0).expand(queries.shape[0], support.shape[0], -1)
    ).pow(2).sum(dim=2)
    
    # Calculate "attention" as softmax over support-query distances
    attention = (-distances).softmax(dim=1)
    
    # Calculate predictions as in equation (1) from Matching Networks
    # y_hat = \sum_{i=1}^{k} a(x_hat, x_i) y_i
    # Create one-hot encoded label vector for the support set, the 
    # default PyTorch format is for labels to be integers
    y_onehot = torch.zeros(k * n, k)
    
    # Unsqueeze to force y to be 2D as this
    # is needed for .scatter()
    y_onehot = y_onehot.scatter(1, y_support, 1)
    
    y_pred = torch.mm(attention, y_onehot.cuda().double())
    
    # Calculated loss with negative log likelihood
    # Clip predictions for numerical stability
    clipped_y_pred = y_pred.clamp(1e-8, 1 - 1e-8)
    loss = loss_fn(clipped_y_pred.log(), y_queries)
    
    if train:
        # Backpropagate gradients
        loss.backward()
        
        # I found training to be quite unstable so I clip the norm
        # of the gradient to be at most 1
        clip_grad_norm_(model.parameters(), 1)
        
        # Take gradient step
        optimizer.step()
        
    return loss, y_pred

I was unable to reproduce the results of this paper using cosine distance but was successful when using l2 distance. 

I believe this is because cosine distance is bounded between -1 and 1 which then limits the amount that the attention function (a(x^, x_i) below) can point to a particular sample in the support set. Since cosine distance is bounded a(x^, x_i) will never be close to 1! In the case of 5-way classification the maximum possible value of a(x^, x_i) is exp(1)/ (exp(1) + 4*exp(-1)) ≈ 0.65. This led to very slow convergence when using cosine distance.

![eq1](https://miro.medium.com/max/756/1*Quo_tUQ2kE4v0c-y7n3RCA.png)

![eq2](https://miro.medium.com/max/1400/1*KpI9WoSeoz0G3u9JesUdUQ.png)

I think it’s possible to reproduce the results using cosine distance with either longer training times, better hyperparameters or a heuristic like multiplying the cosine distance by a constant factor. Seeing as the choice of distance is not key to the paper and results are very good using l2 distance I decided to spare myself that debugging effort.

![result1](https://miro.medium.com/max/888/1*BQJY8y9mU-LRkNfXR-JaJw.png)
![result2](https://miro.medium.com/max/812/1*X9wkdq19dl4hpExUmhpwdA.png)

## Prototypical Networks

In [Prototypical Networks](https://arxiv.org/pdf/1703.05175.pdf) Snell et al use a compelling inductive bias motivated by the theory of Bregman divergences to achieve impressive few-shot performance.

The Prototypical Network algorithm can be summarised as follows:

- Embed all query and support samples (line 36).
- Calculate class prototypes taking the mean of the embeddings of each class (line 48).
- Predictions are a softmax over the distances between the query samples and the class prototypes (line 63).



In [None]:
def proto_net_episode(model: Module,
                      optimizer: Optimizer,
                      loss_fn: Callable,
                      x: torch.Tensor,
                      y: torch.Tensor,
                      n_shot: int,
                      k_way: int,
                      q_queries: int,
                      distance: str,
                      train: bool):
    """
    Performs a single training episode for a Prototypical Network.
    
    Args:
        model: Prototypical Network to be trained.
        optimizer: Optimizer to calculate gradient steps.
        loss_fn: Loss function to calculate between predictions and outputs. Should be cross-entropy.
        x: Input samples of few-shot classification task.
        y: Input labels of few-shot classification task.
        n_shot: Number of examples per class in the support set.
        k_way: Number of classes in the few shot classification task.
        q_queries: Number of examples per class in the query set.
        distance: Distance metric to use when calculating distance between class prototypes and queries.
        train: Whether (True) or not (False) to perform a parameter update.
        
    Returns:
        loss: Loss of the Prototypical Network on this task.
        y_pred: Predicted class probabilities for the query set on this task.
    """
    
    if train:
        model.train()
        optimizer.zero_grad()
    else:
        model.eval()
    
    # Embed all samples
    embeddings = model(x)
    
    # Samples are ordered by the NShotWrapper class as follows:
    # k lots of n support samples from a particular class
    # k lots of q query samples from those classes
    support = embeddings[: n_shot * k_way]
    queries = embeddings[n_shot * k_way :]
    y_support = y[: n_shot * k_way]
    y_queries = y[n_shot * q_queries :]
    
    # Reshape so the first dimension indexes by class then take the mean
    # along that dimension to generate the "prototypes" for each class
    prototypes = support.reshape(k, n, -1).mean(dim=1)
    
    # Calculate squared distances between all queries and all prototypes
    # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
    distances = (
        queries.unsqueeze(1).expand(queries.shape[0], support.shape[0], -1) -
        queries.unsqueeze(0).expand(queries.shape[0], support.shape[0], -1)
    ).pow(2).sum(dim=2)
    
    # Calculate log p_{phi} (y = k | x)
    log_p_y = (-distances).log_softmax(dim=1)
    loss = loss_fn(log_p_y, y_queries)
    
    # Prediction probabilities are softmax over distances
    y_pred = (-distances).softmax(dim=1)
    
    if train:
        # Take gradient step
        loss.backward()
        optimizer.step()
        
    return loss, y_pred

I found this paper delightfully easy to reproduce as the authors provided the full set of hyperparameters. Hence I was easily able to achieve the stated performance to within ~0.2% on the Omniglot benchmark and within a few % on the miniImageNet benchmark without having to perform any tuning of my own.

![result1](https://miro.medium.com/max/774/1*soUZhuA4G226kxxAKSHYew.png)
![result2](https://miro.medium.com/max/606/1*xHT4OOi1BoDxdoeLxS9S6g.png)




## Meta-Agnostic Meta-Learning (MAML)

In [MAML](https://arxiv.org/pdf/1703.03400.pdf) Finn et al introduce a powerful and broadly applicable meta-learning algorithm to learn a network initialisation that can quickly adapt to new tasks. This paper was the most difficult yet most rewarding to reproduce of the three in this article.

The MAML algorithm can be summarised as follows:

- For each n-shot task in a meta-batch of tasks, create a new model using the weights of the base model AKA meta-learner (line 79).
- Update the weights of the new model using the loss from the samples in the task by stochastic gradient descent (lines 81–92).
- Calculate loss of the updated model on some more data from the same task (lines 94–97)
- If performing 1st order MAML update the meta-learner weights with the gradient of the loss from part 3. If performing 2nd order MAML calculate the derivative of this loss with respect to the *original weights* (lines 110+).

The biggest appeal of PyTorch is its autograd system. This is a piece of code-magic that records operations acting on `torch.Tensor` objects and dynamically builds the directed acyclic graph of these operations under the hood. Backpropagation is as simple as calling `.backwards()` on the final result. I had to learn a bit more about this system in order to calculate and apply parameter updates to the meta-learner, which I will now share with you.

### 1st Order MAML — gradient swapping

Typically when training a model in PyTorch you create an Optimizer object tied to the parameters of a particular model.


In [None]:
from torch.optim import Adam

opt = Adam(model.parameters(), lr=0.001)

When `opt.step()` is called the optimizer reads the gradients on the model parameters and calculates an update to those parameters. However in 1st order MAML we’ re going to calculate the gradients using one model (*the fast weights*) and apply the update to a *different model* i.e. the meta-learner.

A solution to this is to use an under-utilised bit of PyTorch functionality in the form of `torch.Tensor.register_hook(hook)`. Register a hook function to a Tensor and this hook function will be called whenever a gradient with respect to this tensor is computed. For each parameter Tensor in the meta-learner I register a hook that simply replaces the gradient with the corresponding gradient on the fast weights (lines 111–129 in gist). This means that when `opt.step()` is called the gradients of the fast model will be used to update the meta-learner weights as desired.

### 2nd Order MAML — autograd issues

When making my first attempt at implementing MAML I instantiated a new model object (subclass of `torch.nn.Module`) and set the values of its weights equal to the meta-learner’s weights. However this makes it impossible to perform 2nd order MAML as the weights of the fast model are disconnected from the weights of the meta-learner in the eyes of `torch.autograd`. What this means is when I call `optimizer.step()` (line 140 in the gist) the autograd graph for the meta-learner weights is empty and no update is performed.

In [None]:
# This didn't work, meta_learner weights remain unchanged
meta_learner = ModelClass()
opt = Adam(meta_learner.parameters(), lr=0.001)

task_losses = []
for x, y in meta_batch:
    fast_model = ModelClass()
    # torch.autograd loses reference here!
    copy_weights(from=meta_learner, to=fast_model)
    task_losses.append(update_with_batch(x, y))
    
meta_batch_loss = torch.stack(task_losses).mean()
meta_batch_loss.backward()
opt.step()

The solution to this is `functional_forward()` (line 17) which is a slightly awkward hack that manually performs the same operations (convolution, max pooling, etc…) as the model class using `torch.nn.functional`. This also means that I have to manually perform a parameter update of the fast model. The consequence of this is that `torch.autograd` knows to backpropagate gradients to the original weights of the meta-learner. This leads to a spectacularly large autograd graph.

![PyTorch autograd graph for 1st order MAML (left) and 2nd order MAML (right) with one inner training step. You can probably see why 2nd order MAML has much higher memory requirements!](https://miro.medium.com/max/1400/1*4nyXG0ozncYxzuCE1EOBBg.png)
*PyTorch autograd graph for 1st order MAML (left) and 2nd order MAML (right) with one inner training step. You can probably see why 2nd order MAML has much higher memory requirements!*

However 2nd order MAML is a trickier beast than just that. When I first wrote my 2nd order MAML implementation I thought I had got everything to work miraculously on the first try. At least there were no exceptions right? Only after running a full set of Omniglot and miniImageNet experiments did I begin to doubt my work — the results were just too similar to 1st order MAML. This is typical of an unfortunate breed of silent ML bugs which don’t cause exceptions but only become visible in the final performance of a model.

Hence I decided to buckle down and write a unit test that would confirm that I was truly performing a 2nd order update. Disclaimer: in the spirit of true test-driven development I should’ve written this test before running any experiments 😛.

The test I decided on was to run the `meta_gradient_step` function on a dummy model and manually parse the autograd graph, counting the number of double backwards operations. This way I can be absolutely sure that I am performing a 2nd order update when desired. Conversely, I was able to test that my 1st order MAML implementation only performs a 1st order update with no double backwards operations.

I finally located the bug to not applying the `create_graph` parameter in the inner training loop (line 86). I was, however, retaining the autograd graph of the loss on the query samples (line 97) but this was insufficient to perform a 2nd order update as the unrolled training graph was not created.


In [3]:
import torch
import torch.nn.functional as F

In [None]:
def replace_grad(parameter_gradients, parameter_name):
    """
    Creates a backward hook function that replaces the calculated gradient
    with a precomputed value when .backward() is called.
    
    See: https://pytorch.org/docs/stable/autograd.html?highlight=hook#torch.Tensor.register_hook
    for more info.
    """
    
    def replace_grad_(module):
        return parameter_gradients[parameter_name]
    
    return replace_grad_

def functional_forward(x: torch.Tensor, weights: dict):
    """
    Performs a forward pass of the network using the PyTorch functional API.
    """
    for block in [1, 2, 3, 4]:
        x = functional_conv_block(x, 
                                  weights[f'conv{block}.0.weight'], 
                                  weights[f'conv{block}.0.bias'],
                                  weights.get(f'conv{block}.1.weight'),
                                  weights.get(f'conv{block}.1.bias'))
    x = x.view(x.size(0), -1)
        
    x = F.linear(x, weights['logits.weight'], weights['logits.bias'])
        
    return x
    
def meta_gradient_step(model: Module,
                       optimizer: Optimizer,
                       loss_fn: Callable,
                       x: torch.Tensor,
                       y: torch.Tensor,
                       n_shot: int,
                       k_way: int,
                       q_queries: int,
                       order: int,
                       inner_train_steps: int,
                       inner_lr: float,
                       train: bool,
                       device: Union[str, torch.device]):
    """
    Perform a gradient step on a meta-learner.
    
    Args:
        model: Base model of the meta-learner being trained.
        optimizer: Optimizer to calculate gradient step from loss.
        loss_fn: Loss function to calculate between predictions and outputs.
        x: Input samples for all few-shot tasks.
        y: Input labels of all few-shot tasks.
        n_shot: Number of examples per class in the support set of each task.
        k_way: Number of classes in the few shot classification task of each task.
        q_queries: Number of examples per class in the query set of each task. The query set is used to calculate
                   meta-gradients after applying the update to.
        order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the
               query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated
               weights on the query with respect to the original weights).
        inner_train_steps: Number of gradient steps to fit the fast weights during each inner update.
        inner_lr: Learning rate used to update the fast weights on the inner update.
        train: Whether to update the meta-learner weights at the end of the episode.
        device: Device on which to run computation.
    """
    
    data_shape = x.shape[2:]
    create_graph = (True if order==2 else False) and train
    
    task_gradients = []
    task_losses = []
    task_predictions = []
    
    for meta_batch_examples, meta_batch_labels in zip(x, y):
        # By construction x is a 5-D tensor of shape: (meta_batch_size, n*k + q*k, channels, width, height)
        # Hence when we iterate over the first dimension we are iterating through the meta batches
        # Equivalently y is a 2-D tensor of shape: (meta_batch_size, n*k + q*k, 1)
        x_task_train = meta_batch_samples[: n_shot * k_way]
        x_task_val = meta_batch_samples[n_shot * k_way :]
        y_task_train = meta_batch_labels[: n_shot * k_way]
        y_task_val = meta_batch_labels[n_shot * k_way :]
        
        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())
        
        # Train the model for 'inner_train_steps' iterations
        for inner_batch in range(inner_train_steps):
            # Perform update of model weights
            logits = model.functional_forward(x_task_train, fast_weights)
            loss = loss_fn(logits, y_task_train)
            gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)
            
            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), gradients)
            )
            
        # Do a pass of the model on the validation data from the current task
        logits = functional_forward(x_task_val, fast_weights)
        loss = loss_fn(logits, y_task_val)
        loss.backward(retain_grad=True)
        
        # Get post-update accuracies
        y_pred = logits.softmax(dim=1)
        task_predictions.append(y_pred)
        
        # Accumulate losses and gradients
        task_losses.append(loss)
        gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)
        named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), gradients)}
        task_gradients.append(named_grads)

    if order == 1:
        if train:
            sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
                                  for k in task_gradients[0].keys()}
            hooks = []
            for name, param in model.named_parameters():
                hooks.append(
                    param.register_hook(replace_grad(sum_task_gradients, name))
                )

            model.train()
            optimiser.zero_grad()
            # Dummy pass in order to create `loss` variable
            # Replace dummy gradients with mean task gradients using hooks
            logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
            loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
            loss.backward()
            optimiser.step()

            for h in hooks:
                h.remove()

        return torch.stack(task_losses).mean(), torch.cat(task_predictions)

    elif order == 2:
        model.train()
        optimiser.zero_grad()
        meta_batch_loss = torch.stack(task_losses).mean()

        if train:
            meta_batch_loss.backward()
            optimiser.step()

        return meta_batch_loss, torch.cat(task_predictions)

Training time was quite long (over 24 hours for the 5-way, 5-shot miniImageNet experiment) but in the end I had fairly good success reproducing results.

![result1](https://miro.medium.com/max/824/1*U5eIDnl4xRyaOLEVZtk_Pg.png)
![result2](https://miro.medium.com/max/662/1*8VZj62bp1dMMZTsLz1NBug.png)

I hope that you’ve learnt something useful from this technical deep dive. If you have any questions feel free to let me know in the comments.