In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

import torch
from torch.utils.data import DataLoader

sys.path.append('..')
import zsl

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
fe = 'ResNet101'
ds = zsl.data.AwAFeaturesDataset(
    '../data/AwA2/Animals_with_Attributes2',
    features_type=fe)

In [5]:
features, target, semantic = ds[0]
print(ds.classes[target])
print(features.shape)

antelope
torch.Size([2048])


In [6]:
train_ds = zsl.data.AwAFeaturesDataset(
    '../data/AwA2/Animals_with_Attributes2',
    features_type=fe,
    load_unseen=False)

valid_ds = zsl.data.AwAFeaturesDataset(
    '../data/AwA2/Animals_with_Attributes2',
    features_type=fe,
    load_unseen=True,
    load_only_unseen=True)

train_dl = DataLoader(
    train_ds, 
    shuffle=True, 
    batch_size=64,
    collate_fn=zsl.utils.collate_image_folder)

valid_dl = DataLoader(
    valid_ds, 
    batch_size=32,
    collate_fn=zsl.utils.collate_image_folder)

In [7]:
len(valid_ds), len(train_ds)

(7913, 29409)

In [8]:
semantic_unit = zsl.models.LinearSemanticUnit(
    in_features=len(train_ds.attrs),
    out_features=1024)
visual_fe = zsl.models.Identity(features.size(0))                                

In [9]:
semantic = torch.FloatTensor(semantic)
print('Semantic repr shape:', semantic_unit(semantic.unsqueeze(0)).size())
print('Visual embedding shape:', visual_fe(features.unsqueeze(0)).size())

Semantic repr shape: torch.Size([1, 1024])
Visual embedding shape: torch.Size([1, 2048])


In [10]:
zs_model = zsl.models.ZeroShot(visual_fe, semantic_unit)
image_embed, semantic_embed = zs_model(
    features.unsqueeze(0), torch.FloatTensor(semantic).unsqueeze(0))
image_embed.size(), semantic_embed.size()

(torch.Size([1, 2048]), torch.Size([1, 2048]))

In [11]:
zs_model.to(device);

In [12]:
features, labels, semantics = next(iter(train_dl))
images_embeds, semantics_embeds = zs_model(features.to(device), 
                                           semantics.to(device))

images_embeds.size(), semantics_embeds.size()

(torch.Size([64, 2048]), torch.Size([64, 2048]))

In [13]:
parameters = [p for p in zs_model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(parameters, 1e-4, weight_decay=9e-5)

In [14]:
EPOCHS = 8
semantic_repr = torch.FloatTensor(valid_ds.attr_matrix)

for epoch in range(EPOCHS):
    zsl.engine.train_epoch(
         model=zs_model, dl=train_dl, optimizer=optimizer, 
         epoch=epoch, print_freq=300, device=device)
    
    zsl.engine.evaluate(
        model=zs_model, dl=valid_dl,
        class_representations=semantic_repr, device=device)

Epoch [0] [299/460] loss: 0.4748
Epoch [0] [460/460] loss: 0.4499
Validation accuracy: 0.2174 top_5_accuracy: 0.8322 loss: 0.5206
Epoch [1] [299/460] loss: 0.3927
Epoch [1] [460/460] loss: 0.3883
Validation accuracy: 0.2709 top_5_accuracy: 0.8625 loss: 0.5146
Epoch [2] [299/460] loss: 0.3771
Epoch [2] [460/460] loss: 0.3737
Validation accuracy: 0.2932 top_5_accuracy: 0.8606 loss: 0.5131
Epoch [3] [299/460] loss: 0.3677
Epoch [3] [460/460] loss: 0.3655
Validation accuracy: 0.3029 top_5_accuracy: 0.8659 loss: 0.5123
Epoch [4] [299/460] loss: 0.3623
Epoch [4] [460/460] loss: 0.3599
Validation accuracy: 0.3146 top_5_accuracy: 0.8705 loss: 0.5112
Epoch [5] [299/460] loss: 0.3584
Epoch [5] [460/460] loss: 0.3557
Validation accuracy: 0.3249 top_5_accuracy: 0.8783 loss: 0.5113
Epoch [6] [299/460] loss: 0.3534
Epoch [6] [460/460] loss: 0.3527
Validation accuracy: 0.3244 top_5_accuracy: 0.8771 loss: 0.5109
Epoch [7] [299/460] loss: 0.3517
Epoch [7] [460/460] loss: 0.3500
Validation accuracy: 0.3