# Advances in few-shot learning: a guided tour

![intro-image](https://miro.medium.com/max/4100/1*zgOPVND44UMM4Uvs-r-hPg.jpeg)

Few-shot learning is an exciting field of machine learning right now. The ability of deep neural networks to extract complex statistics and learn high level features from vast datasets is proven. Yet current deep learning approaches suffer from poor sample efficiency in stark contrast to human perception — even a child could recognise a giraffe after seeing a single picture. Fine-tuning a pre-trained model is a popular strategy to achieve high sample efficiency but it is a post-hoc hack. Can machine learning do better?

Few-shot learning aims to solve these issues. In this article I will explore some recent advances in few-shot learning through a deep dive into three cutting-edge papers:

1. [Matching Networks](https://arxiv.org/pdf/1606.04080.pdf): A differentiable nearest-neighbours classifier.

2. [Prototypical Networks](https://arxiv.org/pdf/1703.05175.pdf): Learning prototypical representations

3. [Model-agnostic Meta-Learning](https://arxiv.org/pdf/1703.03400.pdf): Learning to fine-tune

I will start with a brief explanation of n-shot, k-way classification tasks which are the de-facto benchmark for few-shot learning.

I’ve reproduced the main results of these papers in a single [Github repository](https://github.com/oscarknagg/few-shot). You can check out [this post](https://medium.com/@oknagg/advances-in-few-shot-learning-reproducing-results-in-pytorch-aba70dee541d) to read about my experience implementing this research.


## The n-shot, k-way task

The ability of a algorithm to perform few-shot learning is typically measured by its performance on n-shot, k-way tasks. These are run as follows:

- A model is given a query sample belonging to a new, previously unseen class.
- It is also given a support set, S, consisting of n examples each from k different unseen classes.
- The algorithm then has to determine which of the support set classes the query sample belongs to.

![omniglot](https://miro.medium.com/max/784/1*yXw5D5oNs3JZ1-SUzsoxfg.png)

## Matching Networks

![Vinyals et al.](https://miro.medium.com/max/1400/1*OkiAPbdYq1utWUGlDGuBKw.png)

While there is much previous research on few-shot approaches for deep learning, Matching Networks was the first to *both train and test on n-shot, k-way tasks*. The appeal of this is straightforward — training and evaluating on the same tasks lets us optimise for the target task in an end-to-end fashion. Earlier approaches such as siamese networks use a pairwise verification loss to perform metric learning and then in a separate phase use the learnt metric space to perform nearest-neighbours classification. This is not optimal as the initial embedding function is trained to maximise performance on a different task! However, Matching Networks combine both embedding and classification to form an **end-to-end differentiable nearest neighbours classifier**.

Matching Networks first embed a high dimensional sample into a low dimensional space and then perform a generalised form of nearest-neighbours classification described by the equation below.

![Equation (1) from Matching Networks](https://miro.medium.com/max/756/1*Quo_tUQ2kE4v0c-y7n3RCA.png)

**The meaning of this is that the prediction of the model, $\hat{y}$, is the weighted sum of the labels, $y_i$, of the support set, where the weights are a pairwise similarity function, $a(\hat{x}, x_i)$, between the query example, $\hat{x}$, and a support set samples, $x_i$. The labels $y_i$ in this equation are one-hot encoded label vectors.**

Notice that if we choose $a(\hat{x}, x_i)$ to be $1/k$ for the closest $k$ samples to the query sample and 0 otherwise we recover the k-nearest-neighbours algorithm. The key thing to note is that Matching Networks are end-to-end differentiable provided the attention function $a(\hat{x}, x_i)$ is differentiable.

The authors choose a straightforward softmax over cosine similarities in the embedding space as their attention function $a(x, x_i)$. The embedding function they use for their few-shot image classification problems is a CNN which is, of course, differentiable hence making the attention and Matching Networks fully differentiable! This means its straightforward to fit the whole model end-to-end with typical methods such as stochastic gradient descent.

![Attention function used in the Matching Networks paper](https://miro.medium.com/max/1400/1*KpI9WoSeoz0G3u9JesUdUQ.png)

In the above equation c represents the cosine similarity and the the functions $f$ and $g$ are the embedding functions for the query and support set samples respectively. Another interpretation of this equation is that the support set is a form of memory and upon seeing a new samples the network generates a prediction by retrieving the labels of samples with similar content from this memory.

Interestingly the possibility for the support set and query set embedding functions, $f$ and $g$, to be different is left open in order to grant more flexibility to the model. In fact Vinyals et al. do exactly this and introduce the concept of **full context embeddings** or FCE for short.

They consider the myopic nature of the embedding functions a weakness in the sense that each element of the support set $x_i$ gets embedded by $g(x_i)$ in a fashion that is independent of the rest of the support set and the query sample. They propose that the embedding functions $f(\hat{x})$ and $g(x_i)$ should take on the more general form $f(\hat{x}, S)$ and $g(x_i, S)$ where $S$ is the support set. The reasoning behind this is that if two of the support set items are very close, e.g. we are performing fine-grained classification between dog breeds, we should change the way the samples are embedded to increase the distinguishability of these samples.

In practice the authors use an LSTM to calculate the FCE of the support and then use another LSTM with attention to modify the embedding of the query sample. This results in an appreciable performance boost at the cost of introducing a bunch more computation and a slightly unappealing arbitrary ordering of the support set.

All in all this is a very novel paper that develops the idea of a fully differentiable neural neighbours algorithm.

## Prototypical Networks

![Class prototypes c_i and query sample x.](https://miro.medium.com/max/1082/1*JX0QOZ4zoytOuss-Yn7o8g.png)

In Prototypical Networks Snell et al. apply a compelling inductive bias in the form of class prototypes to achieve impressive few-shot performance — exceeding Matching Networks without the complication of FCE. The key assumption is made is that there exists an **embedding** in which samples from each class cluster around a single **prototypical representation** which is simply the **mean** of the individual samples. This idea streamlines n-shot classification in the case of n > 1 as classification is simply performed by taking the label of the closest class prototype.

![Equation (1) from Prototypical Networks — calculating class prototypes. S_k is the support set belonging to class k and f_phi is the embedding function.](https://miro.medium.com/max/816/1*752b0H0z407rps7wp63jsQ.png)

Another contribution of this paper is a persuasive theoretical argument to use euclidean distance over cosine distance in metric learning that also justifies the use of class means as prototypical representations. The key is to recognise that squared euclidean distance (but not cosine distance) is a member of a particular class of distance functions known as **Bregman divergences**.

Consider the clustering problem of finding the centroid of a cluster of points such that the total distance between the centroid and all other points is minimised. It has been [proven](http://www.jmlr.org/papers/volume6/banerjee05b/banerjee05b.pdf) that if your distance function is a **Bregman divergence** (such as squared euclidean distance) then the **centroid** that satisfies this condition is simply the **mean of the cluster** — this is not the case for cosine distance however! This centroid is the point that minimises the loss of information when representing a set of points as just a single point.

This intuition is backed up by experiments as the authors find that both ProtoNets and their own implementation of Matching Networks are improved across the board by swapping from cosine to euclidean distance.

Prototypical Networks are also amenable to **zero-shot learning**, one can simply learn class prototypes directly from a high level description of a class such as labelled attributes or a natural language description. Once you’ve done this it’s possible to classify new images as a particular class without having seen an image of that class. In their experiments they perform zero-shot species classification of images of birds based only on **attributes** such as colour, shape and feather patterns.

I am quite fond of this paper as it achieves the highest performance on typical benchmarks of all of the approaches in this article while also being elegant and the easiest for me to reproduce. Well done Snell et al!

## Model-Agnostic Meta-Learning (MAML)

Finn et al. take a very different approach to few-shot learning by learning a **network initialisation** that can quickly adapt to new tasks — this is a form of **meta-learning** or learning-to-learn. The end result of this meta-learning is a model that can reach high performance on a new task with as little as a **single step of regular gradient descent**. The brilliance of this approach is that it can not only work for supervised regression and classification problems but also for reinforcement learning using any differentiable model!

![Figure 1 from Model-Agnostic Meta-Learning. Theta represents the weights of the meta-learner. Gradient L_i are the losses for tasks, i, in a meta-batch and the starred theta_i are the optimal weights for each task.](https://miro.medium.com/max/872/1*_fvx0_vmihh5_kH3HIO43g.png)

MAML does not learn on batches of samples like most deep learning algorithms but batches of **tasks** AKA meta-batches. For each task in a meta-batch we first initialise a new “fast model” using the weights of the base meta-learner. We then compute the gradient and hence a parameter update from samples drawn from that task and update the weights of the fast model i.e. perform typical mini-batch stochastic gradient descent on the weights of the fast model.

![The weight update due to a single task T_i. Alpha is a learning rate hyperparameter.](https://miro.medium.com/max/734/1*z18N_TTM_pXVh1hQvpbfBA.png)

After the parameter update we sample some more, unseen, samples from the same task and calculate the loss on the task of the updated weights (AKA fast model) of the meta-learner. The final step is to update the weights of the meta-learner by taking the gradient of the sum of losses from the post-update weights . This is in fact taking the gradient of a gradient and hence is a second-order update — the MAML algorithm differentiates through the unrolled training process.

![Weight update for meta-learner. Beta is a learning rate hyperparameter and p(T) is the distribution of tasks.](https://miro.medium.com/max/1026/1*Jq4LaXVep9PY7yxZ2RQqMQ.png)

This is the key step as it means we are optimising for the performance of the base model **after a gradient step** i.e. we are optimising for quick and easy gradient descent. The result of this is that the meta-learner can be trained by gradient descent on datasets as small as a single example per class without overfitting.

A [follow up paper](https://arxiv.org/pdf/1803.02999.pdf) from OpenAI provides some valuable intuition on why this works using a Taylor expansion of the gradient update. The conclusion they come to is that MAML is not only minimising the expected loss over a distribution of tasks but also maximising the expected inner product between gradient updates from the same task. Hence it is **optimising for generalisation** between batches.

![Results of a Taylor expansion analysis from “On First-Order Meta-Learning Algorithms” by Nichol et al.](https://miro.medium.com/max/1400/1*eHD6qsi59tU3Y6qRiXrnYg.png)

The set of above equations shows the expectation of the gradient of MAML, a first order simplification of MAML (FOMAML) and Reptile, a first-order meta-learning algorithm introduced in the same paper. The **AvgGrad** term represents the loss over tasks and the **AvgGradInner** term represents the generalisation term. Note that to leading order in the learning rate, **alpha**, all of the algorithms are performing a very similar update, with second order MAML putting the highest weight on the generalisation term.

Perhaps the only downside of MAML is the second order update as calculating the second derivative of the loss is very memory and compute intensive. However first order simplifications such as FOMAML and Reptile produce very similar performance which hints that the second order update can be approximated with the gradients on the updated weights.

However high computational requirements have no bearing on the fact that Model-Agnostic Meta-Learning is a brilliant paper that has opened up exciting new paths for machine learning.

The field of few-shot learning is making fast progress and although there is much still to be learnt I’m confident that researchers in this field will keep closing the gap between machine and human performance on the challenging task of few-shot learning. I hope you enjoyed reading this post.

*I’ve reproduced the main results of these papers in a single [Github repository](https://github.com/oscarknagg/few-shot). If you’ve had enough concepts and want some juicy technical details and code you can check out [this post](https://medium.com/@oknagg/advances-in-few-shot-learning-reproducing-results-in-pytorch-aba70dee541d) to read about my experience implementing this research.*