In [17]:
import pandas as pd
import pickle
import torch
from sentence_transformers import SentenceTransformer, util
from torch.nn.functional import normalize
import re
from collections import Counter
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

In [18]:
class QueryExpander:
    #def __init__(self, model_name='msmarco-distilbert-base-tas-b'):
    def __init__(self, model_name='multi-qa-mpnet-base-cos-v1'):
        self.model = SentenceTransformer(model_name)
        self.word_pool = None
        self.word_embeddings = None
        self.tfidf_vectorizer = None
        
    def build_word_pool(self, papers_df, min_freq=5, max_terms=100):
        """Build a domain-specific word pool from the corpus"""
        # Create text column if it doesn't exist
        if 'text' not in papers_df.columns:
            papers_df['text'] = papers_df['title'] + '. ' + papers_df['abstract']
        
        # Combine all paper texts
        all_texts = papers_df['text'].tolist()
        
        # Initialize TF-IDF vectorizer
        self.tfidf_vectorizer = TfidfVectorizer(
            max_features=max_terms,
            min_df=min_freq,
            stop_words='english'
        )
        
        # Fit and transform the texts
        tfidf_matrix = self.tfidf_vectorizer.fit_transform(all_texts)
        
        # Get feature names (terms)
        self.word_pool = self.tfidf_vectorizer.get_feature_names_out()
        
        # Pre-compute embeddings for the word pool
        self.word_embeddings = self.model.encode(
            self.word_pool.tolist(),
            convert_to_tensor=True,
            show_progress_bar=True
        )
        
        return self.word_pool
    
    def clean_text(self, text):
        """Clean and normalize text"""
        text = text.lower()
        text = re.sub(r"http\S+|www\S+", "", text)  # remove URLs
        text = re.sub(r"[@#]\w+", "", text)         # remove @mentions and #hashtags
        text = re.sub(r"[^\w\s\-/]", "", text)      # keep alphanum + dash/slash
        return text.strip()
    
    def expand_query(self, query, top_n=3, expansion_weight=0.3):
        """
        Expand query using semantic similarity and TF-IDF weighting
        
        Args:
            query: Original query text
            top_n: Number of terms to add
            expansion_weight: Weight for expanded terms (0-1)
        """
        # Clean the query
        clean_query = self.clean_text(query)
        
        # Get query embedding
        query_emb = self.model.encode(clean_query, convert_to_tensor=True)
        
        # Calculate similarity scores
        scores = util.cos_sim(query_emb, self.word_embeddings)[0]
        
        # Get top terms
        top_ids = scores.topk(top_n).indices
        top_terms = [self.word_pool[i] for i in top_ids]
        
        # Calculate term weights based on similarity scores
        term_weights = scores[top_ids].tolist()
        term_weights = [w * expansion_weight for w in term_weights]
        
        # Create weighted expansion
        expanded_terms = []
        for term, weight in zip(top_terms, term_weights):
            # Repeat terms based on their weight
            repeat_count = max(1, int(weight * 3))
            expanded_terms.extend([term] * repeat_count)
        
        # Combine original query with expanded terms
        expanded_query = f"{clean_query} {' '.join(expanded_terms)}"
        
        return expanded_query
    
    def batch_expand_queries(self, queries, top_n=3, expansion_weight=0.3):
        """Expand a batch of queries"""
        return [self.expand_query(q, top_n, expansion_weight) for q in queries]

In [19]:
class RetrievalSystem:
    def __init__(self, model_name='msmarco-distilbert-base-tas-b'):
        self.model = SentenceTransformer(model_name)
        self.paper_embeddings = None
        self.papers_df = None
        
    def load_papers(self, papers_df):
        """Load and encode paper collection"""
        self.papers_df = papers_df.copy()  # Create a copy to avoid modifying the original
        if 'text' not in self.papers_df.columns:
            self.papers_df['text'] = self.papers_df['title'] + '. ' + self.papers_df['abstract']
        
        # Encode papers
        self.paper_embeddings = self.model.encode(
            self.papers_df['text'].tolist(),
            show_progress_bar=True,
            convert_to_tensor=True
        )
        
        # Normalize embeddings
        self.paper_embeddings = normalize(self.paper_embeddings, p=2, dim=1)
        
    def get_topk_predictions(self, query_embeddings, top_k=5, batch_size=16):
        """Get top-k predictions using batched processing"""
        predictions = []
        
        for start_idx in range(0, len(query_embeddings), batch_size):
            end_idx = min(start_idx + batch_size, len(query_embeddings))
            query_batch = query_embeddings[start_idx:end_idx]
            query_norm = normalize(query_batch, p=2, dim=1)
            
            # Calculate similarity scores
            similarity_matrix = torch.matmul(query_norm, self.paper_embeddings.T)
            _, top_k_indices = torch.topk(similarity_matrix, k=top_k, dim=1)
            
            # Get predictions
            batch_predictions = [
                self.papers_df.iloc[indices.tolist()]['cord_uid'].tolist()
                for indices in top_k_indices
            ]
            predictions.extend(batch_predictions)
            
        return predictions

In [20]:
def evaluate_performance(data, col_gold, col_pred, list_k=[1, 5, 10]):
    """Evaluate retrieval performance using MRR"""
    d_performance = {}
    for k in list_k:
        scores = []
        for _, row in data.iterrows():
            gold = row[col_gold]
            preds = row[col_pred]
            if isinstance(preds, str):
                try:
                    preds = eval(preds)
                except:
                    preds = []
            if gold in preds[:k]:
                rank = preds[:k].index(gold) + 1
                scores.append(1.0 / rank)
            else:
                scores.append(0.0)
        d_performance[k] = sum(scores) / len(scores) if scores else 0.0
    return d_performance

In [21]:
def main():
    # Load data
    with open("/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_collection_data.pkl", "rb") as f:
        papers_df = pickle.load(f)
    
    train_df = pd.read_csv("/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_query_tweets_train.tsv", sep="\t", 
                          names=["post_id", "tweet_text", "cord_uid"])
    dev_df = pd.read_csv("/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_query_tweets_dev.tsv", sep="\t", 
                        names=["post_id", "tweet_text", "cord_uid"])
    
    # Initialize systems
    expander = QueryExpander()
    retrieval = RetrievalSystem()
    
    # Build word pool and load papers
    expander.build_word_pool(papers_df)
    retrieval.load_papers(papers_df)
    
    # Expand queries
    train_df['expanded_text'] = expander.batch_expand_queries(
        train_df['tweet_text'].tolist(),
        top_n=3,
        expansion_weight=0.3
    )
    dev_df['expanded_text'] = expander.batch_expand_queries(
        dev_df['tweet_text'].tolist(),
        top_n=3,
        expansion_weight=0.3
    )
    
    # Encode expanded queries
    train_query_embeddings = retrieval.model.encode(
        train_df['expanded_text'].tolist(),
        show_progress_bar=True,
        convert_to_tensor=True
    )
    dev_query_embeddings = retrieval.model.encode(
        dev_df['expanded_text'].tolist(),
        show_progress_bar=True,
        convert_to_tensor=True
    )
    
    # Get predictions
    train_df['preds'] = retrieval.get_topk_predictions(train_query_embeddings)
    dev_df['preds'] = retrieval.get_topk_predictions(dev_query_embeddings)
    
    # Evaluate
    train_mrr = evaluate_performance(train_df, 'cord_uid', 'preds')
    dev_mrr = evaluate_performance(dev_df, 'cord_uid', 'preds')
    
    print("Train MRR:", train_mrr)
    print("Dev MRR:", dev_mrr)

In [22]:
if __name__ == "__main__":
    main() 

Batches: 100%|██████████| 4/4 [00:00<00:00, 12.64it/s]
Batches: 100%|██████████| 242/242 [11:25<00:00,  2.83s/it]
Batches: 100%|██████████| 402/402 [01:58<00:00,  3.38it/s]
Batches: 100%|██████████| 44/44 [00:12<00:00,  3.42it/s]


Train MRR: {1: 0.33016959701260307, 5: 0.39050490119806874, 10: 0.39050490119806874}
Dev MRR: {1: 0.33547466095645967, 5: 0.38929336188436786, 10: 0.38929336188436786}
