In [1]:
from beir.datasets.data_loader import GenericDataLoader

  from tqdm.autonotebook import tqdm


In [2]:
data_path = "/dss/dsshome1/07/ra65bex2/srawat/climate-fever/climate-fever"

In [3]:
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")

  0%|          | 0/5416593 [00:00<?, ?it/s]

In [4]:
contrastive_pairs=[]
import random
c=0
for query_id, relevant_docs in qrels.items():
    try:
        query_text = queries[query_id]
        for doc_id in relevant_docs:
            positive = corpus[doc_id]["text"]
        #print(relevant_docs)
        positive_doc_ids = set(relevant_docs)
        all_doc_ids = set(corpus.keys())
        negative_doc_ids = all_doc_ids - positive_doc_ids
        negative_doc_ids=list(negative_doc_ids)
        negative_doc_samples = random.sample(negative_doc_ids, k=5)
        negatives=[]
        for neg_doc_id in negative_doc_samples:
            negative_doc_text = corpus[neg_doc_id]["text"]
            negatives.append(negative_doc_text)
        contrastive_pairs.append({
            "anchor": query_text,
            "positive": positive,
            "negatives": negatives
        })
    except:
        c=c+1

In [5]:
contrastive_pairs[0:5]

[{'anchor': 'Global warming is driving polar bears toward extinction',
  'positive': "Global warming , also referred to as climate change , is the observed century-scale rise in the average temperature of the Earth 's climate system and its related effects . Multiple lines of scientific evidence show that the climate system is warming . Many of the observed changes since the 1950s are unprecedented in the instrumental temperature record which extends back to the mid 19th century , and in paleoclimate proxy records over thousands of years .   In 2013 , the Intergovernmental Panel on Climate Change ( IPCC ) Fifth Assessment Report concluded that `` It is extremely likely that human influence has been the dominant cause of the observed warming since the mid-20th century . '' The largest human influence has been emission of greenhouse gases such as carbon dioxide , methane and nitrous oxide . Climate model projections summarized in the report indicated that during the 21st century the glob

In [6]:
c

0

In [7]:
len(contrastive_pairs)

1535

In [8]:
len(qrels)

1535

In [9]:
from torch.utils.data import DataLoader

In [10]:
class ContrastiveDataset:
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        item = self.pairs[idx]
        return item["anchor"], item["positive"], item["negatives"]

In [11]:
contrastive_dataset = ContrastiveDataset(contrastive_pairs)

In [12]:
data_loader = DataLoader(contrastive_dataset, batch_size=32, shuffle=True)

In [13]:
len(data_loader)

48

In [14]:
import torch
file_path="/dss/dsshome1/07/ra65bex2/srawat/contrastive_learning/v1.1/app_baseline/checkpoint_epoch_3.pth"
lora_model = torch.load(file_path)

  lora_model = torch.load(file_path)


In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lora_model = lora_model.to(device)

In [16]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [17]:
def cosine_distance(x, y):
    return 1 - torch.nn.functional.cosine_similarity(x, y, dim=-1)

In [18]:
def evaluate_mrr(model, data_loader_val, distance_fn):
    model.eval()

    total_rr = 0.0
    num_queries = 0

    with torch.no_grad():
        for batch in data_loader_val:
            anchor_text = batch[0]
            positive_text = batch[1]
            negative_texts = batch[2]

            anchor_input = tokenizer(anchor_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
            positive_input = tokenizer(positive_text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

            anchor_embedding = model(**anchor_input).last_hidden_state[:, 0, :]
            positive_embedding = model(**positive_input).last_hidden_state[:, 0, :]
            negative_embedding = [model(**tokenizer(neg, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)).last_hidden_state[:, 0, :] for neg in negative_texts]

            pos_dist = distance_fn(anchor_embedding, positive_embedding)
            neg_dist = torch.stack([distance_fn(anchor_embedding, neg) for neg in negative_embedding], dim=-1)
            all_similarities=torch.cat([-pos_dist.unsqueeze(1), -neg_dist], dim=1)

            sorted_similarities, sorted_indices = torch.sort(all_similarities, dim=1, descending=True)

            # Find the rank of the first relevant (positive) document
            positive_rank = (sorted_indices == 0).nonzero(as_tuple=True)[1] + 1  # +1 to make rank 1-based
            total_rr += torch.sum(1.0 / positive_rank.float()).item()  # Reciprocal rank
            num_queries += len(positive_rank)

    mrr = total_rr / num_queries
    return mrr

In [19]:
mrr_validation = evaluate_mrr(lora_model, data_loader, cosine_distance)
print(mrr_validation)

0.40052117835426954
