# Zero-Shot Learning With Pytorch

Hello guys, today I have a pretty interesting framework to talk about. This framework will allow us to classify images which we have never seen during the training. 

Last year I worked in One-Shot Learning (ONE) [1] to predict similarity between Time Series. To do so, we implemented a Siamese Neural Network [1] that was able to compute features from two different time series at a time, and later compare them using a simple L1 distance.

But today I don't want to talk about ONE, I want to talk about a more powerful framework called Zero-Shot Learning (ZSL), which, as I commented before, it can classify unseen images.

## How ZSL works

To explain in a high level how ZSL I want to mention a quora answer [2] which I find useful and comprehensive. 
 
> What if I show you an image of an animal, given you have never seen that animal before, can you guess the name of the animal? Maybe, if you have somewhere read about that particular animal. Let's say I show a picture of zebra to a child who has never seen a zebra but has seen a horse and also she is taught that a zebra looks like a horse but with stripes. Can she identify it now? Most probably, yes!

This assumption may rings a bell to you when talking about Natural Language Processing (NLP) and *word-embeddings*. With *word-embeddings* we can represent a word using a vector. If two words have a similar semantic, then the vectors will be pretty close on their vector space [3].

> Let's consider a word embedding of three dimensions (although in practice the embedding varies from 100 to 300-dimensional vectors). Let these three dimensions represent features like stripes, animalness, and whiteness. So for a tiger, it would be [1, 1, 0] i.e a tiger has stripes, it is an animal but is not white in colour and for a rabbit it would be [0, 1, 1] as a rabbit does not have stripes, but is an animal and white in colour.

Now coming back to CNNs, we can represent an image using a set of features, also known as a **feature vector**. A straight forward way to create a feature vector from an image is using a pretrained network.

For instance, we can retrieve a feature vector from an image using PyTorch with few lines of code.

In [3]:
import torch
import torch.nn.functional as F
import torchvision.models as zoo

image = torch.randn((1, 3, 224, 224))
model = zoo.densenet121(pretrained=True)
feature_vector = F.adaptive_avg_pool2d(model.features(image), (1, 1))
print('Random image feature vector:', feature_vector.view(-1).size()) 

Random image feature vector: torch.Size([1024])


Therefore, for ZSL tasks, we won't use the feature vectors to classify the images into labels, we are going to use those features to match the feature vectors to a semantic representation of the label, for example, a word embedding. 

So, if we have an image of a tiger and his corresponding word embedding [1, 1, 0], we are going to do a regression task in order the label representation gets closer to the feature vector. And later, when we feed a zebra image (unseen image during training) to the pretrained CNN, we will be able to do a search task to the word embedding space and get the closest one. *So we were able to identify an image of a zebra which we didn't have in our training data, but had a word embedding for it.*

## ZSL Framework definition

At training time we have a dataset $D$ with $N$ training samples. Then, the dataset is given by $D = \{(I_i, y^u_i, t^u_i), i = 1, ..., N\}$ with associated labels $\tau$, where $I_i$ is the $i$th training image, $y^u_i$ is the corresponding $L$-Dimensional semantic representation (word embedding, text description, etc.), $t^u_i \in \tau$ is the $u$th training class for $i$th image. 

Therefore, given a new test image $I_j$, the goal of ZSL is to predict a class label $t^u_j \in \tau'$, where $\tau \cap \tau' = \emptyset $. Meaning that both images and its label are not seen during training.


In [2]:
import sys

from torch.utils.data import DataLoader

sys.path.append('..')
import zsl

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
fe = 'ResNet101'
ds = zsl.data.AwAFeaturesDataset(
    '../data/AwA2/Animals_with_Attributes2',
    features_type=fe)

In [5]:
features, target, semantic = ds[0]
print(ds.classes[target])
print(features.shape)

antelope
torch.Size([2048])


In [6]:
train_ds = zsl.data.AwAFeaturesDataset(
    '../data/AwA2/Animals_with_Attributes2',
    features_type=fe,
    load_unseen=False)

valid_ds = zsl.data.AwAFeaturesDataset(
    '../data/AwA2/Animals_with_Attributes2',
    features_type=fe,
    load_unseen=True,
    load_only_unseen=True)

train_dl = DataLoader(
    train_ds, 
    shuffle=True, 
    batch_size=64,
    collate_fn=zsl.utils.collate_image_folder)

valid_dl = DataLoader(
    valid_ds, 
    batch_size=32,
    collate_fn=zsl.utils.collate_image_folder)

In [7]:
len(valid_ds), len(train_ds)

(7913, 29409)

In [8]:
semantic_unit = zsl.models.LinearSemanticUnit(
    in_features=len(train_ds.attrs),
    out_features=1024)
visual_fe = zsl.models.Identity(features.size(0))                                

In [9]:
semantic = torch.FloatTensor(semantic)
print('Semantic repr shape:', semantic_unit(semantic.unsqueeze(0)).size())
print('Visual embedding shape:', visual_fe(features.unsqueeze(0)).size())

Semantic repr shape: torch.Size([1, 1024])
Visual embedding shape: torch.Size([1, 2048])


In [10]:
zs_model = zsl.models.ZeroShot(visual_fe, semantic_unit)
image_embed, semantic_embed = zs_model(
    features.unsqueeze(0), torch.FloatTensor(semantic).unsqueeze(0))
image_embed.size(), semantic_embed.size()

(torch.Size([1, 2048]), torch.Size([1, 2048]))

In [11]:
zs_model.to(device);

In [12]:
features, labels, semantics = next(iter(train_dl))
images_embeds, semantics_embeds = zs_model(features.to(device), 
                                           semantics.to(device))

images_embeds.size(), semantics_embeds.size()

(torch.Size([64, 2048]), torch.Size([64, 2048]))

In [13]:
parameters = [p for p in zs_model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(parameters, 1e-4, weight_decay=9e-5)

In [14]:
EPOCHS = 8
semantic_repr = torch.FloatTensor(valid_ds.attr_matrix)

for epoch in range(EPOCHS):
    zsl.engine.train_epoch(
         model=zs_model, dl=train_dl, optimizer=optimizer, 
         epoch=epoch, print_freq=300, device=device)
    
    zsl.engine.evaluate(
        model=zs_model, dl=valid_dl,
        class_representations=semantic_repr, device=device)

Epoch [0] [299/460] loss: 0.4748
Epoch [0] [460/460] loss: 0.4499
Validation accuracy: 0.2174 top_5_accuracy: 0.8322 loss: 0.5206
Epoch [1] [299/460] loss: 0.3927
Epoch [1] [460/460] loss: 0.3883
Validation accuracy: 0.2709 top_5_accuracy: 0.8625 loss: 0.5146
Epoch [2] [299/460] loss: 0.3771
Epoch [2] [460/460] loss: 0.3737
Validation accuracy: 0.2932 top_5_accuracy: 0.8606 loss: 0.5131
Epoch [3] [299/460] loss: 0.3677
Epoch [3] [460/460] loss: 0.3655
Validation accuracy: 0.3029 top_5_accuracy: 0.8659 loss: 0.5123
Epoch [4] [299/460] loss: 0.3623
Epoch [4] [460/460] loss: 0.3599
Validation accuracy: 0.3146 top_5_accuracy: 0.8705 loss: 0.5112
Epoch [5] [299/460] loss: 0.3584
Epoch [5] [460/460] loss: 0.3557
Validation accuracy: 0.3249 top_5_accuracy: 0.8783 loss: 0.5113
Epoch [6] [299/460] loss: 0.3534
Epoch [6] [460/460] loss: 0.3527
Validation accuracy: 0.3244 top_5_accuracy: 0.8771 loss: 0.5109
Epoch [7] [299/460] loss: 0.3517
Epoch [7] [460/460] loss: 0.3500
Validation accuracy: 0.3

## References

[1] Siamese Neural Networks for One-Shot Image Recognition - https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf

[2] Quora What is zero-shot learning? https://www.quora.com/What-is-zero-shot-learning

[3] Playing with word embeddings - https://guillem96.github.io/guillem96-blog

[2] Zero-Shot Learning - A Comprehensive Evaluation of the Good, the Bad and the Ugly - https://arxiv.org/pdf/1707.00600.pdf