In [1]:
import os
import pickle
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np

In [2]:
device = torch.device("cuda:1")

In [17]:
class FeatureDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset[0])
    
    def __getitem__(self, idx):
        img,text,label = self.dataset[0][idx],self.dataset[1][idx],self.dataset[3][idx]
        
        return img, text, label

In [18]:

class Encoder(nn.Module):
    def __init__(self,output_size,input_size=1024):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Linear(input_size,512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.LeakyReLU()
        )
        self.layer2 = nn.Linear(512,output_size)
    
    def forward(self, x):
        x = self.layer1(x)
        return self.layer2(x)

In [19]:
def load_pickled_data(path):
    with open(path, 'rb') as fp:
        data = pickle.load(fp)
    return data

In [22]:
img_text_features_train_LABELLED = load_pickled_data('/common/home/gg676/535/data/labelled_data/embeddings_train1_LABELLED.pkl')
features_val = load_pickled_data('/common/home/gg676/Downloads/embeddings_val1.pkl')
features_test = load_pickled_data('/common/home/gg676/Downloads/embeddings_test1.pkl')

img_val = features_val[0]
text_val = features_val[1]

img_test = features_test[0]
text_test = features_test[1]

dataset = FeatureDataset(img_text_features_train_LABELLED)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=60)

In [7]:
def rank(txt_data, img_data):
    results_dict = {}
    projection_txt, projection_img = txt_data, img_data
    med_dict = {}
    idxs = range(1000)
    
    glob_rank = []
    glob_recall = {1:0.0,5:0.0,10:0.0}
    
    for i in range(10):
        ids = random.sample(range(0,txt_data.shape[0]-1), 1000)
        
        txt_sample = projection_txt[ids,:]
        img_sample = projection_img[ids,:]
        
        similarity = np.dot(txt_sample.cpu().numpy(), img_sample.T.cpu().numpy())

        med_rank = []
        
        recall = {1:0.0,5:0.0,10:0.0}
        
        for ii in idxs:
            # get a column of similarities
            sim = similarity[ii,:]
            # sort indices in descending order
            sorting = np.argsort(sim)[::-1].tolist()
            # find where the index of the pair sample ended up in the sorting
            pos = sorting.index(ii)  
            if (pos+1) == 1:
                recall[1]+=1
            if (pos+1) <=5:
                recall[5]+=1
            if (pos+1)<=10:
                recall[10]+=1
            # store the position
            med_rank.append(pos+1)
        for i in recall.keys():
            recall[i]=recall[i]/1000
        med = np.median(med_rank)
        for i in recall.keys():
            glob_recall[i]+=recall[i]
        glob_rank.append(med)

    for i in glob_recall.keys():
        glob_recall[i] = glob_recall[i]/10
    
    med_dict["mean_median"] = np.average(glob_rank)
    med_dict["recall"] = glob_recall
    med_dict["median_all"] = glob_rank
    print("Result:",med_dict)
    return med_dict

In [8]:
model_img = Encoder(512)
model_text = Encoder(512)
optimizer = optim.Adam(list(model_text.parameters()) + list(model_img.parameters()) , lr=1e-6)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, verbose=True)
loss = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))

In [9]:
def get_hard_negative(anchor,positive,label):
    labels_equal = torch.eq(label.unsqueeze(0),label.unsqueeze(1))
    mask_neg = torch.logical_not(labels_equal)
    distance_matrix = torch.matmul(anchor, positive.T)
    anchor_negative_dist = mask_neg * distance_matrix
    hard_negative_idx = torch.argmax(anchor_negative_dist, dim=1)
    hard_negative = torch.index_select(positive, 0, hard_negative_idx)
    
    return hard_negative

In [10]:
from tqdm import tqdm
def train(train_loader, img_model, txt_model, loss, optimizer): 
    total_loss = 0.0       
    for i, (img_input_batch, text_input_batch,label) in tqdm(enumerate(train_loader)):
        img_input_batch = img_input_batch.to(device)
        text_input_batch = text_input_batch.to(device)
        label = label.to(device)
        hard_negative_emb = get_hard_negative(img_input_batch,text_input_batch,label)
        optimizer.zero_grad()
        anchor = img_model(img_input_batch.to(device)) 
        positive = txt_model(text_input_batch.to(device))
        negative = txt_model(hard_negative_emb.to(device))
        curr_loss = loss(anchor,positive,negative) 
        curr_loss.backward()
        optimizer.step()
        total_loss += curr_loss.item()
    return total_loss/i

In [11]:
def save_model(model, file_name):
    torch.save(model, '/common/home/gg676/535/saved_models/'+file_name)

In [12]:
def load_model(file_name):
    model = torch.load('/common/home/gg676/535/saved_models/'+file_name)
    return model

In [13]:
no_epochs = 4
model_img = model_img.to(device)
model_text = model_text.to(device)
model_img.train()
model_text.train()
lowest_median_rank = 99.0
for epoch in range(no_epochs):
    train_loss = train(train_loader, model_img, model_text,loss,optimizer)
    with torch.no_grad():
        model_img.eval()
        model_text.eval()
        out_text, out_img = model_text(torch.tensor(text_val).to('cuda:1')), model_img(torch.tensor(img_val).to('cuda:1'))
        med_rank = rank(out_text,out_img)
        if med_rank['mean_median'] < lowest_median_rank:
            lowest_median_rank = med_rank['mean_median']
            save_model(model_text, 'text_model_triplet')
            save_model(model_img, 'img_model_triplet')
    scheduler.step(med_rank['recall'][1])
    print('Epoch {} loss: {}'.format(epoch, train_loss))

4400it [00:44, 99.33it/s] 


Result: {'mean_median': 7.3, 'recall': {1: 0.2072, 5: 0.45, 10: 0.5612999999999999}, 'median_all': [6.0, 7.0, 7.5, 6.0, 9.0, 7.0, 8.5, 7.0, 8.0, 7.0]}
Epoch 0 loss: 0.9752261804894822


4400it [00:43, 101.56it/s]


Result: {'mean_median': 2.0, 'recall': {1: 0.45890000000000003, 5: 0.7373000000000001, 10: 0.8215999999999999}, 'median_all': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]}
Epoch 1 loss: 0.8567092926915761


4400it [00:42, 102.75it/s]


Result: {'mean_median': 2.0, 'recall': {1: 0.4511, 5: 0.7352, 10: 0.8240999999999999}, 'median_all': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]}
Epoch 2 loss: 0.7607392693248818


4400it [00:45, 96.83it/s] 


Result: {'mean_median': 2.0, 'recall': {1: 0.4064, 5: 0.6943, 10: 0.7861}, 'median_all': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]}
Epoch 3 loss: 0.7115090129781404


In [14]:
txt_model = load_model('text_model_triplet')
img_model = load_model('img_model_triplet')
with torch.no_grad():
    img_model.eval()
    txt_model.eval()
    r_text = txt_model(torch.tensor(text_test).to('cuda:1'))
    r_img = img_model(torch.tensor(img_test).to('cuda:1'))
    test_result = rank(r_text, r_img)
test_result

Result: {'mean_median': 2.0, 'recall': {1: 0.45339999999999997, 5: 0.7399, 10: 0.8193999999999999}, 'median_all': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]}


{'mean_median': 2.0,
 'recall': {1: 0.45339999999999997, 5: 0.7399, 10: 0.8193999999999999},
 'median_all': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]}

In [16]:
with open('/common/home/gg676/535/task_2/tsne_data/all_text_triplet.pkl', 'wb') as fp:
    pickle.dump(r_text, fp)
with open('/common/home/gg676/535/task_2/tsne_data/all_img_triplet.pkl', 'wb') as fp:
    pickle.dump(r_img, fp)