In [12]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from FlagEmbedding import BGEM3FlagModel


import json
from tqdm.notebook import tqdm

import numpy as np
from typing import List, Tuple, Dict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import torch

### Label Data

In [3]:
df_questions = pd.read_csv('questions_train.csv')

In [4]:
all_ids_labels = []
for article_ids in df_questions['article_ids']:
    all_ids_labels.append(article_ids)


### Load Prediction

In [5]:
import json
import gzip

# Load the predictions from the JSON file
with gzip.open('all_predictions_RSChunk.json.gz', 'rt', encoding='utf-8') as f:
    all_predictions = json.load(f)

In [6]:
predicted_ids = []
predicted_article = []

for i in range(len(all_predictions)):
    predicted_ids.append(all_predictions[str(i)]['predicted_ids'])
    predicted_article.append(all_predictions[str(i)]['predictions_articles'])

### RERANKING FUNCT

In [26]:
import numpy as np
from typing import List
from tqdm import tqdm

def optimize_reranking(
    predicted_ids: List[List[int]],
    predicted_articles: List[List[str]],
    questions: List[str],
    ids_labels: List[str],
    bge_m3,
    w_d: float = 0.4,
    w_s: float = 0.2,
    w_c: float = 0.4,
    batch_size: int = 12,
    max_length: int = 512
) -> List[List[int]]:
    
    def process_batch(question: str, articles: List[str]) -> np.ndarray:
        """Process a single question and its candidate articles using compute_score"""
        # Create sentence pairs for the question and all articles
        sentence_pairs = [[question, article] for article in articles]
        
        # Compute scores using the official method
        scores = bge_m3.compute_score(
            sentence_pairs,
            max_passage_length=max_length,
            weights_for_different_modes=[w_d, w_s, w_c]  # [dense, sparse, colbert]
        )
        
        # Return the combined scores
        return scores['colbert+sparse+dense']

    # Process each question and its candidate articles
    reranked_predictions = []
    
    # Use tqdm for progress tracking
    for q, arts, preds in tqdm(zip(questions, predicted_articles, predicted_ids), 
                              total=len(questions), 
                              desc="Processing questions"):
        try:
            # Process articles in batches if there are many
            if len(arts) > batch_size:
                all_scores = []
                for i in range(0, len(arts), batch_size):
                    batch_arts = arts[i:i + batch_size]
                    batch_scores = process_batch(q, batch_arts)
                    all_scores.extend(batch_scores)
                similarities = np.array(all_scores)
            else:
                # Process all articles at once if within batch size
                similarities = process_batch(q, arts)
            
            # Sort predictions based on similarities
            reranked_indices = np.argsort(-similarities)
            reranked_pred = [preds[idx] for idx in reranked_indices]
            reranked_predictions.append(reranked_pred)
            
        except Exception as e:
            print(f"Error processing question: {q[:100]}...")
            print(f"Error: {str(e)}")
            # In case of error, keep original ordering
            reranked_predictions.append(preds)
    
    return reranked_predictions

### Embeddings model

In [2]:
bge_m3 = BGEM3FlagModel('BAAI/bge-m3',  
                       use_fp16=True, 
                       device='cuda')

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

  colbert_state_dict = torch.load(os.path.join(model_dir, 'colbert_linear.pt'), map_location='cpu')
  sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')


In [None]:
reranked_preds = optimize_reranking(
    predicted_ids=predicted_ids,
    predicted_articles=predicted_article,
    questions=df_questions['question'].tolist(),
    ids_labels=all_ids_labels,
    bge_m3=bge_m3,
    w_d=0.4,
    w_s=0.2,
    w_c=0.4
)

Processing questions: 100%|██████████| 886/886 [42:18<00:00,  2.87s/it]  


#### Compute Metric 

In [29]:
def Eval_Retrieval(all_predictions, articles_ids, top_k=20):
    # Assure que les articles_ids sont bien sous forme de liste d'ID (int) pour chaque requête
    articles_ids = [list(map(int, ids.split(','))) for ids in articles_ids]
    
    # Initialize metrics
    precisions = []
    recalls = []
    f1_scores = []
    average_precisions = []
    reciprocal_ranks = []

    # Nombre total de questions
    Q = len(all_predictions)

    # Calcul des métriques pour chaque ensemble de prédictions
    for preds, true_ids in zip(all_predictions, articles_ids):
        # Limiter les prédictions à top_k résultats
        preds = preds[:top_k]
        
        # Convertir les prédictions en set pour faciliter les calculs
        preds_set = set(preds)
        true_set = set(true_ids)

        # Calcul des True Positives (TP), False Positives (FP), et False Negatives (FN)
        tp = len(preds_set & true_set) # intersection (in both)
        fp = len(preds_set - true_set)  # Difference (in pred but not in true)
        fn = len(true_set - preds_set)  #Difference (in true but not in pred)

        # Calcul Precision, Recall, F1-Score
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)

        # Calcul de l'Average Precision (AP)
        ap = 0
        relevant_count = 0
        for rank, pred in enumerate(preds, 1):  # rank starts at 1
            if pred in true_set:
                relevant_count += 1
                ap += relevant_count / rank
        ap /= len(true_set) if len(true_set) > 0 else 1
        average_precisions.append(ap)

        # Calcul Mean Reciprocal Rank (MRR)
        mrr = 0
        for rank, pred in enumerate(preds, 1):
            if pred in true_set:
                mrr = 1 / rank
                break
        reciprocal_ranks.append(mrr)

    # Calcul des métriques globales
    mean_precision = sum(precisions) / Q
    mean_recall = sum(recalls) / Q
    mean_f1 = sum(f1_scores) / Q
    mean_ap = sum(average_precisions) / Q
    mean_mrr = sum(reciprocal_ranks) / Q

    # Retourner les métriques sous forme de dictionnaire
    return {
        "mean_precision": mean_precision,
        "mean_recall": mean_recall,
        "mean_f1_score": mean_f1,
        "MAP": mean_ap,
        "MRR": mean_mrr
    }


In [32]:
# Evaluate results
Eval_Retrieval(reranked_preds, all_ids_labels, top_k=3)

{'mean_precision': 0.20786305492851767,
 'mean_recall': 0.2296581251630902,
 'mean_f1_score': 0.17475715458591964,
 'MAP': 0.2194387862703324,
 'MRR': 0.3502633559066967}

Mean Precision (0.056): Only about 5.6% of retrieved documents are relevant
Mean Recall (0.228): Only about 22.8% of relevant documents are being retrieved
F1 Score (0.066): Very low harmonic mean of precision and recall
MAP (0.076): Low mean average precision indicates poor ranking of relevant documents
MRR (0.182): First relevant document appears, on average, at around position 5-6