In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets, transforms
import numpy as np
import pandas as pd
from scipy.spatial import distance
from statistics import median
from torch.utils.data import Dataset, DataLoader
import pickle 
from img2vec import Img2Vec
from PIL import Image
from gensim.models import Word2Vec

### Base Model 

In [None]:
class Encoder(nn.Module):

    def __init__(self, input_vec_dim):
        
        super(Encoder, self).__init__()
        
        self.FC1_img = nn.Linear(input_vec_dim, 300)
        self.FC2_img = nn.Linear(300, 200)
        self.FC3_img = nn.Linear(200, 50)
        
        self.FC1_txt = nn.Linear(input_vec_dim, 300)
        self.FC2_txt = nn.Linear(300, 200)
        self.FC3_txt = nn.Linear(200, 50)

    def forward(self, img, txt):

        x = F.relu(self.FC1_img(img))
        x = F.relu(self.FC2_img(x))
        x = F.relu(self.FC3_img(x))

        y = F.relu(self.FC1_txt(txt))
        y = F.relu(self.FC2_txt(y))
        y = F.relu(self.FC3_txt(y))

        return F.relu(torch.add(x,y))
    
class Decoder(nn.Module):

    def __init__(self, output_vec_dim):
        
        super(Decoder, self).__init__()
        
        self.FC1_img = nn.Linear(output_vec_dim, 200)
        self.FC2_img = nn.Linear(200, 300)
        self.FC3_img = nn.Linear(300, 512)
        
        self.FC1_txt = nn.Linear(output_vec_dim, 200)
        self.FC2_txt = nn.Linear(200, 300)
        self.FC3_txt = nn.Linear(300, 512)

    def forward(self, rep):

        x = F.relu(self.FC1_img(rep))
        x = F.relu(self.FC2_img(x))
        x = F.relu(self.FC3_img(x))

        y = F.relu(self.FC1_txt(rep))
        y = F.relu(self.FC2_txt(y))
        y = F.relu(self.FC3_txt(y))

        combined=F.relu(torch.cat((x, y), 1))
        return combined

In [None]:
class Corrnet(nn.Module):
    def __init__(self, input_vec_dim, latent_rep_dim):
        super(Corrnet,self).__init__()
        self.encoder = Encoder(input_vec_dim)
        self.decoder = Decoder(latent_rep_dim)

    def forward(self, img, txt):
        latent_rep = self.encoder(img, txt)
        combined = self.decoder(latent_rep)
        return combined


class CorrnetDataset(Dataset):
    
    def __init__(self,img,txt):
        self.img = img
        self.txt = txt
        if self.img.shape[0]!=self.txt.shape[0]:
            raise Exception("Different no. of samples")
            
    def __len__(self):
        return self.img.shape[0]
    
    def __getitem__(self, index):
        _img = self.img[index]
        _txt = self.txt[index]

        return _img, _txt

In [None]:
def correlation(x, y, lamda=0.02):

    '''
      x, y are n x 50 dimensional vectors obtained from the respective n x 512 embeddings
    '''

    x_mean = torch.mean(x, dim = 0) # Along the y-axis, that is, average of all feature vectors
    y_mean = torch.mean(y, dim = 0) # 1 x 50 dimensional
    x_centered = torch.sub(x, x_mean) # calculates xi - X_mean n x 50 dimensional
    y_centered = torch.sub(y, y_mean) # calculates yi - Y_mean
    corr_nr = torch.sum(torch.mul(x_centered, y_centered)) # The numerator
    # print(list(corr_nr.shape))
    corr_dr1 = torch.sqrt(torch.sum(torch.square(x_centered)))
    corr_dr2 = torch.sqrt(torch.sum(torch.square(y_centered)))
    corr_dr = corr_dr1 * corr_dr2
    corr = -lamda * corr_nr / corr_dr
    # print(corr.item()) # Should decrease ideally
    return corr

In [None]:
# image_vectors = []
# with open ('./files/features', 'rb') as fp:
#     image_vectors = pickle.load(fp)
# image_vectors = np.array(image_vectors)


# text_vectors = []
# with open ('./files/text_embeddings', 'rb') as fp:
#     text_vectors = pickle.load(fp) 
# text_vectors = np.array(text_vectors)
dataset_train = pd.read_pickle('dataset_img_text_train')
train_image_vectors = np.array(list(dataset_train['img_vec']))
train_text_vectors = np.array(list(dataset_train['text_vec']))

dataset_test = pd.read_pickle('dataset_img_text_test')
test_image_vectors = np.array(list(dataset_test['img_vec']))
test_text_vectors = np.array(list(dataset_test['text_vec']))

val_size = int((0.1 / 0.85) * len(train_image_vectors))
train_size = len(train_image_vectors) - val_size
test_size = len(test_text_vectors)

img_train = torch.from_numpy(train_image_vectors[:train_size].astype(np.float32))
txt_train = torch.from_numpy(train_text_vectors[:train_size].astype(np.float32))
train_dataset = DataLoader(CorrnetDataset(img_train, txt_train), batch_size=64, shuffle=True)

img_val = torch.from_numpy(train_image_vectors[train_size:(train_size + val_size)].astype(np.float32))
txt_val = torch.from_numpy(train_text_vectors[train_size:(train_size + val_size)].astype(np.float32))
val_dataset = DataLoader(CorrnetDataset(img_val, txt_val), batch_size=64, shuffle=True)

img_test = torch.from_numpy(test_image_vectors.astype(np.float32))
txt_test = torch.from_numpy(test_text_vectors.astype(np.float32))
test_dataset = DataLoader(CorrnetDataset(img_test, txt_test), batch_size=64, shuffle=True)
print(train_size)
print(val_size)
print(test_size)

65301
8706
13053


In [None]:
corrnet = Corrnet(512,50)
optimizer = optim.Adam(corrnet.parameters(), lr=0.001)
load_pretrained = True
model_save_path = './result/model_state_third.pt'

if (load_pretrained and os.path.exists(model_save_path)):
      corrnet.load_state_dict(torch.load(model_save_path))

criterion = nn.MSELoss()

In [None]:
def evaluate_model(model, val_dataset):
    model.eval()
    L = []
    err = [[], [], [], []]
    for img, txt in val_dataset:

        # img-> 224*224*3 array
        # txt -> string
        concat_inputs = torch.cat((img,txt),1)
        res_combined_input=corrnet(img,txt)
        res_img_input=corrnet(img,torch.zeros_like(txt))
        res_txt_input=corrnet(torch.zeros_like(img),txt)

        err1 = criterion(res_combined_input,concat_inputs)
        err2 = criterion(res_img_input,concat_inputs)
        err3 = criterion(res_txt_input,concat_inputs)
        err4 = correlation(
            corrnet.encoder(img, torch.zeros_like(txt)),
            corrnet.encoder(torch.zeros_like(img), txt)
        )
        loss = (err1 + err2 + err3 + err4)
        
        L.append(loss.item())
        err[0].append(err1.item())
        err[1].append(err2.item())
        err[2].append(err3.item())
        err[3].append(err4.item())

    print("Epoch: {}:, Val Loss: {}".format(e, np.mean(L)))
    for i in range(len(err)):
        print("err{}: {}".format(i+1,np.mean(err[i])),end="\t")
    print('/n')
    model.train()
    return np.mean(L)

In [None]:
epochs = 150
best_val_loss = 1000
for e in range(epochs):
    ind = 1
    L = []
    err = [[], [], [], []]
    for img, txt in train_dataset:

        # img-> 224*224*3 array
        # txt -> string
        concat_inputs = torch.cat((img,txt),1)
        optimizer.zero_grad()

        res_combined_input=corrnet(img,txt)
        res_img_input=corrnet(img,torch.zeros_like(txt))
        res_txt_input=corrnet(torch.zeros_like(img),txt)

        err1 = criterion(res_combined_input,concat_inputs)
        err2 = criterion(res_img_input,concat_inputs)
        err3 = criterion(res_txt_input,concat_inputs)
        err4 = correlation(
            corrnet.encoder(img, torch.zeros_like(txt)),
            corrnet.encoder(torch.zeros_like(img), txt)
        )
            
        loss = (err1 + err2 + err3 + err4)
        loss.backward()
        L.append(loss.item())
        err[0].append(err1.item())
        err[1].append(err2.item())
        err[2].append(err3.item())
        err[3].append(err4.item())
        optimizer.step()

        ind+=1
        
    val_loss = evaluate_model(corrnet, val_dataset)
    print("Epoch: {}:, Train Loss: {}".format(e, np.mean(L)))
    for i in range(len(err)):
        print("err{}: {}".format(i+1,np.mean(err[i])),end="\t")
    print("\n")

    if(e%5 == 0 and val_loss < best_val_loss):
        best_val_loss = val_loss
        print('model_saved')
        torch.save(corrnet.state_dict(),model_save_path + '_epoch_' + str(epochs))

Epoch: 0:, Val Loss: 0.9985712134476864
err1: 0.3310464927644441	err2: 0.33081432635133917	err3: 0.3518437345822652	err4: -0.015133342762110811	/n
Epoch: 0:, Train Loss: 1.1142857383126799
err1: 0.36933847484381305	err2: 0.36839324676472207	err3: 0.3889300369698068	err4: -0.012376014598387131	

model_saved
Epoch: 1:, Val Loss: 0.7582643104322029
err1: 0.23685420327114337	err2: 0.23575549821058908	err3: 0.3000434519666614	err4: -0.014388840024669966	/n
Epoch: 1:, Train Loss: 0.8339816225611645
err1: 0.26723964007004447	err2: 0.2668178075681562	err3: 0.3142921682285226	err4: -0.014367995382813008	

Epoch: 2:, Val Loss: 0.7019507018002596
err1: 0.2120134144118338	err2: 0.21175435817602908	err3: 0.2928014388590148	err4: -0.014618521104707863	/n
Epoch: 2:, Train Loss: 0.7285153440807177
err1: 0.22449586968059126	err2: 0.22397436888321587	err3: 0.2946596806463988	err4: -0.014614574918928354	

Epoch: 3:, Val Loss: 0.6715728785052444
err1: 0.19922717412312826	err2: 0.19895951585336166	err3: 0.

Epoch: 28:, Val Loss: 0.5399116042888525
err1: 0.1486329689170375	err2: 0.14838983976479733	err3: 0.2596664966055841	err4: -0.016777697583716927	/n
Epoch: 28:, Train Loss: 0.5126942135717558
err1: 0.13885627220506253	err2: 0.1386207413090312	err3: 0.25233710330465564	err4: -0.0171199023399664	

Epoch: 29:, Val Loss: 0.5413086152437961
err1: 0.14977662129835648	err2: 0.1493824521700541	err3: 0.2590049591028329	err4: -0.01685541396904172	/n
Epoch: 29:, Train Loss: 0.5106323443029238
err1: 0.13831749678305957	err2: 0.13811634322223457	err3: 0.2513805943338767	err4: -0.01718209063148369	

Epoch: 30:, Val Loss: 0.539765067172773
err1: 0.14903862503441898	err2: 0.14867750397234253	err3: 0.25888681276278064	err4: -0.016837866553528744	/n
Epoch: 30:, Train Loss: 0.5096019464990367
err1: 0.13804954257996185	err2: 0.137862770324168	err3: 0.2508775279573772	err4: -0.01718789478359015	

model_saved
Epoch: 31:, Val Loss: 0.5390081920407035
err1: 0.148819211305994	err2: 0.14832744273272427	err3: 0.2

Epoch: 56:, Val Loss: 0.5350042175162922
err1: 0.1495871783205957	err2: 0.14880124321489624	err3: 0.2536682974208485	err4: -0.017052504149350254	/n
Epoch: 56:, Train Loss: 0.47958291380301765
err1: 0.1302347217888936	err2: 0.1300309045807175	err3: 0.23687762248775232	err4: -0.01756033360067269	

Epoch: 57:, Val Loss: 0.5323901618971969
err1: 0.14863487920074753	err2: 0.14797581393610348	err3: 0.2528452927416021	err4: -0.017065835213570885	/n
Epoch: 57:, Train Loss: 0.47916808141314465
err1: 0.13036247407612594	err2: 0.13018515699583552	err3: 0.2362073676093765	err4: -0.0175869165757752	

Epoch: 58:, Val Loss: 0.5353233570402319
err1: 0.14916168740301422	err2: 0.14860627348675873	err3: 0.254592752366355	err4: -0.017037363243148182	/n
Epoch: 58:, Train Loss: 0.4777578572864118
err1: 0.12994434299028437	err2: 0.12976015158321547	err3: 0.23563961950333223	err4: -0.01758625640574357	

Epoch: 59:, Val Loss: 0.5321552934068622
err1: 0.14898248649004733	err2: 0.1483590973146034	err3: 0.2519483

Epoch: 84:, Val Loss: 0.5350510649608843
err1: 0.15049887561436856	err2: 0.14986219053918665	err3: 0.25187345029729785	err4: -0.017183448893554283	/n
Epoch: 84:, Train Loss: 0.46422357053860375
err1: 0.12675996720790864	err2: 0.1265585326306198	err3: 0.2286167192718257	err4: -0.017711648425978162	

Epoch: 85:, Val Loss: 0.5348100544828357
err1: 0.15071158562645767	err2: 0.14980008385398172	err3: 0.2514559202121966	err4: -0.017157533826927345	/n
Epoch: 85:, Train Loss: 0.4632553973923559
err1: 0.12670909665201022	err2: 0.12653544778409212	err3: 0.22775183028501014	err4: -0.01774097833296527	

Epoch: 86:, Val Loss: 0.5348374021775795
err1: 0.15057174003485477	err2: 0.14992582662539047	err3: 0.2515312478397832	err4: -0.017191409189818485	/n
Epoch: 86:, Train Loss: 0.4631851600564044
err1: 0.12662621639345004	err2: 0.12643327900896903	err3: 0.22787436084902804	err4: -0.017748696863165367	

Epoch: 87:, Val Loss: 0.5340784444953456
err1: 0.15078953766461575	err2: 0.15005340540047848	err3: 0.

Epoch: 112:, Val Loss: 0.5373450002887032
err1: 0.15236158240022082	err2: 0.1516470606579925	err3: 0.25044133536743396	err4: -0.017104973480331177	/n
Epoch: 112:, Train Loss: 0.45399432195269546
err1: 0.12441667447919431	err2: 0.1242288378269776	err3: 0.22317851606918418	err4: -0.01782970705434032	

Epoch: 113:, Val Loss: 0.534771153421113
err1: 0.1510529454910394	err2: 0.15034951766331991	err3: 0.25056936795061285	err4: -0.01720067771208106	/n
Epoch: 113:, Train Loss: 0.45386389688305234
err1: 0.12444399371743202	err2: 0.12426948327085247	err3: 0.22297919781311698	err4: -0.017828776671186736	

Epoch: 114:, Val Loss: 0.5381377899285519
err1: 0.15145020548141364	err2: 0.1506789457617384	err3: 0.2531582482836463	err4: -0.017149613351758682	/n
Epoch: 114:, Train Loss: 0.4538912206888199
err1: 0.12429700691414916	err2: 0.12416444239409073	err3: 0.2232609148906625	err4: -0.017831145081183185	

Epoch: 115:, Val Loss: 0.5356501765323408
err1: 0.15080551035476453	err2: 0.1499731960621747	err3:

Epoch: 140:, Val Loss: 0.5367931804873727
err1: 0.15229988052989496	err2: 0.15165729956193405	err3: 0.2501782937483354	err4: -0.017342300916259937	/n
Epoch: 140:, Train Loss: 0.4474583811086157
err1: 0.12278534403961638	err2: 0.12262547991198042	err3: 0.2199489960204	err4: -0.01790143681447143	

Epoch: 141:, Val Loss: 0.5393565920266238
err1: 0.15299149896159317	err2: 0.15237329809954672	err3: 0.2512542659586126	err4: -0.017262478443709286	/n
Epoch: 141:, Train Loss: 0.4471494551586068
err1: 0.12267978641649951	err2: 0.12254201432932978	err3: 0.21979228951360869	err4: -0.0178646352061111	

Epoch: 142:, Val Loss: 0.5359348365754792
err1: 0.15215286522200613	err2: 0.15156695743401846	err3: 0.24940314843799127	err4: -0.017188130877912045	/n
Epoch: 142:, Train Loss: 0.4473518746054691
err1: 0.12266906610001689	err2: 0.12257006864832795	err3: 0.22003383940976598	err4: -0.017921098750894485	

Epoch: 143:, Val Loss: 0.5374732143951185
err1: 0.15271116206140228	err2: 0.152054358160857	err3: 0.

In [None]:
def predict(img, txt):
    img_vecs = corrnet.encoder(img, torch.zeros_like(txt))
    txt_vecs = corrnet.encoder(torch.zeros_like(img), txt)

    euc = []
    cos = []
    for img_vec, txt_vec in zip(img_vecs, txt_vecs):
        euc.append(distance.euclidean(img_vec.cpu().detach().numpy(), txt_vec.cpu().detach().numpy()))
        cos.append(distance.cosine(img_vec.cpu().detach().numpy(), txt_vec.cpu().detach().numpy()))

    return np.array(euc), np.array(cos)

def print_metrics(img_test, txt_test):
    mr = []
    top_1_count = 0
    top_5_count = 0
    top_10_count = 0
    test_size = len(img_test)
    for i in range(test_size):
        if i % 100 == 0:
            print(i)
        img_array = np.zeros((test_size, 512))
        for k in range(test_size):
            img_array[k] = img_test[k]

        txt_array = np.zeros((test_size, 512))
        for j in range(test_size):
            txt_array[j] = txt_test[i]

        predictions = list(
            predict(torch.from_numpy(txt_array.astype(np.float32)), torch.from_numpy(img_array.astype(np.float32)))[1])
        pred_i = predictions[i]
        predictions.sort()
        rank = predictions.index(pred_i)
        if rank < 10:
            top_10_count += 1
        if rank < 5:
            top_5_count += 1
        if rank < 1:
            top_1_count += 1
        mr.append(rank + 1)

    print('Median Rank(txt->img):', median(mr) * 100 / test_size, '%')
    print('R@1(txt->img):', top_1_count * 100 / test_size, '%')
    print('R@5(txt->img):', top_5_count * 100 / test_size, '%')
    print('R@10(txt->img):', top_10_count * 100 / test_size, '%')
    print(top_1_count)
    print(top_5_count)
    print(top_10_count)

In [None]:
print_metrics(img_test, txt_test)