In [2]:
'''
question_embeddings is a tensor that contains the embeddings for each question (queries).
answer_embeddings is a tensor that contains the embeddings for each answer (corpus).
scores_q_a is a score that measures the relevance between a question and its corresponding answer. This score can be manually labeled, or derived by using a more powerful embedding model, which will serve as the ground truth.

'''

class encoder(nn.Module):
    def __init__(self, input_size):
        super(encoder, self).__init__()
        self.fc1 = nn.Linear(input_size, input_size)  
        # self.batch_norm1 = nn.BatchNorm1d(input_size)  
        # self.relu = nn.ReLU()  
        # self.fc2 = nn.Linear(input_size, input_size)  

    def forward(self, x):
        z = self.fc1(x)
        # x = self.batch_norm1(x)  
        # x = self.relu(x)
        # x = self.fc2(x)
        return x +  z
         
class decoder(nn.Module):
    def __init__(self, input_size):
        super(decoder, self).__init__()
        self.fc1 = nn.Linear(input_size, input_size)  
        # self.batch_norm1 = nn.BatchNorm1d(input_size)  
        # self.relu = nn.ReLU()  
        # self.fc2 = nn.Linear(input_size, input_size)  

    def forward(self, x):
        z = self.fc1(x)
        # x = self.batch_norm1(x)  
        # x = self.relu(x)
        # x = self.fc2(x)
        return x + z
        
class Models(nn.Module):
    def __init__(self, input_size):
        super(Models, self).__init__()
        self.encoder = encoder(input_size)
        self.decoder = decoder(input_size)
    def forward(self, x):
        return self.encoder(x)
    def decode(self,x):
        return self.decoder(x)


In [4]:
non_zero_indices = []
for k  in range(len(scores_q_a)) :
    non_zero =  torch.nonzero(scores_q_a[k], as_tuple=True)[0]
    non_zero_indices.append( non_zero if non_zero.numel() > 0 else None )


In [23]:
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, question_embeddings, answer_embeddings, scores, answer_batch_size, non_zero_indices, device):
        super().__init__()  
        self.question_embeddings = question_embeddings
        self.answer_embeddings = answer_embeddings  
        self.scores = scores
        self.answer_batch_size = answer_batch_size
        self.non_zero_indices = non_zero_indices

    def __len__(self):
        return self.question_embeddings.size(0)  

    def __getitem__(self, index):
        relevant_indices = self.non_zero_indices[index]

        if relevant_indices is not None:  
            num_relevant = min(self.answer_batch_size//2, relevant_indices.numel())
            num_other = self.answer_batch_size - num_relevant

            all_indices = torch.randperm(relevant_indices.numel())  
            random_other_indices = all_indices[:num_other].to(device)
            random_relevant_indices = relevant_indices[torch.randperm(relevant_indices.size(0))[:num_relevant]].to(device)
            # while torch.any(random_relevant_indices[:, None] == random_other_indices): That will take longer !!
            #     random_relevant_indices = torch.randint(0, relevant_indices.size(0), size=(self.answer_batch_size - self.answer_batch_size // 2,))

            corpus_indices = torch.cat([random_relevant_indices, random_other_indices])
        else:
             corpus_indices = torch.randint(0, self.answer_embeddings.size(0), size=(self.answer_batch_size,))
        corpus_indices = torch.randint(0, self.answer_embeddings.size(0), size=(self.answer_batch_size,))

        return index, corpus_indices


In [37]:
input_dim = answer_embeddings_.shape[1]
models_dummy = Models(input_dim).to(device)

In [42]:
def ranking_loss(criterion_recovery,criterion_pred, question_batch,their_answers_batch, answer_batch , scores_batch, question_encoded,their_answers_encoded , answer_encoded , answer_decoded , alpha, beta) :
    diff = (scores_batch.unsqueeze(1) - scores_batch.unsqueeze(2))
    similarity_encoded_question_answer = torch.sum(F.normalize(question_encoded, dim = 1).unsqueeze(1)*F.normalize(answer_encoded, dim = 2) , dim = 2)

    sim = torch.log(1+ (similarity_encoded_question_answer.unsqueeze(1)-similarity_encoded_question_answer.unsqueeze(2)))
    loss_rank = torch.sum((diff > 0)*diff*sim )
    loss = loss_rank
    loss_recovery = alpha*(criterion_recovery(question_batch, question_encoded) + criterion_recovery(their_answers_batch, their_answers_encoded) )
  
    decoding_contribution = torch.norm(question_encoded.unsqueeze(1)- answer_decoded, p=1 , dim =  2 )  # (N , M)  

    
    loss_pred = beta*(scores_batch* decoding_contribution / scores_batch.sum()) # scores (N , M) *  decoding_part (N, M) ........
    return loss, loss_rank, loss_recovery, loss_pred


criterion_recovery = nn.L1Loss()
criterion_pred = nn.L1Loss(reduction=  'none')

batch_size = 64
dataset = CustomDataset(question_embeddings, answer_embeddings , scores_q_a , 4*batch_size, non_zero_indices , device)
dataloader = DataLoader(dataset, batch_size= batch_size , shuffle= True)

optimizer= torch.optim.AdamW(models_dummy.parameters(), lr=0.01)  # type: ignore

alpha = 0
beta= 0
total_train_loss = 0
for  k ,batch in enumerate(dataloader) :
    batch_indices, corpus_indices = batch
   
    question_batch  = question_embeddings[batch_indices]
    their_answers_batch = answer_embeddings[batch_indices]
    # answer_batch , scores_batch = answer_embeddings[batch_indices], 

    scores_batch = scores_q_a[batch_indices.unsqueeze(1), corpus_indices]
    answer_batch =  answer_embeddings[corpus_indices]  # (N, M , dim)
    M = answer_batch.size(1)
    N = answer_batch.size(0)
    
    answer_batch_flattened = answer_batch .view(M*N,-1)

    optimizer.zero_grad()


    question_encoded= models_dummy(question_batch)
    answer_encoded_flattened = models_dummy(answer_batch_flattened)
    their_answers_encoded = models_dummy(their_answers_batch)
    answer_decoded_flattened = models_dummy.decode(answer_batch_flattened)

    answer_encoded = answer_encoded_flattened.view((N, M , input_dim))
    answer_decoded = answer_decoded_flattened.view((N, M , input_dim))
    

    loss,_,_,_ = ranking_loss(criterion_recovery,criterion_pred, question_batch,their_answers_batch, answer_batch , scores_batch, question_encoded,their_answers_encoded , answer_encoded , answer_decoded , alpha, beta) 
    loss.backward()
    optimizer.step()

    total_train_loss += loss.item()

avg_train_loss = total_train_loss/len(dataloader) 

print(f'Train Loss: {avg_train_loss:.6f}')


Train Loss: 0.000000


In [None]:
def training_step(writer, models, train_dataset, valid_dataset, question_embeddings, ground_truth, num_train, num_epochs, batch_size, learning_rate , tau, alpha, beta, hparams, k_values,path,patience, commentary, factor):    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 
    # valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 
    device_ = torch.device('cpu')
    k_max = max(k_values) 
    print('We ended the dataloader part .')

    optimizer= torch.optim.AdamW(models.parameters(), lr=learning_rate)  # type: ignore
    # early_stopping = EarlyStopping(patience=patience , path = path)
    evaluating  = 0
    # Loss functions :
    criterion_recovery = nn.L1Loss()
    criterion_pred = nn.L1Loss(reduce = False)
    # Initialize TensorBoard writer
    writer.add_text('Run Commentary', commentary, 0)

    # Log hyperparameters
    writer.add_hparams(hparams, {})
    Step = 0 
    for epoch in range(num_epochs):
        print(f'We started the process of epoch {epoch + 1}')
        models.train()
        total_train_loss = 0
        total_train_reg = 0
        total_train_triplets = 0
        for batch in train_loader:
            batch_indices = batch[0]
            question_batch , answer_batch , scores_batch = question_embeddings[batch_indices],answer_embeddings[batch_indices], scores_q_a[batch_indices[:, None], batch_indices]
            question_encoded, answer_encoded = models_dummy(question_batch), models_dummy(answer_batch)
            answer_decoded = models_dummy.decode(answer_batch)

            loss, loss_rank, loss_recovery, loss_pred  = ranking_loss(criterion_recovery =  criterion_recovery,criterion_pred = criterion_pred, question_batch = question_batch, answer_batch = answer_batch, scores = scores_batch, question_encoded = question_encoded, answer_encoded = answer_encoded , answer_decoded = answer_encoded ,alpha = alpha , beta = beta) 
