In [1]:
import os
import pickle
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [2]:
def load_pickled_data(path):
    with open(path, 'rb') as fp:
        data = pickle.load(fp)
    return data

In [3]:
features_train = load_pickled_data('/common/home/gg676/Downloads/embeddings_train1.pkl')
features_val = load_pickled_data('/common/home/gg676/Downloads/embeddings_val1.pkl')
features_test = load_pickled_data('/common/home/gg676/Downloads/embeddings_test1.pkl')

In [4]:
img_val = features_val[0]
text_val = features_val[1]
img_test = features_test[0]
text_test = features_test[1]

In [5]:
class FeatureDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return self.dataset[0].shape[0]
    
    def __getitem__(self, idx):
        text, img = self.dataset[1][idx], self.dataset[0][idx]
        return text, img

dataset = FeatureDataset(features_train)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=64)

In [6]:
class TextEncoder(nn.Module):
    def __init__(self,output_size,input_size=1024):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.LeakyReLU()
        )
        self.output = nn.Linear(512, output_size)
    
    def forward(self, x):
        x = self.layers(x)
        return self.output(x)

In [7]:
class ImgEncoder(nn.Module):
    def __init__(self,output_size,input_size=1024):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(),
            nn.LeakyReLU()
        )
        self.output = nn.Linear(512, output_size)
    
    def forward(self, x):
        x = self.layers(x)
        return self.output(x)

In [8]:
def rank(txt_data, img_data):
    results_dict = {}
    projection_txt, projection_img = txt_data, img_data
    
    idxs = range(1000)
    
    glob_rank = []
    glob_recall = {1:0.0,5:0.0,10:0.0}
    
    for i in range(10):
        ids = random.sample(range(0,txt_data.shape[0]-1), 1000)
        
        txt_sample = projection_txt[ids,:]
        img_sample = projection_img[ids,:]
        
        similarity = np.dot(txt_sample.cpu().numpy(), img_sample.T.cpu().numpy())

        med_rank = []
        
        recall = {1:0.0,5:0.0,10:0.0}
        
        for ii in idxs:
            # get a column of similarities
            sim = similarity[ii,:]
            # sort indices in descending order
            sorting = np.argsort(sim)[::-1].tolist()
            # find where the index of the pair sample ended up in the sorting
            pos = sorting.index(ii)  
            if (pos+1) == 1:
                recall[1]+=1
            if (pos+1) <=5:
                recall[5]+=1
            if (pos+1)<=10:
                recall[10]+=1
            # store the position
            med_rank.append(pos+1)
        for i in recall.keys():
            recall[i]=recall[i]/1000
        med = np.median(med_rank)
        for i in recall.keys():
            glob_recall[i]+=recall[i]
        glob_rank.append(med)

    for i in glob_recall.keys():
        glob_recall[i] = glob_recall[i]/10
    med_dict = {}
    med_dict["mean_median"] = np.average(glob_rank)
    med_dict["recall"] = glob_recall
    med_dict["median_all"] = glob_rank
    print("Result:",med_dict)
    return med_dict

In [9]:
def train(train_loader, img_model, txt_model, criterion, optimizer_txt, optimizer_img, epoch):
    img_model.train()
    txt_model.train()   
    running_loss = []
    optimizer_txt.zero_grad()
    optimizer_img.zero_grad()
    
    for i, (img_input_batch, text_input_batch) in tqdm(enumerate(train_loader)):
        img_input_batch = img_input_batch.to('cuda:1')
        text_input_batch = text_input_batch.to('cuda:1')
        
        
        optimizer_txt.zero_grad()
        optimizer_img.zero_grad()
        out_img_emb = img_model(img_input_batch)
        out_txt_emb = txt_model(text_input_batch)
       
        loss = criterion(out_txt_emb, out_img_emb)    
        loss.backward()
        optimizer_txt.step()
        optimizer_img.step()
        
        running_loss.append(loss.item())
    return sum(running_loss)/len(running_loss)

In [10]:
def save_model(model, file_name):
    torch.save(model, '/common/home/gg676/535/saved_models/'+file_name)

In [11]:
def load_model(file_name):
    model = torch.load('/common/home/gg676/535/saved_models/'+file_name)
    return model

In [12]:
no_epochs = 3
img_model = ImgEncoder(512).to('cuda:1')
txt_model = TextEncoder(512).to('cuda:1')
lowest_median_rank = 99.0
optimizer_txt = torch.optim.Adam(txt_model.parameters(), lr=1e-6, weight_decay=1e-7)
optimizer_img = torch.optim.Adam(img_model.parameters(), lr=1e-6, weight_decay=1e-7)
criterion = nn.MSELoss() 
for epoch in range(no_epochs):
    img_model.train()
    txt_model.train()
    train_loss = train(train_loader, img_model, txt_model, criterion, optimizer_txt, optimizer_img, epoch)
    print('  Epoch {} loss: {} {}'.format(epoch, train_loss, "\n"))
    with torch.no_grad():
        txt_model.eval()
        img_model.eval()
        out_text, out_img = txt_model(torch.tensor(text_val).to('cuda:1')), img_model(torch.tensor(img_val).to('cuda:1'))
        med_rank = rank(out_text,out_img)
        if med_rank['mean_median'] < lowest_median_rank:
            lowest_median_rank = med_rank['mean_median']
            save_model(txt_model, 'text_model_mse')
            save_model(img_model, 'img_model_mse')

4400it [00:34, 129.16it/s]


  Epoch 0 loss: 0.4023148629137061 

Result: {'mean_median': 18.5, 'recall': {1: 0.11789999999999998, 5: 0.29569999999999996, 10: 0.4025}, 'median_all': [18.0, 19.0, 18.0, 17.0, 20.0, 20.0, 18.0, 18.0, 18.0, 19.0]}


4400it [00:31, 138.12it/s]


  Epoch 1 loss: 0.3048621171645143 

Result: {'mean_median': 6.3, 'recall': {1: 0.22559999999999997, 5: 0.47840000000000005, 10: 0.5936}, 'median_all': [7.0, 6.0, 6.0, 6.0, 6.0, 8.0, 7.0, 6.0, 6.0, 5.0]}


4400it [00:31, 140.90it/s]


  Epoch 2 loss: 0.2589708590913903 

Result: {'mean_median': 5.65, 'recall': {1: 0.24210000000000004, 5: 0.4985000000000001, 10: 0.605}, 'median_all': [6.0, 5.0, 6.0, 6.0, 6.5, 5.0, 6.0, 5.0, 5.0, 6.0]}


In [13]:
txt_model = load_model('text_model_mse')
img_model = load_model('img_model_mse')

In [17]:
with torch.no_grad():
    txt_model.eval()
    img_model.eval()
    out_text_test, out_img_test = txt_model(torch.tensor(text_test).to('cuda:1')), img_model(torch.tensor(img_test).to('cuda:1'))
    med_rank = rank(out_text_test,out_img_test)

Result: {'mean_median': 6.05, 'recall': {1: 0.23230000000000003, 5: 0.4867, 10: 0.6017}, 'median_all': [6.0, 6.0, 5.0, 7.0, 7.0, 6.5, 7.0, 6.0, 5.0, 5.0]}


In [18]:
with open('/common/home/gg676/535/task_2/tsne_data/all_text.pkl', 'wb') as fp:
    pickle.dump(out_text_test, fp)
with open('/common/home/gg676/535/task_2/tsne_data/all_img.pkl', 'wb') as fp:
    pickle.dump(out_img_test, fp)