# ZSL training - with attributes only

This code should be used for ZSL only, not for the generalized ZSL scenario.


In [1]:
import os
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

from zsldataset import ZSLDataset
from models import ContinuousMap, ContinuousMapResidual

def dist_matrix(batch1, batch2):
    delta = batch1.unsqueeze(1) - batch2.unsqueeze(0)
    
    dist_matrix = (delta * delta).mean(dim=-1)
    
    return dist_matrix


def mag(u):
    return torch.dot(u, u)


def dist(u, v):
    return torch.dot(u - v, u - v)

## Load APY dataset

Should be converted into ZSL format beforehand.

In [2]:
trainset = ZSLDataset('Data/APY_Zero/train', use_predicates=True, use_irevnet=True)    ############# changed this line
testset = ZSLDataset('Data/APY_Zero/test', use_predicates=True, use_irevnet=True)    ############# changed this line

In [4]:
bsize = 128
n_epochs = 200
num_classes = trainset.classes.shape[0]

dim_semantic = trainset[0]['class_predicates'].shape[0]   ############# changed this line
dim_visual = trainset[0]['image_embedding'].shape[0]  ############# changed this line

all_class_embeddings = torch.tensor(np.array(trainset.class_predicates)).cuda().float()  ############# changed this line

classes_enum = torch.tensor(np.array(range(num_classes), dtype=np.int64)).cuda()

In [5]:
query_classes = set([testset[i]['class_label'] for i in range(len(testset))])
query_ids = set([testset[i]['class_id'] for i in range(len(testset))])

In [6]:
ids = list(i-1 for i in query_ids)
query_mask = np.zeros((num_classes))
query_mask[ids] = 1
query_mask = torch.tensor(query_mask, dtype=torch.int64).cuda()

In [7]:
v_to_s = ContinuousMap(dim_source=dim_visual, dim_dest=dim_semantic, width=512).cuda()
s_to_v = ContinuousMap(dim_source=dim_semantic, dim_dest=dim_visual, width=512).cuda()

# optimizer = torch.optim.Adam(list(v_to_s.parameters()) + list(s_to_v.parameters()),
#                                 lr = 1e-3,
#                                 betas=(0.9, 0.999),
#                                 weight_decay=3e-2)

optimizer = torch.optim.SGD(list(v_to_s.parameters()) + list(s_to_v.parameters()),
                                lr = 1e-4,
                                momentum=.9,
                                weight_decay=5e-2)

In [8]:
mse = torch.nn.MSELoss()
positive_part = torch.nn.ReLU()

In [9]:
trainloader = torch.utils.data.DataLoader(trainset, 
                                             batch_size=bsize, 
                                             shuffle=True, 
                                             num_workers=4, 
                                             pin_memory=True, 
                                             drop_last=True)

testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=bsize, 
                                         shuffle=True, 
                                         num_workers=4, 
                                         pin_memory=True, 
                                         drop_last=True)

In [11]:
trainset[1]['class_predicates'].shape

torch.Size([64])

In [12]:
alpha1 = 20 # triplet
alpha2 = 1e-2 # surjection
alpha3 = 1e-3 # l2 loss

alpha_backward = 1e-2

margin = 1

In [None]:
for e in range(n_epochs):
    v_to_s = v_to_s.train()
    s_to_v = s_to_v.train()
    
    for i, sample in enumerate(trainloader):
        optimizer.zero_grad()
        
        batch_visual = sample['image_embedding'].cuda().float()  ############# changed this line
        batch_semantic = sample['class_predicates'].cuda().float() ############# changed this line

        batch_classes = sample['class_id'].cuda() - 1
        
        backward_v = s_to_v(all_class_embeddings)
        e_hat = v_to_s(backward_v)
        delta = (e_hat - all_class_embeddings)
        surjection_loss = (delta * delta).sum(dim=-1).mean()
        
        s_out = v_to_s(batch_visual)
        
        same_class = classes_enum.unsqueeze(0) == batch_classes.unsqueeze(1)
        same_class = same_class.detach()

        d_matrix = dist_matrix(s_out, all_class_embeddings)

        closest_negative, _ = (d_matrix + same_class.float() * 1e6).min(dim=-1)
        furthest_positive, _ = (d_matrix * same_class.float()).max(dim=-1)
        
        l2_loss = (s_out * s_out).sum(dim=-1).mean()

        loss = positive_part(furthest_positive - closest_negative + margin)
        
        backwards_l2_loss = (backward_v * backward_v).sum(dim=-1).mean()
        
        loss = alpha1 * loss.mean() + alpha2 * surjection_loss + \
                alpha3 * l2_loss + alpha_backward * backwards_l2_loss
        
        loss.backward()
        
#         c_hat = d_matrix.argmin(dim = -1)
#         print((c_hat == batch_classes).float().mean().item())
        
        optimizer.step()
        
        print(loss.item(), end=', ')
        
    if (e+1) % 50 == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.7
    
    if (e+1) % 5 == 0:
        print('\n\n- Evaluation on epoch {}'.format(e))
        
        avg_accuracy = 0.
        avg_loss = 0.
        n = 0
        
        v_to_s = v_to_s.eval() 
        s_to_v = s_to_v.eval() 
        
        with torch.no_grad():
            for i, sample in enumerate(testloader):
                n += 1
                
                batch_visual = sample['image_embedding'].cuda().float()  ############# changed this line
                batch_semantic = sample['class_predicates'].cuda().float()   ############# changed this line

                batch_classes = sample['class_id'].cuda() - 1

                s_out = v_to_s(batch_visual)

                same_class = classes_enum.unsqueeze(0) == batch_classes.unsqueeze(1)
                same_class = same_class.detach()

                d_matrix = dist_matrix(s_out, all_class_embeddings) 
                
                c_hat = (d_matrix + (1 - query_mask).float() * 1e6).argmin(dim = -1)

                closest_negative, _ = (d_matrix + same_class.float() * 1e6).min(dim=-1)
                furthest_positive, _ = (d_matrix * same_class.float()).max(dim=-1)

                loss = positive_part(furthest_positive - closest_negative + margin)
                loss = alpha1 * furthest_positive.mean()

                avg_loss += loss.item()
                avg_accuracy += (c_hat == batch_classes).float().mean().item()

        avg_accuracy /= n
        avg_loss /= n

        print('Average acc.: {}, Average loss:{}\n\n'.format(avg_accuracy, avg_loss))


23.862548828125, 23.815404891967773, 23.87175178527832, 23.79213523864746, 23.81304168701172, 23.843965530395508, 23.90216064453125, 23.825401306152344, 23.826231002807617, 23.902111053466797, 23.654884338378906, 23.871702194213867, 23.78680419921875, 23.792404174804688, 23.690378189086914, 23.803163528442383, 23.679607391357422, 23.83673095703125, 23.785581588745117, 23.769989013671875, 23.691364288330078, 23.722431182861328, 23.625415802001953, 23.61703872680664, 23.649789810180664, 23.626506805419922, 23.6696834564209, 23.739299774169922, 23.75798225402832, 23.617061614990234, 23.67873764038086, 23.60660743713379, 23.675582885742188, 23.667987823486328, 23.67871856689453, 23.630874633789062, 23.66771697998047, 23.437854766845703, 23.480878829956055, 23.533554077148438, 23.536422729492188, 23.59556770324707, 23.407636642456055, 23.632667541503906, 23.532798767089844, 23.514942169189453, 23.474397659301758, 23.505352020263672, 23.539169311523438, 23.5230770111084, 23.378650665283203, 

22.375022888183594, 21.830097198486328, 21.651351928710938, 22.272605895996094, 22.083925247192383, 21.705820083618164, 22.242271423339844, 22.092472076416016, 21.88823699951172, 21.404460906982422, 21.44552993774414, 21.373497009277344, 21.842273712158203, 21.732500076293945, 22.10805320739746, 21.76866340637207, 21.755037307739258, 21.998559951782227, 21.41439437866211, 21.290443420410156, 21.95783233642578, 21.80103302001953, 21.70794677734375, 21.68015480041504, 21.63582420349121, 21.88988494873047, 21.436307907104492, 21.57554054260254, 21.870912551879883, 22.03746795654297, 22.35259246826172, 21.60399055480957, 21.872312545776367, 22.035673141479492, 22.169734954833984, 21.514421463012695, 21.793155670166016, 21.691490173339844, 21.363887786865234, 21.727468490600586, 21.824199676513672, 21.96822738647461, 21.920059204101562, 21.910356521606445, 21.704769134521484, 21.77998924255371, 21.90180015563965, 21.367172241210938, 21.20092010498047, 20.867389678955078, 21.027761459350586,