In [1]:
import os
import argparse
from pathlib import Path
import re
import random
from typing import Dict, List, Tuple, Set
from datasets import load_from_disk
import pandas as pd
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def extract_dataset_params(dataset_name: str) -> Tuple[str, str, str, str]:
    """Extract parameters from dataset name."""
    parts = dataset_name.split('_')
    base_name = parts[0]
    source = parts[1]
    cutoff = parts[2]
    lookback = parts[3]
    
    return base_name, source, cutoff, lookback

def load_all_datasets(directory: str) -> Dict[str, any]:
    """Load all Hugging Face datasets from the given directory."""
    datasets = {}
    
    for path in Path(directory).glob("*"):
        if path.is_dir():
            try:
                dataset_name = path.name
                print(f"Loading dataset: {dataset_name}")
                dataset = load_from_disk(str(path))
                datasets[dataset_name] = dataset
                print(f"  - Loaded {len(dataset)} entries")
            except Exception as e:
                print(f"  - Error loading dataset {dataset_name}: {str(e)}")
    
    return datasets



In [3]:
directory="/is/cluster/fast/sgoel/forecasting/news/retrieval"
output="retrieval_comparison_results.csv"
samples=3

print(f"Loading datasets from {directory}")
datasets = load_all_datasets(directory)

NameError: name 'args' is not defined

In [None]:
def get_article_identifier(article: Dict) -> str:
    """Create a unique identifier for an article based on its URL and title."""
    url = article.get('url', '')
    title = article.get('title', '')
    return f"{url}_{title}"

def compare_article_lists(list1: List[Dict], list2: List[Dict]) -> Tuple[int, Set[str], Set[str], float]:
    """
    Compare two lists of articles and return information about differences.
    
    Returns:
        - Number of common articles
        - Set of article IDs unique to list1
        - Set of article IDs unique to list2
        - Average position change for common articles
    """
    # Create sets of article identifiers and position mappings
    ids1 = {get_article_identifier(article) for article in list1}
    ids2 = {get_article_identifier(article) for article in list2}
    
    # Create position mappings (article_id -> position)
    pos_map1 = {get_article_identifier(article): i for i, article in enumerate(list1)}
    pos_map2 = {get_article_identifier(article): i for i, article in enumerate(list2)}
    
    # Find unique articles
    unique_to_list1 = ids1 - ids2
    unique_to_list2 = ids2 - ids1
    
    # Calculate position changes for common articles
    common_articles = ids1.intersection(ids2)
    position_changes = []
    
    for article_id in common_articles:
        pos1 = pos_map1[article_id]
        pos2 = pos_map2[article_id]
        position_changes.append(abs(pos1 - pos2))
    
    # Calculate average position change
    avg_position_change = sum(position_changes) / len(position_changes) if position_changes else 0
    
    return len(common_articles), unique_to_list1, unique_to_list2, avg_position_change

def run_comparison(datasets: Dict[str, any]) -> pd.DataFrame:
    """Run comparison between all dataset pairs and return a DataFrame with results."""
    results = []
    
    # Get all dataset names
    dataset_names = list(datasets.keys())
    
    # Create all possible pairs
    for i in range(len(dataset_names)):
        for j in range(i+1, len(dataset_names)):
            name1 = dataset_names[i]
            name2 = dataset_names[j]
            
            ds1 = datasets[name1]
            ds2 = datasets[name2]
            
            # Extract parameters
            base1, source1, cutoff1, lookback1 = extract_dataset_params(name1)
            base2, source2, cutoff2, lookback2 = extract_dataset_params(name2)
            
            # Only compare datasets with same base name and source
            if base1 != base2 or source1 != source2:
                continue
                
            # If datasets have different lengths, use the smaller one
            common_length = min(len(ds1), len(ds2))
            
            print(f"Comparing {name1} vs {name2}")
            
            # Track statistics
            total_questions = 0
            questions_with_different_articles = 0
            questions_with_reranking = 0
            total_reranked_articles = 0
            sum_position_changes = 0
            
            # Store sample questions with differences
            different_article_samples = []
            
            # Iterate through questions
            for q_idx in tqdm(range(common_length), desc="Comparing questions"):
                if "retrieved_articles" not in ds1[q_idx] or "retrieved_articles" not in ds2[q_idx]:
                    continue
                    
                articles1 = ds1[q_idx]["retrieved_articles"]
                articles2 = ds2[q_idx]["retrieved_articles"]
                
                # Skip if either has no articles
                if not articles1 or not articles2:
                    continue
                
                total_questions += 1
                
                # Compare article lists
                common_count, unique_to_1, unique_to_2, avg_position_change = compare_article_lists(articles1, articles2)
                
                # Update statistics
                if unique_to_1 or unique_to_2:
                    questions_with_different_articles += 1
                    
                    # Collect sample if interesting (has differences)
                    if len(different_article_samples) < 10:
                        different_article_samples.append({
                            "q_idx": q_idx,
                            "question": ds1[q_idx].get("question", ""),
                            "unique_to_1": list(unique_to_1),
                            "unique_to_2": list(unique_to_2),
                            "articles1": articles1,
                            "articles2": articles2
                        })
                
                if common_count > 0 and avg_position_change > 0:
                    questions_with_reranking += 1
                    total_reranked_articles += common_count
                    sum_position_changes += avg_position_change
            
            # Calculate average position change across all questions
            avg_position_change_overall = sum_position_changes / questions_with_reranking if questions_with_reranking > 0 else 0
            
            # Add comparison results to results list
            results.append({
                "dataset1": name1,
                "dataset2": name2,
                "cutoff1": cutoff1,
                "cutoff2": cutoff2,
                "lookback1": lookback1,
                "lookback2": lookback2,
                "total_questions": total_questions,
                "questions_with_different_articles": questions_with_different_articles,
                "questions_with_different_articles_pct": questions_with_different_articles / total_questions * 100 if total_questions > 0 else 0,
                "questions_with_reranking": questions_with_reranking,
                "questions_with_reranking_pct": questions_with_reranking / total_questions * 100 if total_questions > 0 else 0,
                "total_reranked_articles": total_reranked_articles,
                "avg_position_change": avg_position_change_overall,
                "samples": different_article_samples
            })
    
    return pd.DataFrame(results)