In [2]:
from datasets import load_dataset
import random
import re
from tqdm import tqdm
import pickle
import pandas as pd
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
embedded_marco_train = pd.read_csv("embedde_marco_train.csv")
embedded_marco_val = pd.read_csv("embedde_marco_val.csv")
embedded_marco_test = pd.read_csv("embedde_marco_test.csv")

In [4]:
class twoTowerDataSet(Dataset):
    def __init__(self,marco_splt):
        super().__init__()
        self.data = marco_splt
        self.len = self.data.shape[0]
    
    def __len__(self):
        return len(self.len)
    
    def __getitem__(self, idx):
        randIdx = random.randint(0, self.len)
        return self.data[idx][1], self.data[idx][self.data[idx][2].index(1)], self.data[randIdx][3][1]

In [None]:
class twoTowerModel(nn.Module):
    def __init__(self, embedding_dim=256):
        super.__init__()
        self.first = nn.Linear(embedding_dim, 128)
        self.second = nn.Linear(128, 64)
        self.third = nn.Linear(64,1)

    def forward(self,X):
        outOne = self.first(X)
        reluOne = nn.ReLU(outOne)
        outTwo = self.second(reluOne)
        reluTwo = nn.ReLU(outTwo)
        outThree = self.third(reluTwo)

        return(outThree)
    
def contrastiveLoss(query, pos, neg, m=0.6):
    cosine_sim_pos = (np.dot(query,pos)/(np.multiply(abs(query),abs(pos))))
    cosine_sim_neg = (np.dot(query,neg)/(np.multiply(abs(query),abs(neg))))
    return max(0, m - cosine_sim_pos + cosine_sim_neg)

In [5]:
epoch_val_loss_history = []
epoch_train_loss_history = []

In [None]:
def train(batchSize=64, numEpochs= 20, lr=1e-3 ):
    device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = twoTowerDataSet(embedded_marco_test)
    val_dataset = twoTowerDataSet(embedded_marco_val)

    train_loader = DataLoader(train_dataset,batch_size=batchSize, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batchSize, shuffle=False, num_workers=2)

    passage_model = twoTowerModel().to(device)
    query_model = twoTowerModel().to(device)

    optimizer = optim.Adam((passage_model.parameters(),query_model.parameters()))

    best_loss = float("inf")

    epoch_counter = 0

    for epoch in range(numEpochs):
        #epoch_counter += 1
        epoch_train_loss = 0.0

        passage_model.train()
        query_model.train()
        
        for query, pos, neg in tqdm(train_loader):
            query = query.to(device)
            pos = pos.to(device)
            neg = neg.to(device)

            optimizer.zero_grad()

            embedded_query = query_model(query)
            embedded_pos = passage_model(pos)
            embedded_neg = passage_model(neg)

            loss = contrastiveLoss(embedded_query, embedded_pos, embedded_neg)
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()

        epoch_val_loss = 0.0

        passage_model.eval()
        passage_model.eval()

        with torch.no_grad():
            for query, pos, neg in tqdm(train_loader):
                query = query.to(device)
                pos = pos.to(device)
                neg = neg.to(device)

                embedded_query = query_model(query)
                embedded_pos = passage_model(pos)
                embedded_neg = passage_model(neg)

                loss = contrastiveLoss(embedded_query, embedded_pos, embedded_neg)

                epoch_val_loss += loss.item()
        avg_train_loss = epoch_train_loss/len(train_loader)
        avg_val_loss = epoch_val_loss/len(val_loader)

        print(f"\nEpoch {epoch+1}/{numEpochs} — " f"Train Loss: {avg_train_loss} | Val Loss: {avg_val_loss}")
        
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save(query_model.state_dict(), f'checkpoints/bestQuery.pt')
            torch.save(passage_model.state_dict(), f'checkpoints/bestPassage.pt')
            print(f"Model improved. Saved.")




    


