In [1]:
from datasets import load_dataset

In [2]:
dataset_val = load_dataset('ms_marco', 'v1.1', split='validation')

In [3]:
dataset_val

Dataset({
    features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
    num_rows: 10047
})

In [4]:
import random
import re
import nltk
from nltk.corpus import stopwords, wordnet
nltk.download('stopwords')
nltk.download('wordnet')
stop_words = set(stopwords.words('english'))

def clean_text(text):
    return re.sub(r'[^\w\s.!?]', '', text)

def select_words(text, num_words):
    words = [w for w in text.split() if w.lower() not in stop_words]
    return random.sample(words, min(num_words, len(words)))

def introduce_typo(word):
    if len(word) > 1:
        idx = random.randint(0, len(word) - 1)
        return word[:idx] + random.choice('abcdefghijklmnopqrstuvwxyz') + word[idx+1:]
    return word

def introduce_noise(word):
    noise_chars = ['@', '#', '$', '%', '&', '*']
    if len(word) > 1:
        idx = random.randint(0, len(word) - 1)
        return word[:idx] + random.choice(noise_chars) + word[idx+1:]
    return word

def replace_with_synonym(word):
    synonyms = [syn.lemmas()[0].name() for syn in wordnet.synsets(word) if syn.lemmas()]
    return random.choice(synonyms) if synonyms else word

def corrupt_word(word, method):
    if method == 'typo':
        return introduce_typo(word)
    elif method == 'noise':
        return introduce_noise(word)
    elif method == 'synonym':
        return replace_with_synonym(word)
    return word

def corrupt_text(text):
    corrupted_words = []  
    words_to_corrupt=select_words(clean_text(text), num_words=1)
    for word in clean_text(text).split():
        if word in words_to_corrupt:
            corruption_method = random.choice(['typo', 'noise', 'synonym'])
            corrupted_words.append(corrupt_word(word, corruption_method))
        else:
            corrupted_words.append(word)
    return ' '.join(corrupted_words)

[nltk_data] Downloading package stopwords to
[nltk_data]     /dss/dsshome1/07/ra65bex2/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /dss/dsshome1/07/ra65bex2/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [5]:
def preprocess(example):
    positive_passages=[]
    for p in enumerate(example['passages']["is_selected"]):
        if p[1]==1:
            positive_passages.append(example["passages"]["passage_text"][p[0]])
    negative_passages = []
    for p in enumerate(example['passages']["is_selected"]):
        if p[1]==0:
            negative_passages.append(example["passages"]["passage_text"][p[0]])
    if (len(positive_passages)>0 and len(negative_passages)>=5):
        positive = positive_passages[0]
        negatives = negative_passages[:5]
        return {
            "query": corrupt_text(example["query"]),
            "positive": positive,
            "negatives": negatives
        }
    else:
        return {"query": None, "positive": None, "negatives": None}

In [6]:
processed_data_val = dataset_val.map(preprocess, remove_columns=dataset_val.column_names)

In [7]:
processed_data_val = processed_data_val.filter(lambda x: x['query'] is not None and x['positive'] is not None)

In [8]:
contrastive_pairs_val = []
for item in processed_data_val:
    query = item["query"]
    positive = item["positive"]
    negatives = item["negatives"]
    contrastive_pairs_val.append({
        "anchor": query,
        "positive": positive,
        "negatives": negatives
    })

In [9]:
len(contrastive_pairs_val)

9084

In [10]:
from torch.utils.data import DataLoader

In [11]:
contrastive_pairs_val[2]

{'anchor': 'what is a furu@cle boil',
 'positive': 'A boil, also called a furuncle, is a deep folliculitis, infection of the hair follicle. It is most commonly caused by infection by the bacterium Staphylococcus aureus, resulting in a painful swollen area on the skin caused by an accumulation of pus and dead tissue. Signs and symptoms [edit]. Boils are bumpy, red, pus-filled lumps around a hair follicle that are tender, warm, and very painful. They range from pea-sized to golf ball-sized. A yellow or white point at the center of the lump can be seen when the boil is ready to drain or discharge pus.',
 'negatives': ['Knowledge center. A boil, also known as a furuncle is a skin abscess, a painful bump that forms under the skin-it is full of puss. A carbuncle is collection of boils that develop under the skin. When bacteria infect hair follicles they can swell up and turn into boils. ',
  'When the hair follicle becomes infected, the skin around it also becomes inflamed. The furuncle look

In [12]:
class ContrastiveDataset:
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        item = self.pairs[idx]
        return item["anchor"], item["positive"], item["negatives"]

In [13]:
contrastive_dataset_val = ContrastiveDataset(contrastive_pairs_val)

In [14]:
data_loader_val = DataLoader(contrastive_dataset_val, batch_size=32, shuffle=True)

In [15]:
len(data_loader_val)

284

In [16]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [17]:
import torch
file_path="/dss/dsshome1/07/ra65bex2/srawat/contrastive_learning/v1.1/app_average/average_checkpoint_epoch_3.pth"
lora_model = torch.load(file_path)

  lora_model = torch.load(file_path)


In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lora_model = lora_model.to(device)

In [19]:
def cosine_distance(x, y):
    return 1 - torch.nn.functional.cosine_similarity(x, y, dim=-1)

In [20]:
def avg_embedding(inputs, model):
    input=model(**inputs)
    input_last_hidden_state=input.last_hidden_state
    input_attention_mask = inputs['attention_mask']
    input_masked_embeddings = input_last_hidden_state * input_attention_mask.unsqueeze(-1)
    input_sum_embeddings = torch.sum(input_masked_embeddings, dim=1)
    input_token_counts = torch.sum(input_attention_mask, dim=1).unsqueeze(-1)
    input_avg_embeddings = input_sum_embeddings / input_token_counts
    return(input_avg_embeddings)

In [21]:
def evaluate_mrr(model, data_loader_val, distance_fn):
    model.eval()  

    total_rr = 0.0
    num_queries = 0

    with torch.no_grad():
        for batch in data_loader_val:
            anchor_text = batch[0]
            positive_text = batch[1]
            negative_texts = batch[2]

            anchor_input = tokenizer(anchor_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
            positive_input = tokenizer(positive_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

            anchor_embedding = avg_embedding(anchor_input, model)
            positive_embedding = avg_embedding(positive_input, model)
            negative_embedding = [avg_embedding(tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device), model) for neg in negative_texts]

            pos_dist = distance_fn(anchor_embedding, positive_embedding)
            neg_dist = torch.stack([distance_fn(anchor_embedding, neg) for neg in negative_embedding], dim=-1)
            all_similarities=torch.cat([-pos_dist.unsqueeze(1), -neg_dist], dim=1)
            
            sorted_similarities, sorted_indices = torch.sort(all_similarities, dim=1, descending=True)

            # Find the rank of the first relevant (positive) document
            positive_rank = (sorted_indices == 0).nonzero(as_tuple=True)[1] + 1  # +1 to make rank 1-based
            total_rr += torch.sum(1.0 / positive_rank.float()).item()  # Reciprocal rank
            num_queries += len(positive_rank)

    mrr = total_rr / num_queries
    return mrr

In [22]:
mrr_validation = evaluate_mrr(lora_model, data_loader_val, cosine_distance)
print(mrr_validation)

0.5509870956249565
