In [1]:
import os
import pickle
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np

In [2]:
device = torch.device("cuda:3")

In [3]:
class FeatureDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset[0])
    
    def __getitem__(self, idx):
        img, text, label = self.dataset[0][idx],self.dataset[4][idx],self.dataset[3][idx]
        
        return img, text, label

In [4]:

class Encoder(nn.Module):
    def __init__(self,output_size,input_size=1024):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Linear(input_size,512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.LeakyReLU()
        )
        self.layer2 = nn.Linear(512,output_size)
    
    def forward(self, x):
        x = self.layer1(x)
        return self.layer2(x)

In [5]:
def load_pickled_data(path):
    with open(path, 'rb') as fp:
        data = pickle.load(fp)
    return data

In [6]:
img_text_features_train_LABELLED = load_pickled_data('/common/home/gg676/535/data/labelled_data/TITLE_embeddings_train_LABELLED.pkl')
features_val = load_pickled_data('/common/home/gg676/535/data/labelled_data/TITLE_embeddings_val_LABELLED.pkl')
features_test = load_pickled_data('/common/home/gg676/535/data/labelled_data/TITLE_embeddings_test_LABELLED.pkl')

img_val = features_val[0]
text_val = features_val[4]

img_test = features_test[0]
text_test = features_test[4]


In [7]:
len(img_text_features_train_LABELLED)

5

In [8]:

dataset = FeatureDataset(img_text_features_train_LABELLED)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=60)

In [9]:
def rank(txt_data, img_data):
    results_dict = {}
    projection_txt, projection_img = txt_data, img_data
    med_dict = {}
    idxs = range(1000)
    
    glob_rank = []
    glob_recall = {1:0.0,5:0.0,10:0.0}
    
    for i in range(10):
        ids = random.sample(range(0,txt_data.shape[0]-1), 1000)
        
        txt_sample = projection_txt[ids,:]
        img_sample = projection_img[ids,:]
        
        similarity = np.dot(txt_sample.cpu().numpy(), img_sample.T.cpu().numpy())

        med_rank = []
        
        recall = {1:0.0,5:0.0,10:0.0}
        
        for ii in idxs:
            # get a column of similarities
            sim = similarity[ii,:]
            # sort indices in descending order
            sorting = np.argsort(sim)[::-1].tolist()
            # find where the index of the pair sample ended up in the sorting
            pos = sorting.index(ii)  
            if (pos+1) == 1:
                recall[1]+=1
            if (pos+1) <=5:
                recall[5]+=1
            if (pos+1)<=10:
                recall[10]+=1
            # store the position
            med_rank.append(pos+1)
        for i in recall.keys():
            recall[i]=recall[i]/1000
        med = np.median(med_rank)
        for i in recall.keys():
            glob_recall[i]+=recall[i]
        glob_rank.append(med)

    for i in glob_recall.keys():
        glob_recall[i] = glob_recall[i]/10
    
    med_dict["mean_median"] = np.average(glob_rank)
    med_dict["recall"] = glob_recall
    med_dict["median_all"] = glob_rank
    print("Result:",med_dict)
    return med_dict

In [10]:
model_img = Encoder(512)
model_text = Encoder(512)
optimizer = optim.Adam(list(model_text.parameters()) + list(model_img.parameters()) , lr=1e-6)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, verbose=True)
loss = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))

In [11]:
def get_hard_negative(anchor,positive,label):
    labels_equal = torch.eq(label.unsqueeze(0),label.unsqueeze(1))
    mask_neg = torch.logical_not(labels_equal)
    distance_matrix = torch.matmul(anchor, positive.T)
    anchor_negative_dist = mask_neg * distance_matrix
    hard_negative_idx = torch.argmax(anchor_negative_dist, dim=1)
    hard_negative = torch.index_select(positive, 0, hard_negative_idx)
    
    return hard_negative

In [12]:
from tqdm import tqdm
def train(train_loader, img_model, txt_model, loss, optimizer): 
    total_loss = 0.0       
    for i, (img_input_batch, text_input_batch,label) in tqdm(enumerate(train_loader)):
        img_input_batch = img_input_batch.to(device)
        text_input_batch = text_input_batch.to(device)
        label = label.to(device)
        hard_negative_emb = get_hard_negative(img_input_batch,text_input_batch,label)
        optimizer.zero_grad()
        anchor = img_model(img_input_batch.to(device)) 
        positive = txt_model(text_input_batch.to(device))
        negative = txt_model(hard_negative_emb.to(device))
        curr_loss = loss(anchor,positive,negative) 
        curr_loss.backward()
        optimizer.step()
        total_loss += curr_loss.item()
    return total_loss/i

In [13]:
def save_model(model, file_name):
    torch.save(model, '/common/home/gg676/535/saved_models/'+file_name)

In [14]:
def load_model(file_name):
    model = torch.load('/common/home/gg676/535/saved_models/'+file_name)
    return model

In [15]:
no_epochs = 15
model_img = model_img.to(device)
model_text = model_text.to(device)
model_img.train()
model_text.train()
lowest_median_rank = 99.0
for epoch in range(no_epochs):
    train_loss = train(train_loader, model_img, model_text,loss,optimizer)
    with torch.no_grad():
        model_img.eval()
        model_text.eval()
        out_text, out_img = model_text(torch.tensor(text_val).to('cuda:3')), model_img(torch.tensor(img_val).to('cuda:3'))
        med_rank = rank(out_text,out_img)
        if med_rank['mean_median'] < lowest_median_rank:
            lowest_median_rank = med_rank['mean_median']
            save_model(model_text, 'title_model_triplet')
            save_model(model_img, 'img_title_model_triplet')
    scheduler.step(med_rank['recall'][1])
    print('Epoch {} loss: {}'.format(epoch, train_loss))

4400it [00:33, 132.85it/s]


Result: {'mean_median': 225.55, 'recall': {1: 0.011, 5: 0.04189999999999999, 10: 0.0671}, 'median_all': [233.5, 219.5, 232.0, 228.5, 209.5, 218.0, 240.0, 205.0, 243.5, 226.0]}
Epoch 0 loss: 0.993343964412175


4400it [00:24, 183.22it/s]


Result: {'mean_median': 149.45, 'recall': {1: 0.03779999999999999, 5: 0.11159999999999999, 10: 0.16209999999999997}, 'median_all': [146.5, 162.0, 146.5, 136.5, 166.5, 126.0, 163.0, 152.0, 142.5, 153.0]}
Epoch 1 loss: 0.9470708586189199


4400it [00:23, 189.92it/s]


Result: {'mean_median': 276.25, 'recall': {1: 0.0319, 5: 0.09109999999999999, 10: 0.1291}, 'median_all': [268.5, 284.0, 267.5, 281.0, 277.0, 270.5, 285.5, 289.5, 263.5, 275.5]}
Epoch 2 loss: 0.8729838645981236


4400it [00:22, 192.86it/s]


Result: {'mean_median': 331.0, 'recall': {1: 0.0249, 5: 0.0721, 10: 0.1051}, 'median_all': [343.0, 343.0, 341.5, 325.5, 311.0, 346.5, 336.0, 298.5, 341.0, 324.0]}
Epoch 3 loss: 0.8165228106022207


4400it [00:24, 182.29it/s]


Result: {'mean_median': 324.65, 'recall': {1: 0.0211, 5: 0.06399999999999999, 10: 0.09339999999999998}, 'median_all': [298.0, 306.5, 351.0, 289.5, 331.5, 346.5, 318.5, 319.0, 340.0, 346.0]}
Epoch     5: reducing learning rate of group 0 to 1.0000e-07.
Epoch 4 loss: 0.7908403161412668


4400it [00:23, 190.09it/s]


Result: {'mean_median': 331.5, 'recall': {1: 0.021899999999999996, 5: 0.06220000000000001, 10: 0.09559999999999999}, 'median_all': [332.5, 328.5, 376.0, 329.0, 331.5, 319.0, 337.5, 348.5, 306.5, 306.0]}
Epoch 5 loss: 0.7826489777450102


4400it [00:22, 193.14it/s]


Result: {'mean_median': 313.95, 'recall': {1: 0.022099999999999995, 5: 0.0623, 10: 0.09469999999999998}, 'median_all': [299.0, 326.5, 307.5, 339.0, 289.0, 305.5, 322.0, 352.5, 308.5, 290.0]}
Epoch 6 loss: 0.7812105380667695


4400it [00:23, 188.80it/s]


Result: {'mean_median': 315.3, 'recall': {1: 0.017799999999999996, 5: 0.06339999999999998, 10: 0.09329999999999998}, 'median_all': [308.0, 297.5, 334.0, 321.0, 290.0, 314.0, 346.0, 333.5, 310.0, 299.0]}
Epoch     8: reducing learning rate of group 0 to 1.0000e-08.
Epoch 7 loss: 0.7803765657842905


4400it [00:22, 193.48it/s]


Result: {'mean_median': 310.6, 'recall': {1: 0.02, 5: 0.06510000000000002, 10: 0.09489999999999998}, 'median_all': [281.5, 301.5, 305.5, 298.5, 297.5, 309.0, 333.0, 320.5, 317.5, 341.5]}
Epoch 8 loss: 0.7783861455930366


4400it [00:23, 189.76it/s]


Result: {'mean_median': 321.4, 'recall': {1: 0.0209, 5: 0.06130000000000001, 10: 0.09269999999999998}, 'median_all': [307.5, 321.0, 325.5, 301.5, 325.0, 338.0, 313.0, 308.0, 353.0, 321.5]}
Epoch 9 loss: 0.7790049427502479


4400it [00:22, 193.25it/s]


Result: {'mean_median': 310.1, 'recall': {1: 0.0204, 5: 0.06340000000000001, 10: 0.09239999999999998}, 'median_all': [298.5, 332.0, 311.0, 296.5, 310.0, 320.0, 316.5, 302.5, 311.0, 303.0]}
Epoch 10 loss: 0.7785375781237036


4400it [00:23, 185.15it/s]


Result: {'mean_median': 322.7, 'recall': {1: 0.019299999999999998, 5: 0.06180000000000001, 10: 0.09359999999999999}, 'median_all': [296.5, 326.5, 367.0, 328.0, 311.0, 328.0, 260.0, 328.5, 341.5, 340.0]}
Epoch 11 loss: 0.7794914712524328


4400it [00:22, 193.80it/s]


Result: {'mean_median': 315.45, 'recall': {1: 0.0197, 5: 0.06080000000000001, 10: 0.08759999999999998}, 'median_all': [303.5, 337.0, 302.5, 302.0, 306.5, 341.0, 340.5, 327.5, 306.0, 288.0]}
Epoch 12 loss: 0.7789189894650844


4400it [00:23, 188.33it/s]


Result: {'mean_median': 332.4, 'recall': {1: 0.0215, 5: 0.0639, 10: 0.09429999999999998}, 'median_all': [351.0, 314.5, 293.0, 370.0, 332.5, 354.0, 325.5, 331.0, 346.5, 306.0]}
Epoch 13 loss: 0.7795453060077087


4400it [00:22, 192.95it/s]


Result: {'mean_median': 323.55, 'recall': {1: 0.0195, 5: 0.0633, 10: 0.09369999999999998}, 'median_all': [331.0, 336.5, 313.5, 320.5, 359.0, 340.0, 341.0, 301.5, 304.0, 288.5]}
Epoch 14 loss: 0.7790868377707226


In [None]:
txt_model = load_model('text_model_triplet')
img_model = load_model('img_model_triplet')
with torch.no_grad():
    img_model.eval()
    txt_model.eval()
    r_text = txt_model(torch.tensor(text_test).to('cuda:1'))
    r_img = img_model(torch.tensor(img_test).to('cuda:1'))
    test_result = rank(r_text, r_img)
test_result

In [None]:
with open('/common/home/gg676/535/task_2/tsne_data/title_text_triplet.pkl', 'wb') as fp:
    pickle.dump(r_text, fp)
with open('/common/home/gg676/535/task_2/tsne_data/title_img_triplet.pkl', 'wb') as fp:
    pickle.dump(r_img, fp)