# Exercise 8: Dive into Few-Shot Learning 🚀

In this exercise, you'll:

1. Learn about the few-shot-learning concept and ways to tackle it.
2. Get a walkthrough of the `Few-Shot-Bench` code base. This code base offers a structured way to compare various methods meta-learning methods an various datasets. Depending on the project you choose, this might be helpful.

Let's get started and explore this together.

------------------
# 0.0 Environment Setup

In [1]:
# !pip install -r requirements.txt

In [2]:
%load_ext autoreload
%autoreload 2

In [14]:
from abc import abstractmethod, ABC
import os

import torch.nn as nn
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset

from IPython.display import clear_output
from backbones.fcnet import FCNet
from backbones.pem import PEM
from methods.meta_template import MetaTemplate
from methods.protonet import ProtoNet, euclidean_dist
from datasets.prot.swissprot import SPSetDataset
from datasets.cell.tabula_muris import TMSetDataset
import gdown

import matplotlib.pyplot as plt
from itertools import combinations

In [4]:
# url = 'https://drive.google.com/u/0/uc?id=1a3IFmUMUXBH8trx_VWKZEGteRiotOkZS&export=download'

# if os.path.exists('swissprot.zip'):
#     print('File already downloaded.')
# else:
#     output = 'swissprot.zip'
#     gdown.download(url, output, quiet=False)
#     print('Download completed.')

In [5]:
# !unzip -q swissprot.zip

# !rm -rf swissprot.zip
# clear_output()

# !mv data/swissprot/go-basic.obo ./

## PrototypeFormer Implementation

In [6]:
class MetaTemplate(nn.Module):
    def __init__(self, backbone, n_way, n_support):
        super(MetaTemplate, self).__init__()
        self.n_way = n_way
        self.n_support = n_support
        self.n_query = -1  # (change depends on input)|
        self.feature = backbone
        self.feat_dim = self.feature.final_feat_dim

        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

    @abstractmethod
    def set_forward(self, x, is_feature=False):
        pass

    @abstractmethod
    def set_forward_loss(self, x):
        pass

    def forward(self, x):
        out = self.feature.forward(x)
        return out

    def parse_feature(self, x):
        '''
        :param x: [n_way, n_support + n_query, **embedding_dim]
        '''
        x = x.to(self.device)
        x = x.contiguous().view(self.n_way * (self.n_support + self.n_query), * x.size()[2:])
        # Compute support and query feature.
        z_all = self.forward(x)

        # Reverse the transformation to distribute the samples based on the dimensions of their individual categories and flatten the embeddings.
        z_all = z_all.view(self.n_way, self.n_support + self.n_query, -1)

        # Extract the support and query features.
        z_support = z_all[:, :self.n_support]
        z_query = z_all[:, self.n_support:]

        return z_support, z_query

    def correct(self, x):
        # Compute the predictions scores.
        scores = self.set_forward(x)

        # Compute the top1 elements.
        _, topk_indices = scores.topk(k=1, dim=1, largest=True, sorted=True)

        # Detach the variables.
        topk_ind = topk_indices.detach().cpu().numpy()

        # Create the category labels for the queries.
        y_query = np.repeat(range(self.n_way), self.n_query)

        # Compute number of elements that are correctly classified.
        top1_correct = np.sum(topk_ind[:, 0] == y_query) # first column denotes the most likely label

        return float(top1_correct), len(y_query)

    def train_loop(self, epoch, train_loader, optimizer):
        print_freq = 10

        avg_loss = 0
        for i, (x, _) in enumerate(train_loader):
            self.n_query = x.size(1) - self.n_support
            optimizer.zero_grad()
            loss = self.set_forward_loss(x)
            loss.backward()
            optimizer.step()
            avg_loss = avg_loss + loss.item()

            if i % print_freq == 0:
                print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader),
                                                                        avg_loss / float(i + 1)))
        return avg_loss/len(train_loader)
    
    def test_loop(self, epoch, test_loader, record=None, return_std=False):
        acc_all = []

        iter_num = len(test_loader)
        for i, (x, _) in enumerate(test_loader):
            self.n_query = x.size(1) - self.n_support
            correct_this, count_this = self.correct(x)
            acc_all.append(correct_this / count_this * 100)

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        print(f'Epoch {epoch} | Test Acc = {acc_mean:4.2f}% +- {1.96 * acc_std / np.sqrt(iter_num):4.2f}%')

        if return_std:
            return acc_mean, acc_std
        else:
            return acc_mean
        
    def count_parameters(self):
        """Count the number of learnable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [7]:
class ProtoNet(MetaTemplate):
    def __init__(self, backbone, n_way, n_support):
        super(ProtoNet, self).__init__(backbone, n_way, n_support)
        self.loss_fn = nn.CrossEntropyLoss()

    def set_forward(self, x):
        # Compute the prototypes (support) and queries (embeddings) for each datapoint.
        # Remember that you implemented a function to compute this before.
        z_support, z_query = self.parse_feature(x)
            
        # Compute the prototype.
        z_support = z_support.contiguous().view(self.n_way, self.n_support, -1)
        z_proto = z_support.mean(axis=1)
        
        # Format the queries for the similarity computation.
        z_query = z_query.contiguous().view(self.n_way * self.n_query, -1)

        # Compute similarity score based on the euclidean distance between prototypes and queries.
        scores = -euclidean_dist(z_query, z_proto)
        return scores

    def set_forward_loss(self, x):
        # Compute the similarity scores between the prototypes and the queries.
        scores = self.set_forward(x)
        
        # Create the category labels for the queries.
        y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query ))

        # Compute the loss
        loss = self.loss_fn(scores, y_query)
        return loss


In [8]:
import torch

temp = torch.arange(5, 5, 5)

sub_temp = temp.repeat(1, 4, 1).view(5 * 5, 4, -1)


sub_temp[:, 0].shape

torch.Size([25, 0])

In [9]:
def euclidean_dist( x, y):
    # x: N x D
    # y: M x D
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    assert d == y.size(1)

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)

def euclidean_dist_3d(x):
    # dist[i][j] is matrix of pairwise distances of classes i and j
    # x: N x K x D
    n = x.size(0)
    k = x.size(1)
    d = x.size(2)

    flat_x = x.view(-1, d)
    dist = euclidean_dist(flat_x, flat_x)

    return dist.view(n, k, -1).transpose(1, 2).reshape(n , n, k, k)

def contrastive_loss(pairwise_dist):
    n = pairwise_dist.shape[0]

    # Get the device from the input tensor
    device = pairwise_dist.device

    # Create the mask tensor on the same device as pairwise_dist
    mask = torch.eye(n).to(device)
    
    dist_sums = pairwise_dist.sum((2, 3))

    positive_sums = (dist_sums * mask).sum() + 1
    negative_sums = (dist_sums * (1 - mask)).sum() + 1

    return torch.exp(positive_sums / negative_sums / n)

In [10]:
class PrototypeFormer(MetaTemplate):
    def __init__(self, backbone, n_way, n_support, n_sub_supports):
        super().__init__(backbone, n_way, n_support)
        self.classifier_loss_fn = nn.CrossEntropyLoss()
        self.prototype_loss_fn = contrastive_loss

        if n_support < n_sub_supports:
            raise ValueError('Number of sub-supports must be smaller than the number of supports.')
        self.n_sub_support = n_sub_supports
        self.pair_dist = None

    def parse_feature(self, x):
        """
        Prior to processing data with PEM, several preparatory steps are required:
        1. Creating sub-support sets from the main support set.
        2. Integrating a prototype token into both the support and sub-support sets.

        In alignment with the standard input format for Transformers (sequence_length, batch_size, embedding_dim), 
        the dimensions 'n_support + n_query' (serving as sequence_length) and 'n_way' (serving as batch_size) should be reordered to the first and second positions, respectively.

        Moreover, in the context of PrototypeFormer, it is important to note that query data does not undergo processing via PEM. 
        Instead, the query embeddings are directly obtained from the feature extractor. However, as x is already ESM embeddings, we leave query set as is.

        Args:
            x: one batch of data, shape (n_way, n_support + n_query, **embedding_dim)

        Returns:
            z_proto: prototypes, shape (n_way, **embedding_dim)
            z_sub_proto: sub-prototypes, shape (n_way, n_combinations, **embedding_dim)
        """
        x = x.to(self.device)

        # 1. Generate sub-support sets from support set
        # x_sub_support = self.generate_sub_supports(x[:, :self.n_support], n_sub_supports=self.n_support - 1) # (n_way, n_combinations, n_sub_supports, **embedding_dim)

        # 2. Add prototype token to support and sub-support sets
        # x_support     = self.add_prototype_token(x[:, :self.n_support]) # (n_way, n_support + 1, **embedding_dim)
        x_sub_support = self.generate_sub_supports(x[:, :self.n_support], n_sub_supports=self.n_sub_support) # (n_way, n_combinations, n_sub_supports, **embedding_dim)
        # x_sub_support = self.add_prototype_token(x_sub_support) # (n_way, n_combinations, n_sub_supports + 1, **embedding_dim)

        n_combos, n_subs_with_token = x_sub_support.shape[1:3]

        # 3. run through PEM
        # - First, Transposing the dimensions of the input tensor x_support to fit the expected input format for a Transformer encoder
        # - Second, transposing it back to its original format after running through PEM
        z_support = self.forward(x[:, :self.n_support]) # (n_way, n_support + 1, **embedding_dim)

        # z_sub_support = self.forward(x_sub_support.view(self.n_way * n_combos, n_subs_with_token, -1))\
                            # .view(self.n_way, n_combos, n_subs_with_token, -1) # (n_way, n_combinations, n_sub_supports + 1, **embedding_dim)
        
        z_sub_support = x[:, :self.n_support].repeat(1, self.n_support - 1, 1).view(self.n_way * self.n_support, self.n_support - 1, -1)
        z_sub_support = self.forward(z_sub_support)

        print(z_sub_support.shape)
        
        return z_support, z_sub_support

    def set_forward(self, x):

        # Compute the prototype.
        z_support, z_sub_support = self.parse_feature(x)
        # z_proto, z_sub_proto = z_support[:, 0].contiguous(), z_sub_support[:, :, 0].contiguous()
        z_proto, z_sub_proto = z_support[:, 0].contiguous(), z_sub_support[:, 0].view(self.n_way, self.n_support, -1).contiguous()
        # z_proto, z_sub_proto = z_support.mean(axis=1), z_sub_support.mean(axis=2)

        # Format the queries for the similarity computation.
        z_query = x[:, self.n_support:].contiguous().view(self.n_way * self.n_query, -1)

        # Compute pairwise distance between subprototypes
        self.pair_dist = euclidean_dist_3d(z_sub_proto)
        
        # Compute similarity score based on the euclidean distance between prototypes and queries.
        scores = -euclidean_dist(z_query, z_proto)
        return scores

    def set_forward_loss(self, x):
        # Compute the similarity scores between the prototypes and the queries.
        scores = self.set_forward(x)
        
        # Create the category labels for the queries.
        y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query ))

        # Compute the loss
        loss = self.classifier_loss_fn(scores, y_query) + self.prototype_loss_fn(self.pair_dist)
        return loss
    

    def generate_sub_supports(self, support_set, n_sub_supports):
        """
        Generate sub-support sets for each category in the support set.
        
        :param support_set: Tensor of shape [n_way, n_support, **embedding_dim]
        :param n_sub_supports: Number of elements in each sub-support set
        :return: Tensor of shape [n_way, n_combinations, n_sub_supports, **embedding_dim]
                where n_combinations = n_support choose n_sub_supports
        """
        embedding_dim = support_set.size()[2:]
        combo_indices = list(combinations(range(self.n_support), n_sub_supports))
        n_combinations = len(combo_indices)
        sub_supports = torch.zeros(self.n_way, n_combinations, n_sub_supports, *embedding_dim, device=self.device)

        for way in range(self.n_way):
            for idx, combo in enumerate(combo_indices):
                sub_supports[way, idx] = support_set[way, list(combo)]

        return sub_supports
    
    def count_parameters(self):
        """Count the number of learnable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [11]:
def train_model(n_way, n_support, n_query, n_train_episode):
    # Load train dataset. Remember to use make use of functions defined in the `SPSetDataset`.
    train_dataset = SPSetDataset(n_way, n_support, n_query, n_episode=n_train_episode, root='./data', mode='train')
    # train_dataset = TMSetDataset(n_way, n_support, n_query, n_episode=n_train_episode, root='./data', mode='train')
    train_loader = train_dataset.get_data_loader()

    # Load test dataset. Remember to use make use of functions defined in the `SPSetDataset`.
    test_dataset = SPSetDataset(n_way, n_support, n_query, n_episode=100, root='./data', mode='test')
    # test_dataset = TMSetDataset(n_way, n_support, n_query, n_episode=100, root='./data', mode='test')
    test_loader = test_dataset.get_data_loader()

    # Initialize a fully connected network `FCNet` in `fcnet.py` with two hidden layers of 512 units each as feature extractor.
    # backbone = FCNet(train_dataset.x_dim, [512, 512])
    backbone = PEM(train_dataset.x_dim, n_layer=1)

    # Initialize model using the backbone and the optimizer.
    # model = ProtoNet(backbone, n_way, n_support)
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model = PrototypeFormer(backbone, n_way, n_support, n_sub_supports=n_support-1)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8)

    print(f"Number of parameters in the model: {model.count_parameters()}")


    test_accs = []; train_losses = []
    for epoch in range(100):
        model.train()

        # Implement training of the model. Remember to make use of functions defined in the `ProtoNet` and `MetaTemplate` class.
        epoch_loss = model.train_loop(epoch, train_loader, optimizer)
        train_losses.append(epoch_loss)

        # Evaluate test performance for epoch. Remember to make use of functions defined in the `ProtoNet` and `MetaTemplate` class.
        test_acc = model.test_loop(epoch, test_loader)
        test_accs.append(test_acc)
        print(f'Epoch {epoch} | Train Loss {epoch_loss} | Test Acc {test_acc}')
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 2.5))
    ax1.plot(range(len(train_losses)), train_losses)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Train Loss')
    ax1.grid()

    ax2.plot(range(len(test_accs)), test_accs)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Test Accuracy')
    ax2.grid()
    fig.suptitle(f"n_way={n_way}, n_support={n_support}, n_query={n_query}, n_train_episode={n_train_episode}")

    plt.tight_layout()

In [12]:
parameters = {'n_way': 5, 'n_support': 5, 'n_query': 15, 'n_train_episode': 5}

In [13]:
train_model(**parameters)


!gaf-version: 2.2


ERROR:root:Failed to validate header as GAF v2.2:
[]


HMS:0:00:03.895452 310,057 annotations READ: ./data/swissprot/filtered_goa_uniprot_all_noiea.gaf 
25933 IDs in loaded association branch, BP

!gaf-version: 2.2


ERROR:root:Failed to validate header as GAF v2.2:
[]


HMS:0:00:03.912192 310,057 annotations READ: ./data/swissprot/filtered_goa_uniprot_all_noiea.gaf 
25933 IDs in loaded association branch, BP
Number of parameters in the model: 9197828
  EXISTS: go-basic.obo
go-basic.obo: fmt(1.2) rel(2023-06-11) 46,420 Terms; optional_attrs(relationship)
  EXISTS: go-basic.obo
go-basic.obo: fmt(1.2) rel(2023-06-11) 46,420 Terms; optional_attrs(relationship)
  EXISTS: go-basic.obo
go-basic.obo: fmt(1.2) rel(2023-06-11) 46,420 Terms; optional_attrs(relationship)
  EXISTS: go-basic.obo
go-basic.obo: fmt(1.2) rel(2023-06-11) 46,420 Terms; optional_attrs(relationship)
torch.Size([25, 5, 1280])
Epoch 0 | Batch 0/5 | Loss 68.183174
torch.Size([25, 5, 1280])
torch.Size([25, 5, 1280])
torch.Size([25, 5, 1280])
torch.Size([25, 5, 1280])
  EXISTS: go-basic.obo
go-basic.obo: fmt(1.2) rel(2023-06-11) 46,420 Terms; optional_attrs(relationship)
  EXISTS: go-basic.obo
go-basic.obo: fmt(1.2) rel(2023-06-11) 46,420 Terms; optional_attrs(relationship)
  EXISTS: go-basic.

KeyboardInterrupt: 

  EXISTS: go-basic.obo
go-basic.obo: fmt(1.2) rel(2023-06-11) 46,420 Terms; optional_attrs(relationship)
