In [1]:
from datasets import load_dataset

In [2]:
dataset = load_dataset('ms_marco', 'v1.1')

In [3]:
dataset

DatasetDict({
    validation: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 10047
    })
    train: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 82326
    })
    test: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 9650
    })
})

In [4]:
dataset = load_dataset('ms_marco', 'v1.1', split='train')

In [5]:
dataset

Dataset({
    features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
    num_rows: 82326
})

In [6]:
type(dataset)

datasets.arrow_dataset.Dataset

In [7]:
def preprocess(example):
    positive_passages=[]
    for p in enumerate(example['passages']["is_selected"]):
        if p[1]==1:
            positive_passages.append(example["passages"]["passage_text"][p[0]])
    negative_passages = []
    for p in enumerate(example['passages']["is_selected"]):
        if p[1]==0:
            negative_passages.append(example["passages"]["passage_text"][p[0]])
    if (len(positive_passages)>0 and len(negative_passages)>=5):
        positive = positive_passages[0]
        negatives = negative_passages[:5]
        return {
            "query": example["query"],
            "positive": positive,
            "negatives": negatives
        }
    else:
        return {"query": None, "positive": None, "negatives": None}

In [8]:
processed_data = dataset.map(preprocess, remove_columns=dataset.column_names)

In [9]:
processed_data = processed_data.filter(lambda x: x['query'] is not None and x['positive'] is not None)

In [10]:
contrastive_pairs = []
for item in processed_data:
    query = item["query"]
    positive = item["positive"]
    negatives = item["negatives"]
    contrastive_pairs.append({
        "anchor": query,
        "positive": positive,
        "negatives": negatives
    })

In [11]:
len(contrastive_pairs)

74538

In [12]:
from torch.utils.data import DataLoader

In [13]:
class ContrastiveDataset:
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        item = self.pairs[idx]
        return item["anchor"], item["positive"], item["negatives"]

In [14]:
contrastive_dataset = ContrastiveDataset(contrastive_pairs)

In [15]:

data_loader = DataLoader(contrastive_dataset, batch_size=32, shuffle=True)

In [16]:
len(data_loader)

7

In [17]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [18]:
model = AutoModel.from_pretrained("bert-base-uncased")

In [19]:
from peft import LoraConfig, get_peft_model

In [20]:
lora_config = LoraConfig(
    task_type= "FEATURE_EXTRACTION"
)

In [21]:
lora_model = get_peft_model(model, lora_config)

In [22]:
lora_model.print_trainable_parameters()

trainable params: 294,912 || all params: 109,777,152 || trainable%: 0.2686


In [23]:
import torch

In [24]:
def cosine_distance(x, y):
    return 1 - torch.nn.functional.cosine_similarity(x, y, dim=-1)

In [25]:
def info_nce_loss(anchor_embedding, positive_embedding, negative_embedding, distance_fn):

    pos_dist = distance_fn(anchor_embedding, positive_embedding)
    neg_dist = torch.stack([distance_fn(anchor_embedding, neg) for neg in negative_embedding], dim=-1) 
    
    logits = torch.cat([-pos_dist.unsqueeze(1), -neg_dist], dim=1)
    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

    loss = torch.nn.CrossEntropyLoss()(logits, labels)
    return loss

In [26]:
import torch.optim as optim
optimizer = torch.optim.AdamW(lora_model.parameters(), lr=5e-5)

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [28]:
lora_model = lora_model.to(device)

In [34]:
dataset_val = load_dataset('ms_marco', 'v1.1', split='validation')
processed_data_val = dataset_val.map(preprocess, remove_columns=dataset_val.column_names)
processed_data_val = processed_data_val.filter(lambda x: x['query'] is not None and x['positive'] is not None)
contrastive_pairs_val = []
for item in processed_data_val:
    query = item["query"]
    positive = item["positive"]
    negatives = item["negatives"]
    contrastive_pairs_val.append({
        "anchor": query,
        "positive": positive,
        "negatives": negatives
    })
contrastive_dataset_val = ContrastiveDataset(contrastive_pairs_val)
data_loader_val = DataLoader(contrastive_dataset_val, batch_size=32, shuffle=True)
len(data_loader_val)

7

In [35]:
def avg_embedding(inputs, model):
    input=model(**inputs)
    input_last_hidden_state=input.last_hidden_state
    input_attention_mask = inputs['attention_mask']
    input_masked_embeddings = input_last_hidden_state * input_attention_mask.unsqueeze(-1)
    input_sum_embeddings = torch.sum(input_masked_embeddings, dim=1)
    input_token_counts = torch.sum(input_attention_mask, dim=1).unsqueeze(-1)
    input_avg_embeddings = input_sum_embeddings / input_token_counts
    return(input_avg_embeddings)

In [36]:
def evaluate_mrr(model, data_loader_val, distance_fn):
    model.eval() 

    total_rr = 0.0
    num_queries = 0

    with torch.no_grad(): 
        for batch in data_loader_val:
            anchor_text = batch[0]
            positive_text = batch[1]
            negative_texts = batch[2]

            anchor_input = tokenizer(anchor_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
            positive_input = tokenizer(positive_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

            anchor_embedding = avg_embedding(anchor_input, model)
            positive_embedding = avg_embedding(positive_input, model)
            negative_embedding = [avg_embedding(tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device), model) for neg in negative_texts]

            pos_dist = distance_fn(anchor_embedding, positive_embedding)
            neg_dist = torch.stack([distance_fn(anchor_embedding, neg) for neg in negative_embedding], dim=-1)
            all_similarities=torch.cat([-pos_dist.unsqueeze(1), -neg_dist], dim=1)

            sorted_similarities, sorted_indices = torch.sort(all_similarities, dim=1, descending=True)

            # Find the rank of the first relevant (positive) document
            positive_rank = (sorted_indices == 0).nonzero(as_tuple=True)[1] + 1  # +1 to make rank 1-based
            total_rr += torch.sum(1.0 / positive_rank.float()).item()  # Reciprocal rank
            num_queries += len(positive_rank)

    mrr = total_rr / num_queries
    return mrr

In [40]:
num_epochs=3
import os
save_dir ="/dss/dsshome1/07/ra65bex2/srawat/app_average"

In [38]:
import time
epoch_metrics = []

In [39]:
for epoch in range(num_epochs):
    start_time = time.time()
    lora_model.train()
    
    total_loss = 0.0
    for batch in data_loader:
    
        anchor_texts = batch[0]
        positive_texts = batch[1]
        negative_texts = batch[2]
    
        anchor_inputs = tokenizer(anchor_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
        positive_inputs = tokenizer(positive_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
    
        anchor_embedding = avg_embedding(anchor_inputs, lora_model)
        positive_embedding = avg_embedding(positive_inputs, lora_model)
        negative_embedding = [avg_embedding(tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device), lora_model) for neg in negative_texts]

        loss = info_nce_loss(anchor_embedding, positive_embedding, negative_embedding, distance_fn=cosine_distance)
        
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    save_path = os.path.join(save_dir, f"average_checkpoint_epoch_{epoch+1}.pth")
    torch.save(lora_model, save_path)
    print(f"EPOCH {epoch+1}:")
    print(f"Checkpoint saved: {save_path}")

    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {total_loss / len(data_loader)}")
    mrr_validation = evaluate_mrr(lora_model, data_loader_val, cosine_distance)
    print(f"Mean Reciprocal Rank (MRR) for validation set: {mrr_validation:.4f}")
    end_time = time.time()
    print(f"Epoch {epoch+1} took {(end_time - start_time) / 60:.4f} minutes.")
    print(f"\n")
    epoch_metrics.append({
        'epoch': epoch + 1,
        'training_loss': total_loss / len(data_loader),
        'mrr_validation': mrr_validation,
        'time_taken_minutes': (end_time - start_time) / 60
    })

EPOCH 1:
Checkpoint saved: /dss/dsshome1/07/ra65bex2/srawat/average_checkpoint_epoch_1.pth
Epoch 1/3, Training Loss: 1.7842329570225306
Mean Reciprocal Rank (MRR) for validation set: 0.4153
Epoch 1 took 0.1062 minutes.


EPOCH 2:
Checkpoint saved: /dss/dsshome1/07/ra65bex2/srawat/average_checkpoint_epoch_2.pth
Epoch 2/3, Training Loss: 1.7833352770124162
Mean Reciprocal Rank (MRR) for validation set: 0.4137
Epoch 2 took 0.0542 minutes.


EPOCH 3:
Checkpoint saved: /dss/dsshome1/07/ra65bex2/srawat/average_checkpoint_epoch_3.pth
Epoch 3/3, Training Loss: 1.781571354184832
Mean Reciprocal Rank (MRR) for validation set: 0.4137
Epoch 3 took 0.0549 minutes.




In [43]:
import json

In [44]:
with open(save_dir + '/average_epoch_metrics.json', 'w') as f:
    json.dump(epoch_metrics, f)