In [13]:
import os
import json
from pathlib import Path
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")

# MedRAG Corpus Statistics

This notebook computes corpus statistics for the 4 MedRAG sources: PubMed, Wikipedia, Textbooks, and StatPearls.

**Measurements:**
- **#Doc.**: Number of unique documents 
- **#Snippets**: Number of text chunks
- **Avg. L**: Average snippet length in characters

In [14]:
# Paths
BASE_CORPUS_DIR = Path('/data/wang/junh/githubs/mirage_medrag/MedRAG/src/data/corpus')
HF_CACHE_DIR = Path('/data/wang/junh/githubs/mirage_medrag/MedRAG/src/data/hf_cache')

print('Corpus directory:', BASE_CORPUS_DIR)
print('HF cache directory:', HF_CACHE_DIR)
print('Corpus exists:', BASE_CORPUS_DIR.exists())
print('HF cache exists:', HF_CACHE_DIR.exists())

Corpus directory: /data/wang/junh/githubs/mirage_medrag/MedRAG/src/data/corpus
HF cache directory: /data/wang/junh/githubs/mirage_medrag/MedRAG/src/data/hf_cache
Corpus exists: True
HF cache exists: True


In [15]:
def get_text_content(obj):
    """Extract text content from a JSON object."""
    for field in ['content', 'contents', 'text', 'body']:
        if field in obj and isinstance(obj[field], str):
            return obj[field]
    return ""

def get_document_id(obj):
    """Extract document ID from a JSON object."""
    for field in ['PMID', 'pmid', 'paper_id', 'paperId', 'document_id', 'doc_id']:
        if field in obj:
            return str(obj[field])
    return None

def process_corpus(corpus_path):
    """Process a single corpus and return statistics."""
    corpus_name = corpus_path.name
    chunk_dir = corpus_path / 'chunk'
    
    if not chunk_dir.exists():
        print(f"No chunk directory found for {corpus_name}")
        return None
    
    chunk_files = list(chunk_dir.glob('*.jsonl'))
    if not chunk_files:
        print(f"No JSONL files found for {corpus_name}")
        return None
    
    total_snippets = 0
    total_chars = 0
    unique_docs = set()
    
    print(f"Processing {corpus_name} ({len(chunk_files)} files)...")
    
    for file_path in tqdm(chunk_files, desc=f"{corpus_name}"):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    
                    try:
                        obj = json.loads(line)
                        total_snippets += 1
                        
                        # Get text content and length
                        text = get_text_content(obj)
                        total_chars += len(text)
                        
                        # Get document ID
                        doc_id = get_document_id(obj)
                        if doc_id:
                            unique_docs.add(doc_id)
                            
                    except json.JSONDecodeError:
                        continue
                        
        except Exception as e:
            print(f"Error reading {file_path}: {e}")
            continue
    
    avg_length = total_chars / total_snippets if total_snippets > 0 else 0
    n_docs = len(unique_docs) if unique_docs else total_snippets
    
    return {
        'Corpus': corpus_name,
        '#Doc.': n_docs,
        '#Snippets': total_snippets,
        'Avg. L': round(avg_length)
    }

In [None]:
# 🎯 SIMPLIFIED CORRECT SOLUTION: Direct path processing
print("📊 MedRAG Corpus Statistics (Correct Paths)")
print("=" * 50)

# Define correct paths for each corpus
corpus_paths = {
    'statpearls': BASE_CORPUS_DIR / 'statpearls' / 'chunk',
    'pubmed': HF_CACHE_DIR / 'datasets--MedRAG--pubmed/snapshots' / list((HF_CACHE_DIR / 'datasets--MedRAG--pubmed/snapshots').glob('*'))[0].name / 'chunk',
    'wikipedia': HF_CACHE_DIR / 'datasets--MedRAG--wikipedia/snapshots' / list((HF_CACHE_DIR / 'datasets--MedRAG--wikipedia/snapshots').glob('*'))[0].name / 'chunk',
    'textbooks': HF_CACHE_DIR / 'datasets--MedRAG--textbooks/snapshots' / list((HF_CACHE_DIR / 'datasets--MedRAG--textbooks/snapshots').glob('*'))[0].name / 'chunk'
}

def process_corpus_simple(corpus_name, chunk_dir):
    """Simple corpus processing for correct statistics."""
    if not chunk_dir.exists():
        print(f"❌ {corpus_name}: Chunk directory not found")
        return None
    
    chunk_files = list(chunk_dir.glob('*.jsonl'))
    if not chunk_files:
        print(f"❌ {corpus_name}: No JSONL files found")
        return None
    
    print(f"📁 {corpus_name}: {len(chunk_files)} files")
    
    total_snippets = 0
    total_chars = 0
    unique_docs = set()
    
    # NEW
for file_path in tqdm(chunk_files, desc=f"{corpus_name}"):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    
                    try:
                        obj = json.loads(line)
                        total_snippets += 1
                        
                        # Get text content
                        text = get_text_content(obj)
                        total_chars += len(text)
                        
                        # Get document ID  
                        doc_id = get_document_id(obj)
                        if doc_id:
                            unique_docs.add(doc_id)
                            
                    except json.JSONDecodeError:
                        continue
                        
        except Exception as e:
            print(f"Error reading {file_path}: {e}")
            continue
    
    avg_length = total_chars / total_snippets if total_snippets > 0 else 0
    n_docs = len(unique_docs) if unique_docs else total_snippets
    
    return {
        'Corpus': corpus_name,
        '#Doc.': n_docs,
        '#Snippets': total_snippets,
        'Avg. L': round(avg_length)
    }

# Process all corpora
final_results = []
for corpus_name, chunk_dir in corpus_paths.items():
    stats = process_corpus_simple(corpus_name, chunk_dir)
    if stats:
        final_results.append(stats)

# Display results
if final_results:
    df_final = pd.DataFrame(final_results)
    
    # Format numbers
    df_final['#Doc. (M)'] = (df_final['#Doc.'] / 1_000_000).round(1) 
    df_final['#Snippets (M)'] = (df_final['#Snippets'] / 1_000_000).round(1)
    
    print(f"\n" + "="*50)
    print("📊 MEDRAG CORPUS STATISTICS")
    print("="*50)
    display(df_final[['Corpus', '#Doc. (M)', '#Snippets (M)', 'Avg. L']])
    
    # Summary totals
    total_docs = df_final['#Doc.'].sum()
    total_snippets = df_final['#Snippets'].sum()
    print(f"\n🎯 TOTAL: {total_docs:,} docs, {total_snippets:,} snippets")
    
else:
    print("No data processed!")

📊 MedRAG Corpus Statistics (Correct Paths)
📁 statpearls: 9625 files


statpearls: 100%|██████████| 9625/9625 [00:03<00:00, 2875.13it/s]


📁 pubmed: 1166 files


pubmed: 100%|██████████| 10/10 [00:01<00:00,  6.08it/s]


📁 wikipedia: 646 files


wikipedia: 100%|██████████| 646/646 [06:45<00:00,  1.59it/s]


📁 textbooks: 18 files


textbooks: 100%|██████████| 18/18 [00:02<00:00,  8.35it/s]


📊 MEDRAG CORPUS STATISTICS





Unnamed: 0,Corpus,#Doc. (M),#Snippets (M),Avg. L
0,statpearls,0.4,0.4,516
1,pubmed,0.1,0.1,915
2,wikipedia,29.9,29.9,682
3,textbooks,0.1,0.1,777



🎯 TOTAL: 30,524,583 docs, 30,524,583 snippets
