In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import random
from tqdm import tqdm
from torch.utils.data import TensorDataset
import os
from torch.utils.tensorboard.writer import SummaryWriter
import numpy as np
import time

# Use your prepared  dataset :
answer_embeddings_ = np.array(data['answer_embedding'].tolist()).astype(np.float32)
questions = data['question'].tolist()
question_embeddings_ = np.array(data['question_embedding'].tolist()).astype(np.float32)
answers = data['answer'].tolist()


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')
print(f"Using {device} device")
dtype = torch.float32
dtype_int = torch.int32

answer_embeddings  = torch.tensor(answer_embeddings_ ,  device = device, dtype=dtype )
question_embeddings  = torch.tensor(question_embeddings_ , device = device, dtype=dtype )

eps = 0.8
scores_q_a = torch.rand((answer_embeddings.size(0), answer_embeddings.size(0)), device = answer_embeddings.device)
scores_q_a  = torch.where( scores_q_a < eps, 0 ,scores_q_a)

In [10]:
'''
question_embeddings is a tensor that contains the embeddings for each question (queries).
answer_embeddings is a tensor that contains the embeddings for each answer (corpus).
scores_q_a is a score that measures the relevance between a question and its corresponding answer. This score can be manually labeled, or derived by using a more powerful embedding model, which will serve as the ground truth.

'''

class encoder(nn.Module):
    def __init__(self, input_size):
        super(encoder, self).__init__()
        self.fc1 = nn.Linear(input_size, input_size//2)  
        self.batch_norm1 = nn.BatchNorm1d(input_size//2)  
        self.relu = nn.ReLU()  
        self.fc2 = nn.Linear(input_size//2, input_size)  

    def forward(self, x):
        z = self.fc1(x)
        z = self.batch_norm1(z)  
        z = self.relu(z)
        z = self.fc2(z)
        return x + z
         
class decoder(nn.Module):
    def __init__(self, input_size):
        super(decoder, self).__init__()
        self.fc1 = nn.Linear(input_size, input_size//2)  
        self.batch_norm1 = nn.BatchNorm1d(input_size//2)  
        self.relu = nn.ReLU()  
        self.fc2 = nn.Linear(input_size//2, input_size)  

    def forward(self, x):
        z = self.fc1(x)
        z = self.batch_norm1(z)
        z = self.relu(z)
        z = self.fc2(z)
        return x + z
        
class Models(nn.Module):
    def __init__(self, input_size):
        super(Models, self).__init__()
        self.encoder = encoder(input_size)
        self.decoder = decoder(input_size)
    def forward(self, x):
        return self.encoder(x)
    def decode(self,x):
        return self.decoder(x)


In [11]:
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, question_embeddings, answer_embeddings, scores, answer_batch_size, device):
        super().__init__()  
        self.question_embeddings = question_embeddings
        self.answer_embeddings = answer_embeddings  
        self.scores = scores
        self.answer_batch_size = answer_batch_size
        non_zero_indices = []
        for k  in range(len(scores_q_a)) :
            non_zero =  torch.nonzero(scores_q_a[k], as_tuple=True)[0]
            non_zero_indices.append( non_zero if non_zero.numel() > 0 else None )

        self.non_zero_indices = non_zero_indices

    def __len__(self):
        return self.question_embeddings.size(0)  

    def __getitem__(self, index):
        relevant_indices = self.non_zero_indices[index]

        if relevant_indices is not None:  
            num_relevant = min(self.answer_batch_size//2, relevant_indices.numel())
            num_other = self.answer_batch_size - num_relevant

            all_indices = torch.randperm(relevant_indices.numel())  
            random_other_indices = all_indices[:num_other].to(device)
            random_relevant_indices = relevant_indices[torch.randperm(relevant_indices.size(0))[:num_relevant]].to(device)
            # while torch.any(random_relevant_indices[:, None] == random_other_indices): That will take longer !!
            #     random_relevant_indices = torch.randint(0, relevant_indices.size(0), size=(self.answer_batch_size - self.answer_batch_size // 2,))

            corpus_indices = torch.cat([random_relevant_indices, random_other_indices])
        else:
             corpus_indices = torch.randint(0, self.answer_embeddings.size(0), size=(self.answer_batch_size,))
        corpus_indices = torch.randint(0, self.answer_embeddings.size(0), size=(self.answer_batch_size,))

        return index, corpus_indices


In [12]:
input_dim = answer_embeddings_.shape[1]
models_dummy = Models(input_dim).to(device)

criterion_recovery = nn.L1Loss(reduction ='sum')
criterion_pred = nn.L1Loss(reduction=  'none')

batch_size = 128
dataset = CustomDataset(question_embeddings, answer_embeddings , scores_q_a , 2*batch_size, device)
dataloader = DataLoader(dataset, batch_size= batch_size , shuffle= True)

optimizer= torch.optim.AdamW(models_dummy.parameters(), lr = 3e-4)  # type: ignore


num_epochs = 5
alpha = 0.1
beta =  1

In [13]:

def ranking_loss_final(criterion_recovery, question_batch, answer_sampled_batch,  scores_batch, question_batch_encoded, answer_sampled_encoded, answer_sampled_decoded, alpha, beta, M , N) :
    diff = scores_batch.unsqueeze(1) - scores_batch.unsqueeze(2)  # (N, M, M), 1 -> j, 2 -> k
    question_batch_encoded = F.normalize(question_batch_encoded, dim=1)  # (N, dim)
    answer_sampled_encoded = F.normalize(answer_sampled_encoded, dim=2)  # (N, M, dim)
    N, M = answer_sampled_encoded.size(0), answer_sampled_encoded.size(1)
    similarity_encoded_question_answer = torch.sum(question_batch_encoded.unsqueeze(1) * answer_sampled_encoded, dim=2)  # (N, M)
    sim = torch.log(1 + torch.exp(similarity_encoded_question_answer.unsqueeze(2) - similarity_encoded_question_answer.unsqueeze(1)))  # (N, M, M)
    loss_rank = torch.sum(F.relu(diff) * sim) / (M * N) 
    loss_recovery_question = torch.sum(torch.norm(question_batch - question_batch_encoded, p=1, dim=1)) / N
    loss_recovery_answer = torch.sum(torch.norm(answer_sampled_batch - answer_sampled_encoded, p=1, dim=2)) / (M * N)
    loss_recovery = alpha * (loss_recovery_question + loss_recovery_answer) 
    decoding_contribution = torch.norm(question_batch_encoded.unsqueeze(1) - answer_sampled_decoded, p=1, dim=2)  # (N, M)
    loss_pred = beta * torch.sum(scores_batch * decoding_contribution) / (M * N)  
    loss = loss_rank + loss_recovery + loss_pred

    return loss, loss_rank, loss_recovery, loss_pred



writer = SummaryWriter()
train_loader = dataloader
val_loader = dataloader
for epoch in range(num_epochs):
    print(f'Starting epoch {epoch + 1}/{num_epochs}')
    
    models_dummy.train()

    total_train_loss = 0.0
    total_train_recovery = 0.0
    total_train_pred = 0.0
    total_train_rank = 0.0
    
    for k, batch in enumerate(train_loader):
        batch_indices, corpus_indices = batch
        question_batch = question_embeddings[batch_indices]  # (N, dim)
        scores_batch = scores_q_a[batch_indices.unsqueeze(1), corpus_indices]  # (N, M)
        answers_sampled_batch = answer_embeddings[corpus_indices]  # (N, M, dim)
        
        N, M = answers_sampled_batch.size(0), answers_sampled_batch.size(1)
        answer_sampled_batch_flattened = answers_sampled_batch.view(N * M, -1)  # (N*M, dim)

        optimizer.zero_grad()

        question_batch_encoded = models_dummy(question_batch)  # (N, dim)
        answer_sampled_encoded_flattened = models_dummy(answer_sampled_batch_flattened)  # (N*M, dim)
        
        answer_sampled_decoded_flattened = models_dummy.decode(answer_sampled_batch_flattened)  # (N*M, dim)
        
        answer_sampled_encoded = answer_sampled_encoded_flattened.view(N, M, input_dim)  # (N, M, dim)
        answer_sampled_decoded = answer_sampled_decoded_flattened.view(N, M, input_dim)  # (N, M, dim)
        
        loss, loss_rank, loss_recovery, loss_pred = ranking_loss_final(
            criterion_recovery=criterion_recovery,
            question_batch=question_batch,
            answer_sampled_batch=answers_sampled_batch,
            scores_batch=scores_batch,
            question_batch_encoded=question_batch_encoded,
            answer_sampled_encoded=answer_sampled_encoded,
            answer_sampled_decoded=answer_sampled_decoded,
            alpha=alpha, beta=beta,
            M = M,
            N = N
        )


        total_train_loss += loss.item()
        total_train_pred += loss_pred.item()
        total_train_recovery += loss_recovery.item()
        total_train_rank += loss_rank.item()

        loss.backward()
        optimizer.step()

    avg_train_loss = total_train_loss / len(dataloader)
    avg_train_pred = total_train_pred / len(dataloader)
    avg_train_recovery = total_train_recovery / len(dataloader)
    avg_train_rank = total_train_rank / len(dataloader)

    print(f'Epoch {epoch + 1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, '
          f' Rank Loss: {avg_train_rank:.4f} ,'
          f'Pred Loss: {avg_train_pred:.4f}, Recovery Loss: {avg_train_recovery:.4f} ')

    
    writer.add_scalar('Loss/Train/Global', avg_train_loss, epoch*len(train_loader))
    writer.add_scalar('Loss/Train/Pred', avg_train_pred,epoch*len(train_loader))
    writer.add_scalar('Loss/Train/Recovery', avg_train_recovery,epoch*len(train_loader))
    writer.add_scalar('Loss/Train/Rank', avg_train_rank,epoch*len(train_loader))
    

    # Validation step
    models_dummy.eval()
    total_val_loss = 0.0
    total_val_recovery = 0.0
    total_val_pred = 0.0
    total_val_rank = 0.0
    
    with torch.no_grad():
        for val_batch in val_loader:
            batch_indices, corpus_indices = batch
            question_batch = question_embeddings[batch_indices]  # (N, dim)
            scores_batch = scores_q_a[batch_indices.unsqueeze(1), corpus_indices]  # (N, M)
            answers_sampled_batch = answer_embeddings[corpus_indices]  # (N, M, dim)
            N, M = answers_sampled_batch.size(0), answers_sampled_batch.size(1)
            answer_sampled_batch_flattened = answers_sampled_batch.view(N * M, -1)  # (N*M, dim)
            question_batch_encoded = models_dummy(question_batch)  # (N, dim)
            answer_sampled_encoded_flattened = models_dummy(answer_sampled_batch_flattened)  # (N*M, dim)
            
            answer_sampled_decoded_flattened = models_dummy.decode(answer_sampled_batch_flattened)  # (N*M, dim)
            
            answer_sampled_encoded = answer_sampled_encoded_flattened.view(N, M, input_dim)  # (N, M, dim)
            answer_sampled_decoded = answer_sampled_decoded_flattened.view(N, M, input_dim)  # (N, M, dim)
            
            loss, loss_rank, loss_recovery, loss_pred = ranking_loss_final(
                criterion_recovery=criterion_recovery,
                question_batch=question_batch,
                answer_sampled_batch=answers_sampled_batch,
                scores_batch=scores_batch,
                question_batch_encoded=question_batch_encoded,
                answer_sampled_encoded=answer_sampled_encoded,
                answer_sampled_decoded=answer_sampled_decoded,
                alpha=alpha, beta=beta,
                M = M,
                N = N
            )


            total_val_loss += loss.item()
            total_val_pred += loss_pred.item()
            total_val_recovery += loss_recovery.item()
            total_val_rank += loss_rank.item()

    avg_val_loss = total_val_loss / len(val_loader)
    avg_val_pred = total_val_pred / len(val_loader)
    avg_val_recovery = total_val_recovery / len(val_loader)
    avg_val_rank = total_val_rank / len(val_loader)

    print(f'Epoch {epoch + 1}/{num_epochs} - Val Loss: {avg_val_loss:.4f}, '
          f'Rank Loss: {avg_val_rank:.4f} , '
          f'Pred Loss: {avg_val_pred:.4f}, Recovery Loss: {avg_val_recovery:.4f} .')
    
   
    writer.add_scalar('Loss/Valid/Global', avg_train_loss, epoch*len(train_loader))
    writer.add_scalar('Loss/Valid/Pred', avg_train_pred,epoch*len(train_loader))
    writer.add_scalar('Loss/Valid/Recovery', avg_train_recovery,epoch*len(train_loader))
    writer.add_scalar('Loss/Valid/Rank', avg_train_rank,epoch*len(train_loader))

writer.close()


Starting epoch 1/5
Epoch 1/5 - Train Loss: 56.1745,  Rank Loss: 25.6878 ,Pred Loss: 25.6388, Recovery Loss: 4.8480 
Epoch 1/5 - Val Loss: 36.3511, Rank Loss: 25.7814 , Pred Loss: 6.6155, Recovery Loss: 3.9542 .
Starting epoch 2/5
Epoch 2/5 - Train Loss: 35.8685,  Rank Loss: 25.6665 ,Pred Loss: 6.4122, Recovery Loss: 3.7898 
Epoch 2/5 - Val Loss: 34.9239, Rank Loss: 25.7135 , Pred Loss: 5.5890, Recovery Loss: 3.6214 .
Starting epoch 3/5
Epoch 3/5 - Train Loss: 34.1623,  Rank Loss: 25.6893 ,Pred Loss: 4.9568, Recovery Loss: 3.5162 
Epoch 3/5 - Val Loss: 33.3551, Rank Loss: 25.3898 , Pred Loss: 4.6059, Recovery Loss: 3.3594 .
Starting epoch 4/5
Epoch 4/5 - Train Loss: 33.4007,  Rank Loss: 25.6929 ,Pred Loss: 4.3933, Recovery Loss: 3.3145 
Epoch 4/5 - Val Loss: 33.6683, Rank Loss: 26.0893 , Pred Loss: 4.3935, Recovery Loss: 3.1855 .
Starting epoch 5/5
Epoch 5/5 - Train Loss: 32.8941,  Rank Loss: 25.6816 ,Pred Loss: 4.0474, Recovery Loss: 3.1652 
Epoch 5/5 - Val Loss: 32.7191, Rank Loss: 25

There is no evaluation using @NDCG, as was done in the article, because the scores here were sampled randomly. Although a dataset with relevant scores wasn’t available, I did my best with this code.

This is the function I will be using if everything was provided .

In [14]:
def calculate_evaluation_metrics(true_data, predicted_current, predicted_old, log_writer, top_k_values, iteration_step):
    def compute_dcg(relevance_vals):
        dcg_score = 0.0
        for idx, rel in enumerate(relevance_vals):
            dcg_score += rel / np.log2(idx + 2)
        return dcg_score

    def normalized_dcg_at_top_k(actual_ids, pred_ids, top_k):
        ndcg_scores = []
        for actual, predicted in zip(actual_ids, pred_ids):
            relevance_scores = [1 if item in actual else 0 for item in set(predicted[:top_k])]
            dcg_val = compute_dcg(relevance_scores)
            ideal_relevance = sorted(relevance_scores, reverse=True)
            idcg_val = compute_dcg(ideal_relevance)
            ndcg = dcg_val / idcg_val if idcg_val else 0
            ndcg_scores.append(ndcg)
        return sum(ndcg_scores) / len(ndcg_scores) if ndcg_scores else 0

    # Model comparison and logging function
    def log_and_compare(true_ids, current_model_ids, baseline_model_ids, top_k_values, step):
        for  top_k in top_k_values:
           
            baseline_ndcg = normalized_dcg_at_top_k(true_ids, baseline_model_ids, top_k)
            current_ndcg = normalized_dcg_at_top_k(true_ids, current_model_ids, top_k)
            log_writer.add_scalar(f'NDCG@{top_k}/Baseline', baseline_ndcg, step)
            log_writer.add_scalar(f'NDCG@{top_k}/Current', current_ndcg, step)

    actual_relevant_ids, baseline_model_ids, current_model_ids = true_data, predicted_old, predicted_current
    primary_evaluation = log_and_compare(actual_relevant_ids, current_model_ids, baseline_model_ids, top_k_values, iteration_step)
    return primary_evaluation
