In [1]:
from beir.datasets.data_loader import GenericDataLoader

  from tqdm.autonotebook import tqdm


In [2]:
from beir import util
dataset = "nfcorpus"
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
data_path = f"/dss/dsshome1/07/ra65bex2/srawat/{dataset}"

# Download and unzip the dataset
data_path = util.download_and_unzip(url, data_path)
data_path

'/dss/dsshome1/07/ra65bex2/srawat/nfcorpus/nfcorpus'

In [3]:
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")

  0%|          | 0/3633 [00:00<?, ?it/s]

In [4]:
contrastive_pairs=[]
import random
c=0
for query_id, relevant_docs in qrels.items():
    try:
        query_text = queries[query_id]
        for doc_id in relevant_docs:
            positive = corpus[doc_id]["text"]
        #print(relevant_docs)
        positive_doc_ids = set(relevant_docs)
        all_doc_ids = set(corpus.keys())
        negative_doc_ids = all_doc_ids - positive_doc_ids
        negative_doc_ids=list(negative_doc_ids)
        negative_doc_samples = random.sample(negative_doc_ids, k=5)
        negatives=[]
        for neg_doc_id in negative_doc_samples:
            negative_doc_text = corpus[neg_doc_id]["text"]
            negatives.append(negative_doc_text)
        contrastive_pairs.append({
            "anchor": query_text,
            "positive": positive,
            "negatives": negatives
        })
    except:
        c=c+1

In [5]:
contrastive_pairs[0:5]

[{'anchor': 'Do Cholesterol Statin Drugs Cause Breast Cancer?',
  'positive': 'Muscle pain and weakness are frequent complaints in patients receiving 3-hydroxymethylglutaryl coenzymeA (HMG CoA) reductase inhibitors (statins). Many patients with myalgia have creatine kinase levels that are either normal or only marginally elevated, and no obvious structural defects have been reported in patients with myalgia only. To investigate further the mechanism that mediates statin-induced skeletal muscle damage, skeletal muscle biopsies from statin-treated and non-statin-treated patients were examined using both electron microscopy and biochemical approaches. The present paper reports clear evidence of skeletal muscle damage in statin-treated patients, despite their being asymptomatic. Though the degree of overall damage is slight, it has a characteristic pattern that includes breakdown of the T-tubular system and subsarcolemmal rupture. These characteristic structural abnormalities observed in t

In [6]:
c

0

In [7]:
len(contrastive_pairs)

323

In [8]:
len(qrels)

323

In [9]:
from torch.utils.data import DataLoader

In [10]:
class ContrastiveDataset:
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        item = self.pairs[idx]
        return item["anchor"], item["positive"], item["negatives"]

In [11]:
contrastive_dataset = ContrastiveDataset(contrastive_pairs)

In [12]:
data_loader = DataLoader(contrastive_dataset, batch_size=32, shuffle=True)

In [13]:
len(data_loader)

11

In [None]:
import torch
file_path="/dss/dsshome1/07/ra65bex2/srawat/contrastive_learning/v1.1/app_baseline/checkpoint_epoch_3.pth"
lora_model = torch.load(file_path)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lora_model = lora_model.to(device)

In [None]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
def cosine_distance(x, y):
    return 1 - torch.nn.functional.cosine_similarity(x, y, dim=-1)

In [None]:
def evaluate_mrr(model, data_loader_val, distance_fn):
    model.eval()

    total_rr = 0.0
    num_queries = 0

    with torch.no_grad():
        for batch in data_loader_val:
            anchor_text = batch[0]
            positive_text = batch[1]
            negative_texts = batch[2]

            anchor_input = tokenizer(anchor_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
            positive_input = tokenizer(positive_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

            anchor_embedding = model(**anchor_input).last_hidden_state[:, 0, :]
            positive_embedding = model(**positive_input).last_hidden_state[:, 0, :]
            negative_embedding = [model(**tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)).last_hidden_state[:, 0, :] for neg in negative_texts]

            pos_dist = distance_fn(anchor_embedding, positive_embedding)
            neg_dist = torch.stack([distance_fn(anchor_embedding, neg) for neg in negative_embedding], dim=-1)
            all_similarities=torch.cat([-pos_dist.unsqueeze(1), -neg_dist], dim=1)

            sorted_similarities, sorted_indices = torch.sort(all_similarities, dim=1, descending=True)

            # Find the rank of the first relevant (positive) document
            positive_rank = (sorted_indices == 0).nonzero(as_tuple=True)[1] + 1  # +1 to make rank 1-based
            total_rr += torch.sum(1.0 / positive_rank.float()).item()  # Reciprocal rank
            num_queries += len(positive_rank)

    mrr = total_rr / num_queries
    return mrr

In [None]:
mrr_validation = evaluate_mrr(lora_model, data_loader, cosine_distance)
print(mrr_validation)