In [1]:
import os
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

from zsldataset import ZSLDataset
from models import ContinuousMap, ContinuousMapResidual

## Load Yahoo dataset

Should be converted into ZSL format beforehand.

In [6]:
trainset = ZSLDataset('Data/APY/train', use_predicates=False, use_irevnet=True)
testset = ZSLDataset('Data/APY/test', use_predicates=False, use_irevnet=True)

In [7]:
bsize = 128
n_epochs = 150
num_classes = trainset.classes.shape[0]

dim_semantic = trainset[0]['class_embedding'].shape[0]
dim_visual = trainset[0]['image_embedding'].shape[0]

all_class_embeddings = torch.tensor(np.array(trainset.class_embeddings)).cuda().float()
classes_enum = torch.tensor(np.array(range(num_classes), dtype=np.int64)).cuda()

In [14]:
v_to_s = ContinuousMap(dim_source=dim_visual, dim_dest=dim_semantic, width=512).cuda()
s_to_v = ContinuousMap(dim_source=dim_semantic, dim_dest=dim_visual, width=512).cuda()

# optimizer = torch.optim.Adam(list(v_to_s.parameters()) + list(s_to_v.parameters()),
#                                 lr = 1e-3,
#                                 betas=(0.9, 0.999),
#                                 weight_decay=3e-2)

optimizer = torch.optim.SGD(list(v_to_s.parameters()) + list(s_to_v.parameters()),
                                lr = 1e-3,
                                momentum=.9,
                                weight_decay=3e-2)

In [15]:
mse = torch.nn.MSELoss()
triplet_loss = torch.nn.TripletMarginLoss(margin=1, p=2)
positive_part = torch.nn.ReLU()

In [16]:
trainloader = torch.utils.data.DataLoader(trainset, 
                                             batch_size=bsize, 
                                             shuffle=True, 
                                             num_workers=4, 
                                             pin_memory=True, 
                                             drop_last=True)

testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=bsize, 
                                         shuffle=True, 
                                         num_workers=4, 
                                         pin_memory=True, 
                                         drop_last=True)

In [17]:
def dist_matrix(batch1, batch2):
    delta = batch1.unsqueeze(1) - batch2.unsqueeze(0)
    
    dist_matrix = (delta * delta).mean(dim=-1)
    
    return dist_matrix


def mag(u):
    return torch.dot(u, u)


def dist(u, v):
    return torch.dot(u - v, u - v)

In [None]:
for e in range(n_epochs):
    v_to_s = v_to_s.train()
    s_to_v = s_to_v.train()
    
    for i, sample in enumerate(trainloader):
        optimizer.zero_grad()
        
        batch_visual = sample['image_embedding'].cuda().float()
        batch_semantic = sample['class_embedding'].cuda().float()

        batch_classes = sample['class_id'].cuda()
        
        e_hat = v_to_s(s_to_v(all_class_embeddings))
        delta = (e_hat - all_class_embeddings)
        surjection_loss = (delta * delta).sum(dim=-1).mean()
        
        s_out = v_to_s(batch_visual)
        
        same_class = classes_enum.unsqueeze(0) == batch_classes.unsqueeze(1)
        same_class = same_class.detach()

        d_matrix = dist_matrix(s_out, all_class_embeddings)

        closest_negative, _ = (d_matrix + same_class.float() * 1e6).min(dim=-1)
        furthest_positive, _ = (d_matrix * (1 - same_class.float())).max(dim=-1)
        
        l2_loss = (s_out * s_out).sum(dim=-1).mean()

        loss = positive_part(furthest_positive - closest_negative + 1)
        loss = 10 * loss.mean() + .1 * surjection_loss + 1e-3 * l2_loss
        
        loss.backward()
        
        c_hat = d_matrix.argmin(dim = -1)
        print((c_hat == batch_classes).float().mean().item())
        
        optimizer.step()
        
        print(loss.item(), end=', ')
        
    
    if (e+1) % 5 == 0:
        print('\n\n- Evaluation on epoch {}'.format(e))
        
        avg_accuracy = 0.
        avg_loss = 0.
        n = 0
        
        v_to_s = v_to_s.eval() 
        s_to_v = s_to_v.eval() 
        
        with torch.no_grad():
            for i, sample in enumerate(testloader):
                n += 1
                
                batch_visual = sample['image_embedding'].cuda().float()
                batch_semantic = sample['class_embedding'].cuda().float()

                batch_classes = sample['class_id'].cuda()

                s_out = v_to_s(batch_visual)

                same_class = classes_enum.unsqueeze(0) == batch_classes.unsqueeze(1)
                same_class = same_class.detach()

                d_matrix = dist_matrix(s_out, all_class_embeddings)
                
                c_hat = d_matrix.argmin(dim = -1)

                closest_negative, _ = (d_matrix + same_class.float() * 1e6).min(dim=-1)
                furthest_positive, _ = (d_matrix * (1 - same_class.float())).max(dim=-1)

                loss = positive_part(furthest_positive - closest_negative + 5)
                loss = 10 * furthest_positive.mean()

                avg_loss += loss.item()
                avg_accuracy += (c_hat == batch_classes).float().mean().item()

        avg_accuracy /= n
        avg_loss /= n

        print('Average acc.: {}, Average loss:{}\n\n'.format(avg_accuracy, avg_loss))


0.171875
18.196125030517578, 0.1484375
17.79766082763672, 0.1796875
17.19512939453125, 0.171875
16.655811309814453, 0.2265625
16.23833465576172, 0.1875
15.936046600341797, 0.2265625
15.676041603088379, 0.1484375
15.457474708557129, 0.234375
15.239941596984863, 0.234375
15.058481216430664, 0.125
14.894320487976074, 0.1484375
14.732766151428223, 0.203125
14.598896026611328, 0.2109375
14.468913078308105, 0.1796875
14.342936515808105, 0.2109375
14.221004486083984, 0.1796875
14.1026029586792, 0.21875
14.000967025756836, 0.2578125
13.916930198669434, 0.203125
13.81291675567627, 0.203125
13.724823951721191, 0.203125
13.640247344970703, 0.28125
13.565312385559082, 0.2578125
13.485698699951172, 0.2265625
13.425681114196777, 0.2421875
13.343587875366211, 0.3203125
13.272185325622559, 0.25
13.215731620788574, 0.296875
13.141411781311035, 0.3046875
13.087477684020996, 0.2578125
13.033390045166016, 0.265625
12.981819152832031, 0.2890625
12.920666694641113, 0.2734375
12.877577781677246, 0.28125
12.8

In [None]:
closest_negative