<div class="alert alert-block alert-info">
<b>Number of points for this notebook:</b> 3
<br>
<b>Deadline:</b> April 27, 2020 (Monday). 23:00
</div>

# Exercise 8. Few-shot learning with Prototypical Networks

The goal of the exercise is to get familiar with methods that can solve few-shot classification tasks. In this noteboook, we will implement Prototypical Networks. We recommend you to read the original paper by [Snell et al, (2017)](https://arxiv.org/pdf/1703.05175.pdf) before doing this assignment.

In [1]:
# We will use interactive figures in this notebook
%matplotlib notebook

In [2]:
skip_training = True  # Set this flag to True before validation and submission

In [3]:
# During evaluation, this cell sets skip_training to True
# skip_training = True

In [4]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset

import tools
import tests

In [5]:
# When running on your own computer, you can specify the data directory by:
# data_dir = tools.select_data_dir('/your/local/data/directory')
data_dir = tools.select_data_dir()

The data directory is /coursedata


In [6]:
# Select the device for training (use GPU if you have one)
#device = torch.device('cuda:0')
device = torch.device('cpu')

In [7]:
if skip_training:
    # The models are always evaluated on CPU
    device = torch.device("cpu")

# Omniglot data

We will use Omniglot data for training. Omniglot is a collection of 19280 images of 964 characters from 30 alphabets. There are 20 images for each of the 964 characters in the dataset.

In [8]:
transform = transforms.Compose([transforms.ToTensor()])
dataset = torchvision.datasets.Omniglot(root=data_dir, download=True, transform=transform)

Files already downloaded and verified


In [9]:
# Let us plot some samples from the dataset.
x, y = dataset[0]  # x is the image, y is the label (character)
print(x.shape, y)

torch.Size([1, 105, 105]) 0


In [10]:
fig, ax = plt.subplots(1, figsize=(3, 3))
ax.matshow(1-x[0], cmap=plt.cm.Greys)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f15fec9eb70>

## Custom data loader for the few-shot learning task

The task of few-shot learning is to learn a classification task from a few training examples. In this notebook, we will consider $n$-way $k$-shot classification problems, when each classification problem has $n$ clases with $k$ examples per class in the training dataset.

We take the meta-learning approach, in which we learn how to learn new ($N$-way $k$-shot classification) tasks using multiple training examples of tasks. Thus, in the meta-learning approach, a single "training example" is one learning (e.g. classification) task which comes from a distribution of tasks that we create using the Omniglot dataset. 

We perform meta-learning using **episodic training**. In each episode, we process one training task or a mini-batch of tasks. Each tasks contains two datasets:
* *support set*, which is used to build a classifier,
* *query set*, which is used to test the accuracy of the built classifier.

In order to load such training examples in the training loop, we build a custom dataloader on top of the `Omniglot` dataset available in pytorch.

In [11]:
class OmniglotFewShot(Dataset):
    """Omniglot data set for few-shot learning.

    Args:
      root (string): Root directory to put the data.
      n_support (int): Number of support samples in each training task.
      n_query (int): Number of query samples in each training task.
      transform (callable): Transforms applied to Omniglot images. We rescale them to 28x28,
          convert to tensors and invert so that image backround is encoded as 0 (original Omniglot images have
          background encoded as 1).
      mix: If True, all examples can be used either as support or query examples. If False, the first
          n_support images are always used as support examples and the following n_query images are used
          as query examples.
      train: If True, use training set. If False, use test set.
    """
    def __init__(self, root, n_support, n_query,
                 transform=transforms.Compose([
                     transforms.Resize(28),
                     transforms.ToTensor(),
                     transforms.Lambda(lambda x: 1-x),
                 ]),
                 mix=False,  # Mix support and query examples
                 train=True
                ):

        assert n_support + n_query <= 20, "Omniglot contains only 20 images per character."
        self.n_support = n_support
        self.n_query = n_query
        self.mix = mix
        self.train = train  # training set or test set
        
        self._omniglot = torchvision.datasets.Omniglot(root=root, download=True, transform=transform)
        
        self.character_classes = character_classes = np.array([
            character_class for _, character_class in self._omniglot._flat_character_images
        ])
        
        n_classes = max(character_classes)
        self.indices_for_class = {
            i: np.where(character_classes == i)[0].tolist()
            for i in range(n_classes)
        }
        
        np.random.seed(1)
        rp = np.random.permutation(n_classes)
        if train:
            self.used_classes = rp[:770]
        else:
            self.used_classes = rp[770:]
        
    def __getitem__(self, index):
        """
        Returns:
          support_query of shape (n_support+n_query, 1, height, width):
                      support_query[:n_support] is the support set
                      support_query[n_support:] is the query set
        """
        class_ix = self.used_classes[index]
        indices = self.indices_for_class[class_ix]
        if self.mix:
            indices = np.random.permutation(indices)

        indices = indices[:self.n_support+self.n_query]  # First support, then query
        support_query = torch.stack([self._omniglot[ix][0] for ix in indices])

        return support_query

    def __len__(self):
        return len(self.used_classes)

One sample from the dataset represents one class which consists of `n_support` support samples and `n_query` query samples.

In [12]:
dataset = OmniglotFewShot(root=data_dir, n_support=1, n_query=3, train=True)
support_query = dataset[0]
print(support_query.shape)

Files already downloaded and verified
torch.Size([4, 1, 28, 28])


We can now build data for $n$-way $k$-shot classification tasks using the following data loader. Each mini-batch that this data loader produces is one $n$-way $k$-shot classification task. In principles, we could include more tasks into each mini-batch but we do not do it in this notebok.

In [13]:
n_way = 5
trainloader = DataLoader(dataset=dataset, batch_size=n_way, shuffle=True, pin_memory=True)

for support_query in trainloader:
    print(support_query.shape)
    # support_query is (n-way, n_support+n_query, 1, 28, 28)
    break

torch.Size([5, 4, 1, 28, 28])


# Prototypical networks

## The embedding CNN

We first build a convolutional neural network that embeds images into a lower-dimensional space.

The exact architecture is not important in this exercise but the following architecture worked for us:
* Four blocks with the following layers:
    * `Conv2d` layer with kernel size 3 and 64 output channels, followed by `BatchNorm2d`, ReLU and 2d max pooling (with kernel 2 and stride 2).
* A fully-connected layer with 64 output features.

In [14]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
      
        # YOUR CODE HERE
        #raise NotImplementedError()
        self.layer = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size = 3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size = 3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc = nn.Linear(64, 64)
    def forward(self, x):
        #x = F.max_pool2d(self.layer(x), kernel_size = 2, stride = 2)
        #x = F.max_pool2d(self.layer2(x), kernel_size = 2, stride = 2)
        #x = F.max_pool2d(self.layer2(x), kernel_size = 2, stride = 2)
        #x = F.max_pool2d(self.layer2(x), kernel_size = 2, stride = 2)
        #print(x.shape)
        x = self.layer(x)
        x = self.layer2(x)
        x = self.layer2(x)
        x = self.layer2(x)
        x = x.view(-1,64)
        #print(x.shape)
        x = self.fc(x)
        
        return x

In [15]:
def test_CNN_shapes():
    net = CNN()

    x = torch.randn(2, 1, 28, 28)
    y = net(x)
    assert y.shape == torch.Size([2, 64]), f"Wrong y.shape: {y.shape}"
    print('Success')

test_CNN_shapes()

Success


## One episode of training

In the cell below, you need to implement building a computational graph for one episode of training of Prototypical Networks.

The required steps:
* Use the provided network to embed both support and query examples.
* Compute one prototype per class using the support set. The prototypes are the mean values of the embeddings of the samples from the same class.
* Compute the log-probabilities that the query samples belong to one of the n classes.
  The probabilities are softmax of the negative squared Euclidean distance from an embedded sample to a class prototype.
* Compute the negative log-likelihood loss using the query samples.

def episode_pn(net, support_query, n_support):
    """Build a computational graph for one episode of training of prototypical networks.
    
    Args:
      net: An embedding network.
      support_query of shape (n_way, n_support+n_query, 1, height, width):
                      support_query[:, :n_support] is the support set
                      support_query[:, n_support:] is the query set
    
    Returns:
      loss (scalar tensor): The negative log-likelihood loss.
      accuracy (float): The classification accuracy on the given example (needed for tracking the progress).
      outputs of shape (n_way, n_query, n_way): Logits (log-softmax) of the probabilities of query classes
          belonging to one of the n classes. The first dimension corresponds to the true class, the last
          dimension corresponds to predicted classes.
    """
    # YOUR CODE HERE
    #raise NotImplementedError()
    n_way, n_sq,_, _, _ = support_query.shape
    n_query = n_sq - n_support
    support_set = support_query[:, :n_support] #[5,2,1,28,28]
    query_set = support_query[:, n_support:] #[5,4,1,28,28] 
    
    outputs = torch.FloatTensor()
    prototypes = torch.zeros(n_way)

    for i in range(n_way):
            support_sample = support_set[i]
            print(support_sample)   
            support_sample = net(support_sample)
            #print(support_sample)   
            prototype = torch.mean(support_sample)
            prototypes[i] = prototype
            #print(prototypes)
    #outputs_dist = torch.FloatTensor()
    
    outputs_dist = torch.zeros((n_way, n_query, n_way))
    for j in range(n_way):
        query_sample = query_set[i]
        #print(query_sample)
        query_sample = net(query_sample)
        #print(query_sample)
        
        sq_dist_query = torch.zeros((n_query,n_way))
        for m in range(n_query):
            sq_dist_way = torch.zeros(n_way)
            for n in range(n_way):
                ex_prototype = prototypes[n].expand_as(query_sample[m])
                sq_dist = torch.sum((query_sample[m] - ex_prototype) ** 2)
                #print(sq_dist)
                sq_dist_way[n] = sq_dist
            #print(sq_dist_way)
        #sq_dist_query[m] = sq_dist_way
            #sq_dist_query = torch.cat((sq_dist_query,sq_dist_way),dim=0)
            #softmax = nn.Softmax(dim=0)
            #print(sq_dist_way)
            #sq_dist_way = softmax(sq_dist_way)
            #print(sq_dist_way)
            sq_dist_query[m] = sq_dist_way
        #print(sq_dist_query.shape)
        #outputs_dist = torch.cat((outputs_dist,sq_dist_query),dim=0)
        outputs_dist[j] = sq_dist_query
    #print(outputs_dist.shape)
    softmax = nn.LogSoftmax()
    #print(outputs_dist)
    outputs_dist = softmax(-outputs_dist)
    print(outputs_dist)
    
    ##loss
            

    return loss, accuracy, outputs_dist            
    
    #2. compute the log-prob 
    
    #3. compute the negative log-like loss using query set

In [17]:
def episode_pn(net, support_query, n_support):
    """Build a computational graph for one episode of training of prototypical networks.
    
    Args:
      net: An embedding network.
      support_query of shape (n_way, n_support+n_query, 1, height, width):
                      support_query[:, :n_support] is the support set
                      support_query[:, n_support:] is the query set
    
    Returns:
      loss (scalar tensor): The negative log-likelihood loss.
      accuracy (float): The classification accuracy on the given example (needed for tracking the progress).
      outputs of shape (n_way, n_query, n_way): Logits (log-softmax) of the probabilities of query classes
          belonging to one of the n classes. The first dimension corresponds to the true class, the last
          dimension corresponds to predicted classes.
    """
    # YOUR CODE HERE
    #raise NotImplementedError()
    n_way, n_sq,_, _, _ = support_query.shape
    n_query = n_sq - n_support
    support_set = support_query[:, :n_support] #[5,2,1,28,28]
    query_set = support_query[:, n_support:] #[5,4,1,28,28] 
    
    #outputs = torch.FloatTensor()
    #prototypes = torch.zeros(n_way)
    targets = torch.arange(0,n_way).view(n_way,1,1).expand(n_way,n_query,1).long()
    #print(targets)
    #print(support_set.shape)
    sq_set = torch.cat([support_set.reshape(n_way*n_support, *support_set.size()[2:]), 
                        query_set.reshape(n_way*n_query, *query_set.size()[2:])],0)
    #print(sq_set.shape)
    sq_forward = net(sq_set)
    #print(sq_forward.shape)
    sq_dim = sq_forward.size(-1)
    #print(sq_dim)
    prototypes = sq_forward[:n_way*n_support].view(n_way, n_support, sq_dim).mean(1)
    #print(prototypes.shape)
    query_sample = sq_forward[n_way*n_support:]
    #print(query_sample.shape)
    
    #calculate euclidean distance x:quesry 
    #print(query_sample.size(0))
    #print(prototypes.size(0))
    query_sample_dist = query_sample.unsqueeze(1).expand(query_sample.size(0),prototypes.size(0),prototypes.size(1))
    #print(query_sample_dist.shape)
    prototypes_dist = prototypes.unsqueeze(0).expand(query_sample.size(0),prototypes.size(0),prototypes.size(1))
    #print(prototypes_dist.shape)
    #print()
    sq_dist = torch.pow(query_sample_dist-prototypes_dist,2).sum(2)
    #print(sq_dist.shape)
    
    log_softmax = F.log_softmax(-sq_dist, dim=1).view(n_way, n_query, -1)
    #print(log_softmax.shape)
    
    loss = -log_softmax.gather(2, targets).squeeze().view(-1).mean()
    #print(loss.item())

    _, y_predict = log_softmax.max(2)
    accuracy = torch.eq(y_predict, targets.squeeze()).float().mean()
    #print(accuracy.item())

    return loss, accuracy, log_softmax            
    
    #2. compute the log-prob 
    
    #3. compute the negative log-like loss using query set

In [18]:
def test_episode_pn_shapes():
    n_support = 2
    n_query = 4
    n_way = 5
    support_query = torch.randn(n_way, n_support+n_query, 1, 28, 28)

    net = CNN()
    loss, accuracy, outputs = episode_pn(net, support_query, n_support)
    assert loss.shape == torch.Size([]), "Bad loss.shape"
    assert 0. <= float(accuracy) <= 1., "accuracy should be a scalar between 0 and 1."
    assert outputs.shape == torch.Size([n_way, n_query, n_way]), f"Bad outputs.shape: {outputs.shape}"
    print('Success')

test_episode_pn_shapes()

Success


In [19]:
tests.test_episode_pn(episode_pn)

outputs:
 tensor([[[-4.7113e+00, -7.1130e-01, -7.1130e-01, -4.7113e+00, -1.2711e+01],
         [-1.2711e+01, -4.7113e+00, -7.1130e-01, -7.1130e-01, -4.7113e+00]],

        [[-1.2711e+01, -4.7113e+00, -7.1130e-01, -7.1130e-01, -4.7113e+00],
         [-2.4702e+01, -1.2702e+01, -4.7023e+00, -7.0227e-01, -7.0227e-01]],

        [[-2.4702e+01, -1.2702e+01, -4.7023e+00, -7.0227e-01, -7.0227e-01],
         [-4.0018e+01, -2.4018e+01, -1.2018e+01, -4.0182e+00, -1.8156e-02]],

        [[-4.0018e+01, -2.4018e+01, -1.2018e+01, -4.0182e+00, -1.8156e-02],
         [-5.6000e+01, -3.6000e+01, -2.0000e+01, -8.0003e+00, -3.3540e-04]],

        [[-5.6000e+01, -3.6000e+01, -2.0000e+01, -8.0003e+00, -3.3540e-04],
         [-7.2000e+01, -4.8000e+01, -2.8000e+01, -1.2000e+01, -6.1989e-06]]])
expected outputs:
 tensor([[[-4.7113e+00, -7.1130e-01, -7.1130e-01, -4.7113e+00, -1.2711e+01],
         [-1.2711e+01, -4.7113e+00, -7.1130e-01, -7.1130e-01, -4.7113e+00]],

        [[-1.2711e+01, -4.7113e+00, -7.1130e-01

## Train Prototypical Networks

In [24]:
# Prepare dataloader
n_support = 1
n_query = 3
n_way = 5
trainset = OmniglotFewShot(root=data_dir, n_support=n_support, n_query=n_query, train=True, mix=True)
trainloader = DataLoader(dataset=trainset, batch_size=n_way, shuffle=True, pin_memory=True, num_workers=3)

testset = OmniglotFewShot(root=data_dir, n_support=n_support, n_query=n_query, train=False, mix=True)
testloader = DataLoader(dataset=testset, batch_size=n_way, shuffle=True, pin_memory=True, num_workers=3)

Files already downloaded and verified
Files already downloaded and verified


In [25]:
# Create the model
net = CNN()
net.to(device)

CNN(
  (layer): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Linear(in_features=64, out_features=64, bias=True)
)

### Training loop

Implement the training loop in the cell below.

Recommended hyperparameters:
* Adam optimizer with learning rate 0.001. It helps to anneal the learning rate to 0.00001 during training (but it is not needed to pass the tests).

Hints:
* We recommended you to track training and test accuracies returned by function `episode_pn()`.
* During training, both training and test accuracies should reach at least the level of 0.96. Note that we sample a limited number of tasks to compute the accuracies and therefore the accuracy values may fluctuate.
* **Do not forget to set the network into the training mode during training and to evaluation mode during evaluation.**

In [39]:
# Implement the training loop in this cell
lr = 0.001
epochs = 30
#criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)
if not skip_training:
    # YOUR CODE HERE
    #raise NotImplementedError()
    #loss, accuracy, outputs = episode_pn(net, trainset, n_support=1)
    for epoch in range(epochs):
        #support_query = iter(trainloader).next()
        accuracy_sum = 0.0
        for batch, support_query in enumerate(trainloader):
            #print(batch)
            #print(support_query)
            loss, accuracy, outputs = episode_pn(net, support_query, n_support=1)
            #print(loss.item())
           # print(accuracy.item)
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            accuracy_sum += accuracy.item()
        print('[%d]%d,accuracy: %.03f' % (epoch + 1,batch, accuracy_sum/(batch+1)))
    
        

[1]153,accuracy: 0.954
[2]153,accuracy: 0.948
[3]153,accuracy: 0.952
[4]153,accuracy: 0.955
[5]153,accuracy: 0.957
[6]153,accuracy: 0.959
[7]153,accuracy: 0.968
[8]153,accuracy: 0.958
[9]153,accuracy: 0.958
[10]153,accuracy: 0.961
[11]153,accuracy: 0.951
[12]153,accuracy: 0.952


Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/opt/conda/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/conda/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/opt/conda/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

In [31]:
# Save the model to disk (the pth-files will be submitted automatically together with your notebook)
if not skip_training:
    tools.save_model(net, '8_pn.pth')
else:
    net = CNN()
    tools.load_model(net, '8_pn.pth', device)

Do you want to save the model (type yes to confirm)? yes
Model saved to 8_pn.pth.


In [32]:
# This cell tests the accuracy of your model

# Test the trained model

In [35]:
# Use one clasification task from the test set
net.eval()
with torch.no_grad():
    support_query = iter(testloader).next()
    _, acc, outputs = episode_pn(net, support_query.to(device), n_support=1)
    print(outputs.argmax(dim=2))

tensor([[0, 3, 0],
        [1, 1, 1],
        [2, 2, 2],
        [0, 3, 3],
        [4, 4, 4]])


## Interactive demo

You can use the figure in the cell below to draw three images of the support set (top row) and three images of the query set (bottom row).

In [41]:
canvas = tests.Canvas()

<IPython.core.display.Javascript object>

In the next cell, we classify the images of the query set to one of three classes specified by the support set.
The colors of the frames in the bottom row represent the labels produced by the classifier for the query set.

In [42]:
# Convert images into torch tensors
support_query = canvas.get_images()
print(support_query.shape)

net.eval()
with torch.no_grad():
    _, _, outputs = episode_pn(net, support_query.float().to(device), n_support=1)
    # outputs is (n_way, n_query, n_way)
classes = outputs.argmax(dim=2).view(-1)

tests.plot_classification(support_query, classes)

torch.Size([3, 2, 1, 28, 28])


<IPython.core.display.Javascript object>