In [None]:
import torch
import pandas as pd
import numpy as np


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
merged_df = pd.read_parquet('processed_df/merged_df.parquet')

In [None]:
merged_df = merged_df.reset_index()
merged_df.rename(columns={'index': 'claim_index'}, inplace=True)
merged_df.head()

In [None]:
wiki_df = merged_df[['claim_index', 'wiki_index', 'wiki_text']].drop_duplicates(subset = ['wiki_index', 'wiki_text'])

In [None]:
wiki_df.head()

In [None]:
claims = list(merged_df['claim'])
wiki_text = list(set(list(merged_df['wiki_text'])))
# evidence = list(df['evidence_wiki_url'])


In [None]:
claims[:3]

In [None]:
len(claims)

In [None]:
len(wiki_text)

In [None]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import torch

question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-multiset-base').to(device)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-multiset-base')

def encode_questions(questions, batch_size=32):
    question_encoder.eval()
    question_embeddings = []

    for start_idx in range(0, len(questions), batch_size):
        
        print(f"Encoding indices: {start_idx}:{start_idx+batch_size}")
        batch_questions = questions[start_idx:start_idx + batch_size]

        # Using batch_encode_plus for efficient tokenization
        inputs = question_tokenizer.batch_encode_plus(batch_questions, return_tensors='pt', padding=True, truncation=True, max_length=512)

        # Move tokenized inputs to GPU
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Forward pass through the encoder
        with torch.no_grad():
            batch_embeddings = question_encoder(**inputs).pooler_output
        question_embeddings.append(batch_embeddings.cpu())
        
        # Clearing memory
        del inputs, batch_embeddings
        torch.cuda.empty_cache()

    return torch.cat(question_embeddings, dim=0)

In [None]:
batch_size = 256
claims_embeddings = encode_questions(claims, batch_size=batch_size)

In [None]:
claims_embeddings.shape

In [None]:
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer

context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-multiset-base').to(device)
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-multiset-base')

def encode_contexts(contexts, batch_size=32):
    context_encoder.eval()
    context_embeddings = []

    for start_idx in range(0, len(contexts), batch_size):
        
        print(f"Encoding indices: {start_idx}:{start_idx+batch_size}")
        batch_contexts = contexts[start_idx:start_idx + batch_size]

        # Using batch_encode_plus for efficient tokenization
        inputs = context_tokenizer.batch_encode_plus(batch_contexts, return_tensors='pt', padding=True, truncation=True, max_length=512)

        # Move tokenized inputs to GPU
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Forward pass through the encoder
        with torch.no_grad():
            batch_embeddings = context_encoder(**inputs).pooler_output
        context_embeddings.append(batch_embeddings.cpu())
        
        # Clearing memory
        del inputs, batch_embeddings
        torch.cuda.empty_cache()

    return torch.cat(context_embeddings, dim=0)

In [None]:
batch_size = 256
context_embeddings = encode_contexts(wiki_text, batch_size=batch_size)

In [None]:
context_embeddings.shape

In [None]:
claims_embeddings = claims_embeddings.to(device)
context_embeddings = context_embeddings.to(device)


In [None]:

def compute_similarity_in_batches(claims_embeddings, context_embeddings, batch_size=100):
    num_claims = claims_embeddings.size(0)
    num_contexts = context_embeddings.size(0)
    top_passages_indices = np.zeros(num_claims, dtype=int)
    
    for start_idx in range(0, num_claims, batch_size):
        end_idx = min(start_idx + batch_size, num_claims)
        print(f"Computing from {start_idx} : {end_idx}")
        batch_scores = torch.matmul(claims_embeddings[start_idx:end_idx], context_embeddings.T)

        # Compute top passages for the current batch and store the indices
        top_passages_batch = np.argmax(batch_scores.detach().cpu().numpy(), axis=1)
        
        top_passages_indices[start_idx:end_idx] = top_passages_batch
        
        # Clearing memory
        del batch_scores
        torch.cuda.empty_cache()

    return top_passages_indices

embeddings = compute_similarity_in_batches(claims_embeddings, context_embeddings, batch_size=100)

In [None]:
len(embeddings)

In [None]:
print(min(embeddings), max(embeddings))

In [None]:
merged_df.head()

In [None]:
wiki_df.head()

In [None]:
merged_df.iloc[i]['wiki_text']

In [None]:
for i, claim in enumerate(claims[:10]):
    similar_context_index = embeddings[i]
    wiki_passage = wiki_df.iloc[similar_context_index]['wiki_text']

    print(f"Claim: {claim}")
    print(f"Most similar context: {wiki_passage}")
    print(f"Actual: {merged_df.iloc[i]['wiki_text']}\n\n")


In [None]:
import h5py

# Assuming context_embeddings is your PyTorch tensor on the CUDA device
# First, move the tensor to the CPU, then convert it to a NumPy array
context_embeddings_cpu = context_embeddings.cpu().numpy()

# Now you can save it using h5py
with h5py.File('embeddings/merged_embeddings.h5', 'w') as file:
    file.create_dataset('merged_embeddings', data=context_embeddings_cpu)
