# Introduction

We know that for ML problems we need big data. But what if:
* There are some rare cases examples for which only a few at all. Then we formally have a classification of $C$ classes problem where $K < C$ classes have a little number of taining examples for the model and we need to be able to fell the difference between all the classes - both numerous and few in number.
* For some problem in general it can be hard (or expensive) to collect big enough dataset (for example, oil wells data, computer tomography, MRI, etc.). Here we have $K$ classes for each of them we have a little number of data in train dataset.

For such problems were created a number of <b>Few Shot Learning</b> algorithms. If we have $K$ rare classes each of them is represented by $N$ objects, then in English-language textbooks such a Few Shot problem calls $K$-way $N$-shot classification/regression/detection/...

We'll try to do an overview of the methods which provide good enough quality of predictions despite of extremely small (or even missing! - <b>Zero Shot Learning</b>) number of data.

In [1]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn, optim
from torch.nn import functional as F
from torchmetrics.functional import accuracy
from torchvision import transforms
from tqdm.auto import tqdm

warnings.filterwarnings("ignore")
%matplotlib inline

Let's collect a little dataset. Let it be a subset of CIFAR10 where are represented 10 classes of colorful pictures with 32x32 resolution.

Let's make a class CifarSubset which will be represented by a dataset consisted from subset of $N$ objects of each of 10 classes.

In [2]:
class CifarSubset(torchvision.datasets.vision.VisionDataset):
    def __init__(self, root, k_n, train: bool, download: bool, transform):
        super().__init__(root, transform=transform, target_transform=None)
        
        self.k_n = k_n
        self.data = self.__sample_uniform_subset(root, train, download, transform)
        
    def __sample_uniform_subset(self, root, train, download, transform):
        """
        Methods returns a k_n-subset of CIFAR10
        
        dataset - initial dataset
        k_n - number of objects in each class
        """
        
        self.cifar_full = torchvision.datasets.CIFAR10(
            root=root, train=train, download=download, transform=transform
        )
        
        number_of_classes = len(self.cifar_full.classes)
        class_count = [0 for _ in range(number_of_classes)]
        subset = []
        
        for image, label in self.cifar_full:
            if class_count[label] < self.k_n:
                subset.append((image, label))
                class_count[label] += 1
            
            if sum(class_count) == number_of_classes * self.k_n:
                break
        return subset
    
    def __getitem__(self, index):
        return self.data[index][0], self.data[index][1]
    
    def __len__(self):
        return len(self.data)

In [3]:
N_SHOT = 5 # number of objects of each class in train dataset

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ]
)

cifar_train_subset = CifarSubset(
    root="./cifar", k_n=N_SHOT, train=True, download=True, transform=transform
)

cifar_val_subset = CifarSubset(
    root="./cifar", k_n=N_SHOT, train=False, download=True, transform=transform
)

Files already downloaded and verified
Files already downloaded and verified


# Few Shot Learning

For most of Few Shot (and especially Meta) Learning algorithms it's better to train the model by <b>episodes</b> where episode means situation as close as possible to the conditions of inference. For instance, if while inference we ask the model to classify $N$ objects of each of $K$ classes, then we have to train the model by batches of $N \cdot K$ classes.

In our case $N \cdot K$ is all the data then one episode is equivalent of an epoch.

In [4]:
train_dataloader = torch.utils.data.DataLoader(
    cifar_train_subset,
    batch_size=N_SHOT * 10, # because in our sample 10 classes
    shuffle=True,
    num_workers=0,#1, # 0 because jupyter can't deal with multiprocessing module
)

val_dataloader = torch.utils.data.DataLoader(
    cifar_val_subset,
    batch_size=N_SHOT * 10, # because in our sample 10 classes
    shuffle=True,
    num_workers=0,#1, # 0 because jupyter can't deal with multiprocessing module
)

If we'll use an ordinary approach, our model won't find any patterns at all and probably will have overfitted strongly.

The main goal of approaches of this kind - to understand how learned features (embedding) of the typical representer of each class look like.
Having such "prototypes" of the classes we can compare the features of each of the objects to features of each prototype and choose there one which the most similar to our object.
In order to make it works, we must make sure that while training we make it the way similar objects are near in features space to each other and not similar - far from each other (the same way as in clustering problem).

On high level it always looks like approximately the same:
1. On our data we train some model for which an output is not <b>probabilities of classes</b> but <b>multidimensional vector of the features</b> (we can take any model of classification/regression and throw out the last layer - classificational layer).
2. The loss function is selected the way model train to group similar data.
3. Both while training and while inference in case of $K$-way $N$-shot classification (we have $K$ classes and $N$ objects of each class) some percent of $N$ objects of each class is chosen to form of this prototype of the class via take an average of its features. Such subset is called <b>support set</b> and the rest objects (we'll compare them to this prototype) is called <b>query set</b>.

Let's build a convolutional architecture, an output of which is multidimensional vector of learned features.

In [5]:
class PrototypeNet(nn.Module):
    def __init__(self, input_dim=3, hid_dim=64, output_dim=64):
        super(PrototypeNet, self).__init__()
        
        self.encoder = nn.Sequential(
            self.__conv_block(input_dim, hid_dim),
            self.__conv_block(hid_dim, hid_dim),
            self.__conv_block(hid_dim, hid_dim),
            self.__conv_block(hid_dim, output_dim)
        )
        
    def __conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
    
    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)

Now we need to select a loss function which will give the model a fine for:
* large distance between object and prototype of its class
* small distance between objects and prototype of <b>not</b> its class.

It looks like the following (in one approach):
* For each class <b>c</b>:
    * For each object <b>q</b> from our <b>query_set</b>(c)
    $$Loss = \frac{1}{N_c \cdot N_q} \left[ dist(emb(q), p_c) + \log \sum_{classes} e^{-dist(emb(q), p_c)}\right]$$
    
where $emb(q)$ - embedding have obtained after object $q$ is went through our convolutional network, $dist(emb(q), p_c)$ - some function of distance between this embedding and prototype $p_c$ of class $c$.

In [6]:
class PrototypeLoss(nn.Module):
    def __init__(self, n_shot: int):
        super(PrototypeLoss, self).__init__()
        self.n_shot = n_shot
        self.__prototypes: list = None
    
    def __distance(self, inp, other):
        return torch.linalg.norm(inp - other)
    
    def __prototype_loss(
        self,
        predicted_embeddings: torch.Tensor,
        target_labels: torch.Tensor,
        support_percent=0.6,
    ):
        """
        predicted_embeddings - vectors of features have obtained as a 
                                result of work of the model
                                shape = (N, emb_size)
        target_labels - list of target classes 
                        shape = (N)
        """
        
        assert 0 < support_percent < 1
        
        classes = torch.unique(target_labels)
        n_classes = len(classes) # number of classes for episode
        
        n_support = int(self.n_shot * support_percent) # size of support set
        n_query = self.n_shot - n_support
        
        prototypes = []
        class_queries = []
        
        for c in classes:
            supports = predicted_embeddings[target_labels == c][:n_support]
            queries = predicted_embeddings[target_labels == c][n_support:]
            
            prototypes.append(supports.mean(0))
            class_queries.append(queries)
            
        loss = torch.tensor(0.0)
        prediction_labels = []
        
        for ci in range(len(classes)): # ci - nimber of class
            for queries in class_queries: # queries - list of object of a fixed class
                for query in queries:
                    # adding distance of object to prototype of its class
                    loss += self.__distance(prototypes[ci], query)
                    
                    # calculation of distance to prototypes of the other classes
                    distances = torch.stack(
                        [
                            self.__distance(query, p)
                            for i, p in enumerate(prototypes)
                            if i != ci
                        ]
                    )
                    
                    # add the second member of loss function
                    loss += torch.log(torch.sum(torch.exp(-distances)))
        
        # let's save prototypes for further predictions
        self.prototypes = torch.stack(prototypes)
        
        return loss / (n_support * n_query)
    
    def forward (self, predicted_embeddings, targets) -> torch.Tensor:
        return self.__prototype_loss(predicted_embeddings, targets)
    
    def get_class_prototypes(self) -> list:
        return self.__prototypes

In [7]:
class FewShotMetricLearner(pl.LightningModule):
    def __init__(self, n_shot: int) -> None:
        super().__init__()
        
        self.model = PrototypeNet()
        self.loss = PrototypeLoss(n_shot)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-2)
        self.prototypes = None
        
    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    
    def configure_optimizers(self):
        return self.optimizer
    
    def predict_labels(self, batch_embeddings):
        if self.prototypes == None:
            raise ValueError("self.prototypes wasn't set")
            
        ans = []
        for e in batch_embeddings:
            dists = torch.linalg.norm(self.prototypes - e, dim=1)
            ans.append(torch.argmin(dists))
            
        return torch.stack(ans)
    
    def training_step(self, train_batch, batch_idx) -> torch.Tensor:
        images, target = train_batch
        embeddings = self.forward(images)
        
        loss = self.loss(embeddings, target)
        label_predictions = self.predict_labels(embeddings)
        
        self.log("train_loss:", loss, prog_bar=True)
        return loss
    
    def validation_step(self, val_batch, batch_idx) -> None:
        images, target = val_batch
        embeddings = self.forward(images)
        
        loss = self.loss(embeddings, target)
        self.prototypes = self.loss.prototypes
        
        label_predictions = self.predict_labels(embeddings)
        acc = accuracy(label_predictions, target, task="multiclass", num_classes=10)
        
        self.log("val_loss:", loss, prog_bar=True)
        self.log("val_acc:", acc, prog_bar=True)

In [8]:
few_shot_metric_learner = FewShotMetricLearner(n_shot=N_SHOT)

trainer = pl.Trainer(accelerator='cpu', max_epochs=50, num_sanity_val_steps=2)
trainer.fit(
    few_shot_metric_learner,
    train_dataloader, 
    val_dataloader,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type          | Params
----------------------------------------
0 | model | PrototypeNet  | 113 K 
1 | loss  | PrototypeLoss | 0     
----------------------------------------
113 K     Trainable params
0         Non-trainable params
113 K     Total params
0.452     Total estimated model params size (MB)


Sanity Checking: |                                                                               | 0/? [00:00<…

Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=50` reached.


As you can see, we can classify the objects with accuracy ~0.7 with training the simplest convolutional neural network (it works approximately a minute on cpu) on only 5 (!!!) objects of each class.

# Few Shot Meta Learning

# Zero Shot Learning