# ZSL training - with attributes only

This code should be used for ZSL only, not for the generalized ZSL scenario.


In [7]:
import os
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

from zsldataset import ZSLDataset
from models import ContinuousMap, ContinuousMapResidual

def dist_matrix(batch1, batch2):
    delta = batch1.unsqueeze(1) - batch2.unsqueeze(0)
    
    dist_matrix = (delta * delta).mean(dim=-1)
    
    return dist_matrix


def mag(u):
    return torch.dot(u, u)


def dist(u, v):
    return torch.dot(u - v, u - v)

## Load APY dataset

Should be converted into ZSL format beforehand.

In [8]:
trainset = ZSLDataset('Data/APY_Zero/train', use_predicates=True, use_irevnet=True)    ############# changed this line
testset = ZSLDataset('Data/APY_Zero/test', use_predicates=True, use_irevnet=True)    ############# changed this line

In [9]:
bsize = 128
n_epochs = 200
num_classes = trainset.classes.shape[0]

dim_semantic = trainset[0]['class_embedding'].shape[0]   ############# changed this line
dim_visual = trainset[0]['class_predicates'].shape[0]

all_class_embeddings = torch.tensor(np.array(trainset.class_embeddings)).cuda().float()
classes_enum = torch.tensor(np.array(range(num_classes), dtype=np.int64)).cuda()

In [10]:
query_classes = set([testset[i]['class_label'] for i in range(len(testset))])
query_ids = set([testset[i]['class_id'] for i in range(len(testset))])

In [11]:
ids = list(i-1 for i in query_ids)
query_mask = np.zeros((num_classes))
query_mask[ids] = 1
query_mask = torch.tensor(query_mask, dtype=torch.int64).cuda()

In [12]:
v_to_s = ContinuousMap(dim_source=dim_visual, dim_dest=dim_semantic, width=512).cuda()
s_to_v = ContinuousMap(dim_source=dim_semantic, dim_dest=dim_visual, width=512).cuda()

# optimizer = torch.optim.Adam(list(v_to_s.parameters()) + list(s_to_v.parameters()),
#                                 lr = 1e-3,
#                                 betas=(0.9, 0.999),
#                                 weight_decay=3e-2)

optimizer = torch.optim.SGD(list(v_to_s.parameters()) + list(s_to_v.parameters()),
                                lr = 1e-4,
                                momentum=.9,
                                weight_decay=5e-2)

In [13]:
mse = torch.nn.MSELoss()
positive_part = torch.nn.ReLU()

In [14]:
trainloader = torch.utils.data.DataLoader(trainset, 
                                             batch_size=bsize, 
                                             shuffle=True, 
                                             num_workers=4, 
                                             pin_memory=True, 
                                             drop_last=True)

testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=bsize, 
                                         shuffle=True, 
                                         num_workers=4, 
                                         pin_memory=True, 
                                         drop_last=True)

In [15]:
alpha1 = 20 # triplet
alpha2 = 1e-2 # surjection
alpha3 = 1e-3 # l2 loss

alpha_backward = 1e-2

margin = 1

In [None]:
for e in range(n_epochs):
    v_to_s = v_to_s.train()
    s_to_v = s_to_v.train()
    
    for i, sample in enumerate(trainloader):
        optimizer.zero_grad()
        
        batch_visual = sample['class_predicates'].cuda().float()
        batch_semantic = sample['class_embedding'].cuda().float() ############# changed this line

        batch_classes = sample['class_id'].cuda() - 1
        
        backward_v = s_to_v(all_class_embeddings)
        e_hat = v_to_s(backward_v)
        delta = (e_hat - all_class_embeddings)
        surjection_loss = (delta * delta).sum(dim=-1).mean()
        
        s_out = v_to_s(batch_visual)
        
        same_class = classes_enum.unsqueeze(0) == batch_classes.unsqueeze(1)
        same_class = same_class.detach()

        d_matrix = dist_matrix(s_out, all_class_embeddings)

        closest_negative, _ = (d_matrix + same_class.float() * 1e6).min(dim=-1)
        furthest_positive, _ = (d_matrix * same_class.float()).max(dim=-1)
        
        l2_loss = (s_out * s_out).sum(dim=-1).mean()

        loss = positive_part(furthest_positive - closest_negative + margin)
        
        backwards_l2_loss = (backward_v * backward_v).sum(dim=-1).mean()
        
        loss = alpha1 * loss.mean() + alpha2 * surjection_loss + \
                alpha3 * l2_loss + alpha_backward * backwards_l2_loss
        
        loss.backward()
        
#         c_hat = d_matrix.argmin(dim = -1)
#         print((c_hat == batch_classes).float().mean().item())
        
        optimizer.step()
        
        print(loss.item(), end=', ')
        
    if (e+1) % 50 == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.7
    
    if (e+1) % 5 == 0:
        print('\n\n- Evaluation on epoch {}'.format(e))
        
        avg_accuracy = 0.
        avg_loss = 0.
        n = 0
        
        v_to_s = v_to_s.eval() 
        s_to_v = s_to_v.eval() 
        
        with torch.no_grad():
            for i, sample in enumerate(testloader):
                n += 1
                
                batch_visual = sample['class_predicates'].cuda().float()
                batch_semantic = sample['class_embedding'].cuda().float()   ############# changed this line

                batch_classes = sample['class_id'].cuda() - 1

                s_out = v_to_s(batch_visual)

                same_class = classes_enum.unsqueeze(0) == batch_classes.unsqueeze(1)
                same_class = same_class.detach()

                d_matrix = dist_matrix(s_out, all_class_embeddings) 
                
                c_hat = (d_matrix + (1 - query_mask).float() * 1e6).argmin(dim = -1)

                closest_negative, _ = (d_matrix + same_class.float() * 1e6).min(dim=-1)
                furthest_positive, _ = (d_matrix * same_class.float()).max(dim=-1)

                loss = positive_part(furthest_positive - closest_negative + margin)
                loss = alpha1 * furthest_positive.mean()

                avg_loss += loss.item()
                avg_accuracy += (c_hat == batch_classes).float().mean().item()

        avg_accuracy /= n
        avg_loss /= n

        print('Average acc.: {}, Average loss:{}\n\n'.format(avg_accuracy, avg_loss))


21.626522064208984, 21.831846237182617, 21.797863006591797, 21.527210235595703, 21.581579208374023, 21.62843894958496, 21.540184020996094, 21.582786560058594, 21.639509201049805, 21.373931884765625, 21.524520874023438, 21.480995178222656, 21.42144012451172, 21.4009952545166, 21.48554229736328, 21.482646942138672, 21.38574981689453, 21.445228576660156, 21.33075714111328, 21.42051887512207, 21.377683639526367, 21.312255859375, 21.148357391357422, 21.280555725097656, 21.354936599731445, 21.190929412841797, 21.224994659423828, 21.245349884033203, 21.09380531311035, 21.237028121948242, 21.24837303161621, 20.954631805419922, 20.879640579223633, 21.063318252563477, 21.10121726989746, 20.97960662841797, 20.9656982421875, 21.016647338867188, 21.025257110595703, 21.055082321166992, 21.110498428344727, 20.90865707397461, 20.860883712768555, 20.86359405517578, 20.828580856323242, 20.910438537597656, 20.872699737548828, 20.96702003479004, 20.769025802612305, 20.750896453857422, 20.805509567260742, 

18.18260383605957, 18.681957244873047, 18.46857261657715, 18.174043655395508, 18.067270278930664, 18.588520050048828, 18.403247833251953, 18.178720474243164, 17.91973304748535, 18.342037200927734, 18.669498443603516, 18.313020706176758, 18.23192024230957, 18.274120330810547, 18.191709518432617, 18.384014129638672, 18.5758056640625, 18.111631393432617, 18.398263931274414, 18.530536651611328, 17.95981216430664, 18.400028228759766, 18.040786743164062, 18.15311050415039, 18.478361129760742, 18.445205688476562, 18.01478385925293, 18.26280975341797, 18.1685733795166, 18.215864181518555, 18.276493072509766, 18.081762313842773, 18.111377716064453, 17.97718048095703, 18.197891235351562, 18.144014358520508, 18.596357345581055, 18.046302795410156, 18.2386531829834, 18.56035041809082, 17.773818969726562, 18.244802474975586, 18.045469284057617, 17.99386978149414, 18.1984806060791, 18.453441619873047, 18.02641487121582, 18.048614501953125, 18.195819854736328, 18.1612491607666, 17.988525390625, 17.84

16.661827087402344, 17.105384826660156, 16.940156936645508, 16.655908584594727, 16.587322235107422, 16.65022850036621, 16.71297264099121, 16.99538803100586, 16.3099365234375, 17.06844711303711, 17.10977554321289, 16.530759811401367, 16.41852569580078, 16.98876953125, 16.869022369384766, 16.486913681030273, 16.69896125793457, 17.50823974609375, 16.964502334594727, 17.190174102783203, 16.728321075439453, 16.86725425720215, 16.642850875854492, 16.708070755004883, 17.14051628112793, 16.26302146911621, 16.34958267211914, 16.844493865966797, 16.428050994873047, 16.758167266845703, 16.62713050842285, 16.69371223449707, 16.547412872314453, 16.882381439208984, 16.881710052490234, 16.716333389282227, 16.92180633544922, 16.64116859436035, 16.581899642944336, 16.775997161865234, 16.69621467590332, 16.69732666015625, 16.755382537841797, 16.462064743041992, 16.74283790588379, 16.635009765625, 16.87989616394043, 16.475400924682617, 16.664846420288086, 16.337797164916992, 16.561548233032227, 16.452020

16.3931827545166, 16.037599563598633, 15.799846649169922, 15.719684600830078, 15.952982902526855, 15.40995979309082, 15.875943183898926, 15.0660400390625, 15.650259971618652, 15.30361557006836, 15.860596656799316, 15.19624137878418, 15.331766128540039, 15.57806396484375, 14.950273513793945, 15.723356246948242, 15.878609657287598, 15.326936721801758, 15.633862495422363, 15.51990032196045, 15.587722778320312, 15.070497512817383, 15.271142959594727, 15.748065948486328, 15.144529342651367, 14.573816299438477, 14.86674976348877, 15.340375900268555, 15.016793251037598, 15.206432342529297, 15.742267608642578, 14.799433708190918, 15.01598072052002, 14.784069061279297, 15.481156349182129, 14.683658599853516, 15.835243225097656, 15.192510604858398, 15.64499282836914, 15.495772361755371, 15.519622802734375, 15.39673137664795, 15.476088523864746, 15.260761260986328, 15.774542808532715, 14.987380027770996, 15.016458511352539, 15.287775993347168, 15.350390434265137, 15.054814338684082, 15.3716144561

14.684781074523926, 14.440291404724121, 14.4516019821167, 14.13287353515625, 13.633955001831055, 14.175944328308105, 14.341293334960938, 14.037240028381348, 14.319143295288086, 13.720710754394531, 14.447550773620605, 14.26552963256836, 14.522953033447266, 13.802881240844727, 13.956962585449219, 14.413142204284668, 15.004664421081543, 14.757427215576172, 14.213775634765625, 14.174628257751465, 14.717645645141602, 13.238310813903809, 14.050925254821777, 13.924151420593262, 14.103264808654785, 13.942874908447266, 14.185487747192383, 14.87396240234375, 14.622550010681152, 13.943228721618652, 14.350279808044434, 14.241588592529297, 14.057409286499023, 13.627036094665527, 13.627589225769043, 13.909385681152344, 13.955985069274902, 13.772621154785156, 14.079373359680176, 14.398104667663574, 13.83895206451416, 14.138755798339844, 14.024748802185059, 13.43718147277832, 14.19406795501709, 14.709439277648926, 13.95596981048584, 13.871350288391113, 14.283126831054688, 13.870586395263672, 14.288336

12.90578556060791, 13.044063568115234, 12.628018379211426, 13.348260879516602, 12.925190925598145, 13.06873893737793, 12.485247611999512, 12.24502182006836, 13.055296897888184, 13.457350730895996, 13.07153034210205, 13.484113693237305, 13.080387115478516, 12.356426239013672, 12.640822410583496, 12.793391227722168, 13.506183624267578, 13.197697639465332, 12.588340759277344, 12.931527137756348, 13.495763778686523, 12.639348030090332, 13.447846412658691, 12.688875198364258, 12.751209259033203, 12.446146965026855, 12.449901580810547, 13.302311897277832, 13.198976516723633, 13.131428718566895, 12.374874114990234, 12.973244667053223, 13.302789688110352, 12.872011184692383, 12.607413291931152, 12.863967895507812, 12.201926231384277, 12.28445816040039, 13.365080833435059, 12.983914375305176, 13.01285457611084, 12.99332046508789, 12.598163604736328, 12.911722183227539, 13.434992790222168, 13.048186302185059, 12.856419563293457, 12.851470947265625, 12.316258430480957, 13.36953067779541, 12.79676

11.549243927001953, 11.736556053161621, 11.452280044555664, 12.09809684753418, 12.101640701293945, 12.052680969238281, 12.077905654907227, 11.799114227294922, 11.892147064208984, 11.402592658996582, 11.764887809753418, 11.23827075958252, 12.172703742980957, 13.059720039367676, 12.081201553344727, 11.865330696105957, 11.748234748840332, 13.159747123718262, 11.980545997619629, 12.236334800720215, 11.548572540283203, 12.464029312133789, 12.166632652282715, 12.029038429260254, 12.355228424072266, 12.130411148071289, 11.145687103271484, 11.86668586730957, 11.528216361999512, 11.676589012145996, 11.872177124023438, 12.66719913482666, 11.346927642822266, 10.968866348266602, 11.932534217834473, 12.312335968017578, 11.912606239318848, 11.919590950012207, 12.024818420410156, 11.543084144592285, 11.413585662841797, 11.950196266174316, 11.200279235839844, 11.079272270202637, 12.106224060058594, 12.26183795928955, 12.27004623413086, 11.99100399017334, 10.952073097229004, 12.213754653930664, 11.3811

10.99769115447998, 11.19347858428955, 10.862113952636719, 11.239029884338379, 11.266582489013672, 11.395651817321777, 10.627076148986816, 11.545331954956055, 11.513633728027344, 10.551194190979004, 11.522296905517578, 11.886292457580566, 10.9959077835083, 11.860644340515137, 11.016124725341797, 11.493706703186035, 10.665905952453613, 11.712285995483398, 11.93914794921875, 11.7870454788208, 10.434375762939453, 10.908318519592285, 11.060883522033691, 10.426579475402832, 11.560025215148926, 10.985363960266113, 10.135302543640137, 11.421323776245117, 11.473682403564453, 10.607990264892578, 10.439972877502441, 11.55384349822998, 10.622040748596191, 10.610424041748047, 10.470956802368164, 10.659149169921875, 11.303423881530762, 11.490901947021484, 10.723213195800781, 10.856914520263672, 11.045550346374512, 11.021334648132324, 10.728358268737793, 11.583738327026367, 11.189071655273438, 11.026880264282227, 10.882893562316895, 11.162230491638184, 10.814001083374023, 11.14491081237793, 10.598673