# ZSL modification of original training code

This code should be used for ZSL only, not for the generalized ZSL scenario.


# With attributes

In [1]:
import os
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

from zsldataset import ZSLDataset
from models import ContinuousMap, ContinuousMapResidual

def dist_matrix(batch1, batch2):
    delta = batch1.unsqueeze(1) - batch2.unsqueeze(0)
    
    dist_matrix = (delta * delta).mean(dim=-1)
    
    return dist_matrix


def mag(u):
    return torch.dot(u, u)


def dist(u, v):
    return torch.dot(u - v, u - v)

In [2]:
from models import EncoderAttributes, DecoderAttributes

In [3]:
def normalize_feature(data):
    mean = data.mean(0)
    std = data.var(0)
    if std > 0:
        return (data - mean) / std
    else:
        return (data - mean) 

In [4]:
trainset = ZSLDataset('Data/CUB/train_set', use_predicates=True, use_irevnet=False)
testset = ZSLDataset('Data/CUB/test_set', use_predicates=True, use_irevnet=False)



#trainset.image_embeddings = normalize_feature(trainset.image_embeddings)
#testset.image_embeddings = normalize_feature(testset.image_embeddings)

In [5]:
bsize = 32
n_epochs = 200
num_classes = trainset.classes.shape[0]

dim_semantic = trainset[0]['class_embedding'].shape[0]
dim_visual = trainset[0]['image_embedding'].shape[0]
dim_attributes = trainset[0]['class_predicates'].shape[0]

all_class_embeddings = torch.tensor(np.array(trainset.class_embeddings)).cuda().float()
all_train_image_embeddings = torch.tensor(np.array(trainset.image_embeddings)).cuda().float()
all_train_labels = torch.tensor(trainset.labels['class_id'].values).cuda() - 1
all_class_predicates = torch.tensor(np.array(trainset.class_predicates)).cuda().float()
classes_enum = torch.tensor(np.array(range(num_classes), dtype=np.int64)).cuda()

In [6]:
query_classes = set([testset[i]['class_label'] for i in range(len(testset))])
query_ids = set([testset[i]['class_id'] for i in range(len(testset))])

In [7]:
ids = list(i-1 for i in query_ids)
query_mask = np.zeros((num_classes))
query_mask[ids] = 1
query_mask = torch.tensor(query_mask, dtype=torch.int64).cuda()

In [8]:
v_to_s = DecoderAttributes(dim_source=dim_visual, dim_target1=dim_attributes, dim_target2=dim_semantic, width=512).cuda()
s_to_v = EncoderAttributes(dim_source1=dim_semantic, dim_source2=dim_attributes, dim_target=dim_visual, width=512).cuda()

# optimizer = torch.optim.Adam(list(v_to_s.parameters()) + list(s_to_v.parameters()),
#                                 lr = 1e-3,
#                                 betas=(0.9, 0.999),
#                                 weight_decay=3e-2)

optimizer = torch.optim.SGD(list(v_to_s.parameters()) + list(s_to_v.parameters()),
                                lr = 5e-4,
                                momentum=.5,
                                nesterov=True,
                                weight_decay=3e-2)




In [9]:
mse = torch.nn.MSELoss()
triplet_loss = torch.nn.TripletMarginLoss(margin=1, p=2)
positive_part = torch.nn.ReLU()

In [10]:
trainloader = torch.utils.data.DataLoader(trainset, 
                                             batch_size=bsize, 
                                             shuffle=True, 
                                             num_workers=4, 
                                             pin_memory=True, 
                                             drop_last=True)

testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=bsize, 
                                         shuffle=True, 
                                         num_workers=4, 
                                         pin_memory=True, 
                                         drop_last=True)

In [11]:
gamma = 1

alpha1 = 50 # triplet
alpha4 = 50 #triplet visual
alpha2 = 5e-3 # surjection
alpha3 = 5e-3 # l2 loss

margin_s = 0
margin_v = 0

In [None]:
for e in range(n_epochs):
    v_to_s = v_to_s.train()
    s_to_v = s_to_v.train()
    
    for i, sample in enumerate(trainloader):
        optimizer.zero_grad()
        
        batch_visual = sample['image_embedding'].cuda().float()
        batch_semantic = sample['class_embedding'].cuda().float()
        batch_predicates = sample['class_predicates'].cuda().float()
        batch_classes = sample['class_id'].cuda() - 1
         
        e_hat = v_to_s(s_to_v(all_class_embeddings, all_class_predicates))
        delta = (e_hat[1] - all_class_embeddings)
        surjection_loss = (delta * delta).sum(dim=-1).mean()
        delta = (e_hat[0] - all_class_predicates)
        surjection_loss = (1-gamma) * surjection_loss + gamma * (delta * delta).sum(dim=-1).mean()
        
        s_out = v_to_s(batch_visual)
        s_attr, s_word = s_out
        
        
        
        same_class = classes_enum.unsqueeze(0) == batch_classes.unsqueeze(1)
        same_class = same_class.detach()
        
        ##Triplet loss in semantic space
        d_matrix_s = (1 - gamma) * dist_matrix(s_word, all_class_embeddings) + \
                    gamma * dist_matrix(s_attr, all_class_predicates)
        
        
        
        closest_negative_s, _ = (d_matrix_s + same_class.float() * 1e6).min(dim=-1)
        furthest_positive_s, _ = (d_matrix_s * same_class.float()).max(dim=-1)
        
        l2_loss_s = (1-gamma) * (s_word * s_word).sum(dim=-1).mean() + \
                    gamma * (s_attr * s_attr).sum(dim=-1).mean()

        trip_loss_s = positive_part(furthest_positive_s - closest_negative_s + margin_s)
        
        
        ##Triplet loss in visual space
        same_class = all_train_labels.unsqueeze(0) == batch_classes.unsqueeze(1)
        same_class = same_class.detach()
        v_out = s_to_v(batch_semantic, batch_predicates)
        d_matrix_v =  dist_matrix(v_out,all_train_image_embeddings)
        
        closest_negative_v, _ = (d_matrix_v + same_class.float() * 1e6).min(dim=-1)
        
        furthest_positive_v, _ = (d_matrix_v * same_class.float()).max(dim=-1)
        
        
        
        l2_loss_v = (1-gamma) * (v_out * v_out).sum(dim=-1).mean() 
        

        trip_loss_v = positive_part(furthest_positive_v - closest_negative_v + margin_v)
        
        ##Total loss
        
        loss = alpha1 * trip_loss_s.mean() + alpha2 * surjection_loss + alpha3 * l2_loss_s  + alpha4 * trip_loss_v.mean() + alpha3 * l2_loss_v 
        
        
        loss.backward()
        
        #print( trip_loss_s.mean(),l2_loss_s,trip_loss_v.mean(),l2_loss_v, surjection_loss)
        
#         c_hat = d_matrix.argmin(dim = -1)
#         print((c_hat == batch_classes).float().mean().item())
        
        optimizer.step()
        
       # print(loss.item(), end=', ')
        
    if (e+1) % 20 == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.7
            
    
    if (e+1) % 1 == 0:
        print('\n\n- Evaluation on epoch {}'.format(e))
        
        avg_accuracy = 0.
        avg_loss = 0.
        n = 0
        
        v_to_s = v_to_s.eval() 
        s_to_v = s_to_v.eval() 
        
        v_out = s_to_v(all_class_embeddings, all_class_predicates)
        with torch.no_grad():
            for i, sample in enumerate(testloader):
                n += 1
                
                batch_visual = sample['image_embedding'].cuda().float()
                #batch_semantic = sample['class_embedding'].cuda().float()

                batch_classes = sample['class_id'].cuda() - 1

                s_out = v_to_s(batch_visual)
                s_attr, s_word = s_out
                
                
                
                same_class = classes_enum.unsqueeze(0) == batch_classes.unsqueeze(1)
                same_class = same_class.detach()

                
                

                ##Triplet loss in semantic space
                d_matrix_s = (1 - gamma) * dist_matrix(s_word, all_class_embeddings) + \
                            gamma * dist_matrix(s_attr, all_class_predicates)

                 

                closest_negative_s, _ = (d_matrix_s + same_class.float() * 1e6).min(dim=-1)
                furthest_positive_s, _ = (d_matrix_s * same_class.float()).max(dim=-1)

                l2_loss_s = (1-gamma) * (s_word * s_word).sum(dim=-1).mean() + \
                            gamma * (s_attr * s_attr).sum(dim=-1).mean()

                trip_loss_s = positive_part(furthest_positive_s - closest_negative_s + margin_s)


                ##Triplet loss in visual space
                
                d_matrix_v =  dist_matrix(batch_visual,v_out)
                same_class = classes_enum.unsqueeze(0) == batch_classes.unsqueeze(1) 
                same_class = same_class.detach()


                closest_negative_v, _ = (d_matrix_v + same_class.float() * 1e6).min(dim=-1)
                furthest_positive_v, _ = (d_matrix_v * same_class.float()).max(dim=-1)
                c_hat = (d_matrix_v + (1 - query_mask).float() * 1e9).argmin(dim = -1)
                
                l2_loss_v = (1-gamma) * (v_out * v_out).sum(dim=-1).mean() 

                trip_loss_v = positive_part(furthest_positive_v - closest_negative_v + margin_v)
                

                ##Total loss
                loss = alpha1 * trip_loss_s.mean() + alpha2 * surjection_loss + alpha3 * l2_loss_s + alpha4 * trip_loss_v.mean()  + alpha3 * l2_loss_v
                #loss = alpha1*trip_loss_v.mean() +  alpha3 * l2_loss_v

                avg_loss += loss.item()
                #print(c_hat)
                #print(batch_visual)
                #print(batch_classes)
                avg_accuracy += (c_hat == batch_classes).float().mean().item()

        avg_accuracy /= n
        avg_loss /= n

        print('Average acc.: {}, Average loss:{}\n\n'.format(avg_accuracy, avg_loss))






- Evaluation on epoch 0
Average acc.: 0.04042119565217391, Average loss:4043.613354226817




- Evaluation on epoch 1
Average acc.: 0.030570652173913044, Average loss:3572.493896484375




- Evaluation on epoch 2
Average acc.: 0.030400815217391304, Average loss:3234.5476472274117




- Evaluation on epoch 3
Average acc.: 0.026834239130434784, Average loss:3271.6826556661854


