In [1]:
"""
Bi-Encoder Implementation for Scientific Claim Source Retrieval

This module implements a bi-encoder architecture for retrieving scientific papers 
that support claims made in tweets. The implementation uses a two-stage retrieval process:
1. First stage: Bi-encoder retrieval to get top-k candidates
2. Second stage: Re-ranking or filtering of candidates
"""

'\nBi-Encoder Implementation for Scientific Claim Source Retrieval\n\nThis module implements a bi-encoder architecture for retrieving scientific papers \nthat support claims made in tweets. The implementation uses a two-stage retrieval process:\n1. First stage: Bi-encoder retrieval to get top-k candidates\n2. Second stage: Re-ranking or filtering of candidates\n'

In [10]:
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, util, InputExample, losses
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from typing import List, Dict, Tuple, Optional
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [11]:
class Config:
    """Configuration class for the bi-encoder implementation."""
    def __init__(self):
        # Model configuration
        self.model_name = "multi-qa-MiniLM-L6-cos-v1"  # Base model for bi-encoder
        self.fine_tuned_model_path = "/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/fine_tuned_sbert"  # Path for fine-tuned model
        
        # Data paths
        self.collection_path = "/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_collection_data.pkl"
        self.train_path = "/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_query_tweets_train.tsv"
        self.dev_path = "/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_query_tweets_dev.tsv"
        
        # Retrieval parameters
        self.top_k_retrieve = 20  # Number of candidates to retrieve in first stage
        self.top_k_evaluate = 5   # Number of final results to evaluate
        
        # Training parameters
        self.batch_size = 16
        self.learning_rate = 2e-5
        self.num_epochs = 4
        self.warmup_steps = 100
        
        # Device configuration
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Output paths
        self.output_dir = "outputs"
        self.train_submission_file = "nlp_train_submission.tsv"
        self.dev_submission_file = "nlp_dev_submission.tsv"

In [12]:
class DataProcessor:
    """Handles data loading and preprocessing."""
    
    @staticmethod
    def load_collection_data(path: str) -> pd.DataFrame:
        """Load and preprocess the collection data (scientific papers)."""
        try:
            df = pd.read_pickle(path)
            # Combine title and abstract
            df['combined_text'] = df['title'].fillna('') + ' ' + df['abstract'].fillna('')
            return df
        except Exception as e:
            logger.error(f"Error loading collection data: {e}")
            raise

    @staticmethod
    def load_query_data(path: str) -> pd.DataFrame:
        """Load and preprocess the query data (tweets)."""
        try:
            return pd.read_csv(path, sep='\t')
        except Exception as e:
            logger.error(f"Error loading query data: {e}")
            raise

    @staticmethod
    def prepare_training_examples(df_train: pd.DataFrame, df_collection: pd.DataFrame) -> List[InputExample]:
        """Prepare training examples for fine-tuning."""
        examples = []
        paper_dict = dict(zip(df_collection['cord_uid'], df_collection['combined_text']))
        
        for _, row in df_train.iterrows():
            if row['cord_uid'] in paper_dict:
                tweet = row['tweet_text']
                paper = paper_dict[row['cord_uid']]
                examples.append(InputExample(texts=[tweet, paper]))
        
        return examples

In [14]:
class BiEncoder:
    """Bi-encoder model for scientific claim source retrieval."""
    
    def __init__(self, config: Config):
        """Initialize the bi-encoder model."""
        self.config = config
        self.model = SentenceTransformer(config.model_name)
        self.model.to(config.device)
        logger.info(f"Initialized model: {config.model_name}")
    
    def encode_queries(self, queries: List[str]) -> torch.Tensor:
        """Encode a list of queries into dense vectors."""
        return self.model.encode(
            queries,
            convert_to_tensor=True,
            show_progress_bar=True,
            device=self.config.device
        )
    
    def encode_documents(self, documents: List[str]) -> torch.Tensor:
        """Encode a list of documents into dense vectors."""
        return self.model.encode(
            documents,
            convert_to_tensor=True,
            show_progress_bar=True,
            device=self.config.device
        )
    
    def fine_tune(self, train_examples: List[InputExample]):
        """Fine-tune the model on the training examples."""
        train_dataloader = DataLoader(
            train_examples,
            shuffle=True,
            batch_size=self.config.batch_size
        )
        
        # Use MultipleNegativesRankingLoss for bi-encoder training
        train_loss = losses.MultipleNegativesRankingLoss(self.model)
        
        # Fine-tune the model
        self.model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            epochs=self.config.num_epochs,
            warmup_steps=self.config.warmup_steps,
            show_progress_bar=True,
            optimizer_params={'lr': self.config.learning_rate}
        )
        
        # Save the fine-tuned model
        self.model.save(self.config.fine_tuned_model_path)
        logger.info(f"Fine-tuned model saved to {self.config.fine_tuned_model_path}")

In [15]:
class Retriever:
    """Handles the two-stage retrieval process."""
    
    def __init__(self, model: BiEncoder, config: Config):
        """Initialize the retriever."""
        self.model = model
        self.config = config
    
    def first_stage_retrieval(
        self,
        query: str,
        corpus_embeddings: torch.Tensor,
        paper_ids: List[str]
    ) -> Tuple[List[str], List[float]]:
        """First stage: Retrieve top-k candidates using bi-encoder."""
        query_embedding = self.model.encode_queries([query])[0]
        cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
        
        # Get top-k candidates
        top_indices = cos_scores.argsort(descending=True)[:self.config.top_k_retrieve]
        retrieved_ids = [paper_ids[i] for i in top_indices]
        scores = [cos_scores[i].item() for i in top_indices]
        
        return retrieved_ids, scores
    
    def second_stage_processing(
        self,
        candidates: List[str],
        scores: List[float]
    ) -> List[str]:
        """Second stage: Process candidates to get final results."""
        # For now, just take top-k from first stage
        # This can be extended with re-ranking or filtering
        return candidates[:self.config.top_k_evaluate]

In [16]:
def evaluate_mrr(predictions: List[List[str]], gold_standard: List[str]) -> float:
    """Calculate Mean Reciprocal Rank (MRR)."""
    mrr_total = 0.0
    for preds, gold in zip(predictions, gold_standard):
        for i, pred in enumerate(preds):
            if pred == gold:
                mrr_total += 1.0 / (i + 1)
                break
    return mrr_total / len(predictions)

In [17]:
def process_dataset(
    retriever: Retriever,
    df: pd.DataFrame,
    corpus_embeddings: torch.Tensor,
    paper_ids: List[str],
    config: Config,
    is_training: bool = False
) -> Tuple[List[List[str]], float]:
    """Process a dataset (train or dev) and return predictions and MRR."""
    predictions = []
    
    # Process each query
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing queries"):
        candidates, scores = retriever.first_stage_retrieval(
            row['tweet_text'],
            corpus_embeddings,
            paper_ids
        )
        final_results = retriever.second_stage_processing(candidates, scores)
        predictions.append(final_results)
    
    # Calculate MRR if not training
    mrr = None
    if not is_training:
        mrr = evaluate_mrr(predictions, df['cord_uid'].tolist())
        logger.info(f"MRR@5: {mrr:.4f}")
    
    return predictions, mrr

In [18]:
def save_predictions(
    predictions: List[List[str]],
    df: pd.DataFrame,
    output_path: str,
    config: Config
):
    """Save predictions to a TSV file."""
    submission_df = pd.DataFrame({
        'post_id': df['post_id'],
        'preds': [str(preds) for preds in predictions]
    })
    submission_df.to_csv(
        os.path.join(config.output_dir, output_path),
        sep='\t',
        index=False
    )
    logger.info(f"Predictions saved to {output_path}")

In [19]:
def main():
    """Main execution function."""
    # Initialize configuration
    config = Config()
    os.makedirs(config.output_dir, exist_ok=True)
    
    # Load and prepare data
    data_processor = DataProcessor()
    df_collection = data_processor.load_collection_data(config.collection_path)
    df_train = data_processor.load_query_data(config.train_path)
    df_dev = data_processor.load_query_data(config.dev_path)
    
    logger.info(f"Loaded {len(df_collection)} papers, {len(df_train)} training queries, and {len(df_dev)} development queries")
    
    # Initialize model
    model = BiEncoder(config)
    
    # Fine-tune if needed
    if not os.path.exists(config.fine_tuned_model_path):
        logger.info("Fine-tuning model on training set...")
        train_examples = data_processor.prepare_training_examples(df_train, df_collection)
        model.fine_tune(train_examples)
    else:
        logger.info("Loading fine-tuned model...")
        model.model = SentenceTransformer(config.fine_tuned_model_path)
    
    # Initialize retriever
    retriever = Retriever(model, config)
    
    # Encode corpus
    logger.info("Encoding corpus...")
    corpus_embeddings = model.encode_documents(df_collection['combined_text'].tolist())
    paper_ids = df_collection['cord_uid'].tolist()
    
    # Process training set
    logger.info("Processing training set...")
    train_predictions, _ = process_dataset(
        retriever,
        df_train,
        corpus_embeddings,
        paper_ids,
        config,
        is_training=True
    )
    save_predictions(
        train_predictions,
        df_train,
        config.train_submission_file,
        config
    )
    
    # Process development set
    logger.info("Processing development set...")
    dev_predictions, dev_mrr = process_dataset(
        retriever,
        df_dev,
        corpus_embeddings,
        paper_ids,
        config,
        is_training=False
    )
    save_predictions(
        dev_predictions,
        df_dev,
        config.dev_submission_file,
        config
    )
    
    logger.info(f"Final MRR@5 on development set: {dev_mrr:.4f}")
    logger.info("Processing complete!")

In [20]:
if __name__ == "__main__":
    main() 

INFO:__main__:Loaded 7718 papers, 12853 training queries, and 1400 development queries
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: multi-qa-MiniLM-L6-cos-v1
Downloading .gitattributes: 100%|██████████| 791/791 [00:00<00:00, 1.77MB/s]
Downloading config.json: 100%|██████████| 190/190 [00:00<00:00, 965kB/s]
Downloading README.md: 100%|██████████| 11.6k/11.6k [00:00<00:00, 10.6MB/s]
Downloading config.json: 100%|██████████| 612/612 [00:00<00:00, 1.85MB/s]
Downloading (…)ce_transformers.json: 100%|██████████| 116/116 [00:00<00:00, 562kB/s]
Downloading data_config.json: 100%|██████████| 25.5k/25.5k [00:00<00:00, 29.8MB/s]
Downloading model.safetensors: 100%|██████████| 90.9M/90.9M [00:01<00:00, 55.8MB/s]
Downloading model.onnx: 100%|██████████| 90.4M/90.4M [00:01<00:00, 54.7MB/s]
Downloading model_O1.onnx: 100%|██████████| 90.4M/90.4M [00:01<00:00, 60.7MB/s]
Downloading model_O2.onnx: 100%|██████████| 90.3M/90.3M [00:01<00:00, 55.3MB/s]
Downloading mo