In [29]:
from datasets import load_dataset
import random
import re
from tqdm import tqdm
import pickle
import pandas as pd
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from CBOW import CBOW
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)


In [3]:
#marco_test = np.load("marco_test.npy",allow_pickle=True)
marco_val = np.load("marco_val.npy",allow_pickle=True)
marco_train = np.load("marco_train.npy",allow_pickle=True)

In [30]:
embedder = CBOW()
embedder.load_state_dict(torch.load("checkpoints/best.pt", weights_only=True))
embedder.eval()

CBOW(
  (embed): Embedding(76289, 128)
  (lin): Linear(in_features=128, out_features=76289, bias=True)
)

In [70]:
class twoTowerDataSet(Dataset):
    def __init__(self,marco_splt):
        self.data = marco_splt
        self.len = self.data.shape[0]
    
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        randIdx = random.randint(0, self.len)
        row = self.data[idx]
        query = torch.tensor(row[1])
        selected = row[2].index(1)
        pos = torch.tensor(row[selected+2])
        neg = torch.tensor(self.data[randIdx][3])
        
        query = torch.mean(embedder.embed(query),dim=0)
        pos = torch.mean(embedder.embed(pos),dim=0)
        neg = torch.mean(embedder.embed(neg),dim=0)

        # print ("Query", query.shape)
        # print ("pos", pos.shape)
        # print ("neg", neg.shape)


        return query, pos, neg

In [68]:
class twoTowerModel(nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64,32)
            )

    def forward(self,X):
        out = self.seq(X)

        return(out)
    
def contrastiveLoss(query, pos, neg, m=0.6):
    query = query[0]
    pos = pos[0]
    neg = neg[0]
    cosine_sim_pos = (torch.dot(query,pos)/(torch.multiply(torch.linalg.norm(query),torch.linalg.norm(pos))))
    cosine_sim_neg = (torch.dot(query,neg)/(torch.multiply(torch.linalg.norm(query),torch.linalg.norm(neg))))
    # print ("pos",torch.dot(query,pos))
    # print (torch.linalg.norm(query),torch.linalg.norm(pos))
    # print ("neg", cosine_sim_neg)
    return max(0, (m - cosine_sim_pos + cosine_sim_neg))

In [6]:
epoch_val_loss_history = []
epoch_train_loss_history = []

In [None]:
def train(batchSize=2, numEpochs= 20, lr=1e-3 ):
    device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = twoTowerDataSet(marco_train)
    val_dataset = twoTowerDataSet(marco_val)

    train_loader = DataLoader(train_dataset,batch_size=batchSize, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batchSize, shuffle=False)

    passage_model = twoTowerModel().to(device)
    query_model = twoTowerModel().to(device)

    optimizer = optim.Adam(list(passage_model.parameters()) + list(query_model.parameters()),lr=lr)

    best_loss = float("inf")

    epoch_counter = 0

    for epoch in range(numEpochs):
        #epoch_counter += 1
        epoch_train_loss = 0.0

        passage_model.train()
        query_model.train()
        
        for query, pos, neg in tqdm(train_loader):
            query = query.to(device)
            pos = pos.to(device)
            neg = neg.to(device)

            optimizer.zero_grad()

            embedded_query = query_model(query)
            embedded_pos = passage_model(pos)
            embedded_neg = passage_model(neg)

            loss = contrastiveLoss(embedded_query, embedded_pos, embedded_neg)
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()

        epoch_val_loss = 0.0

        passage_model.eval()
        passage_model.eval()

        with torch.no_grad():
            for query, pos, neg in tqdm(train_loader):
                query = query.to(device)
                pos = pos.to(device)
                neg = neg.to(device)

                embedded_query = query_model(query)
                embedded_pos = passage_model(pos)
                embedded_neg = passage_model(neg)

                loss = contrastiveLoss(embedded_query, embedded_pos, embedded_neg)

                epoch_val_loss += loss.item()
        avg_train_loss = epoch_train_loss/len(train_loader)
        avg_val_loss = epoch_val_loss/len(val_loader)

        print(f"\nEpoch {epoch+1}/{numEpochs} — " f"Train Loss: {avg_train_loss} | Val Loss: {avg_val_loss}")
        
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save(query_model.state_dict(), f'checkpoints/bestQuery.pt')
            torch.save(passage_model.state_dict(), f'checkpoints/bestPassage.pt')
            print(f"Model improved. Saved.")




    




In [71]:
train()

  0%|          | 2/82326 [00:00<2:09:54, 10.56it/s]

  0%|          | 58/82326 [00:05<2:21:21,  9.70it/s]


KeyboardInterrupt: 