**PHASE 1: DATA PREPROCESSING - Collect Citation Data**

Purpose: Fetch citation counts from Semantic Scholar API  
Input: arxiv_metadata_features.pkl  
Output: citation_data_full.pkl (stratified sample, batched collection)  
Strategy: 1M paper sample, 50K batches, progress saved automatically  
ML Involved: None - Data collection via API  
Runtime: ~1.5 hours per 50K papers (500K = 15 hrs, 1M = 30 hrs)  
Note: Citations kept separate, merged later in analysis phase

In [1]:
# setup and load data

import pandas as pd
import numpy as np
import requests
import time
from tqdm import tqdm
import os

# load metadata
df = pd.read_pickle('data/processed/arxiv_metadata_features.pkl')
print(f"Loaded: {len(df):,} papers")

# create citations folder
os.makedirs('data/processed/citations_batches', exist_ok=True)

Loaded: 2,384,622 papers


In [2]:
# target sample size

TARGET_SAMPLE = 1_000_000  # ambitious but seeing if I can get 1m count

# stratified sampling by year and top_level_domain
# this ensures representation across time and fields

sample_fraction = TARGET_SAMPLE / len(df)

df_sample = df.groupby(['year', 'top_level_domain'], group_keys=False).apply(
    lambda x: x.sample(frac=sample_fraction, random_state=42)
).reset_index(drop=True)

# shuffle randomly
df_sample = df_sample.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"Sampled: {len(df_sample):,} papers")
print(f"\nYear distribution:")
print(df_sample['year'].value_counts().sort_index().tail(10))
print(f"\nDomain distribution:")
print(df_sample['top_level_domain'].value_counts().head(10))

# save sample IDs for reference
df_sample[['id', 'year', 'primary_category']].to_pickle(
    'data/processed/citation_sample_ids.pkl'
)
print("\n✓ Sample IDs saved")

  df_sample = df.groupby(['year', 'top_level_domain'], group_keys=False).apply(


Sampled: 999,995 papers

Year distribution:
year
2016     48376
2017     49634
2018     56243
2019     63127
2020     72433
2021     77204
2022     78464
2023     89728
2024    108887
2025    102317
Name: count, dtype: int64

Domain distribution:
top_level_domain
cs          269247
math        206833
cond-mat    106321
astro-ph    102131
physics      75531
quant-ph     41671
hep-ph       35975
hep-th       27828
eess         26331
stat         23476
Name: count, dtype: int64

✓ Sample IDs saved


In [3]:
# check for existing batches

batch_dir = 'data/processed/citations_batches'
existing_batches = sorted([f for f in os.listdir(batch_dir) if f.startswith('batch_')])

if existing_batches:
    print(f"Found {len(existing_batches)} existing batches:")
    
    # load and combine existing batches
    existing_dfs = []
    for batch_file in existing_batches:
        batch_df = pd.read_pickle(os.path.join(batch_dir, batch_file))
        existing_dfs.append(batch_df)
        print(f"{batch_file}: {len(batch_df):,} papers")
    
    df_existing = pd.concat(existing_dfs, ignore_index=True)
    already_collected = set(df_existing['id'].values)
    print(f"\nTotal collected: {len(already_collected):,} papers")
    
    # filter sample to exclude already collected citation papers
    df_sample = df_sample[~df_sample['id'].isin(already_collected)].reset_index(drop=True)
    print(f"Remaining to collect: {len(df_sample):,} papers")
else:
    print("No existing batches found. Starting fresh.")
    already_collected = set()

No existing batches found. Starting fresh.


In [4]:
def fetch_citations_batch(arxiv_ids, batch_name, start_idx=0):
    """
    fetch citations for a list of ArXiv IDs
    saves progress every 1000 papers
    """
    results = []
    errors = 0
    rate_limits = 0
    
    print(f"\nFetching batch: {batch_name}")
    print(f"Papers to fetch: {len(arxiv_ids)}")
    
    for i, arxiv_id in enumerate(tqdm(arxiv_ids, desc=batch_name)):
        try:
            # semantic Scholar API
            url = f"https://api.semanticscholar.org/graph/v1/paper/ARXIV:{arxiv_id}"
            params = {"fields": "citationCount,year"}
            
            response = requests.get(url, params=params, timeout=10)
            
            if response.status_code == 200:
                data = response.json()
                results.append({
                    'id': arxiv_id,
                    'citation_count': data.get('citationCount', 0),
                    'ss_year': data.get('year', None)
                })
            elif response.status_code == 429:
                # rate limited - wait and retry
                rate_limits += 1
                time.sleep(5)
                continue
            else:
                errors += 1
                results.append({
                    'id': arxiv_id,
                    'citation_count': 0,
                    'ss_year': None
                })
            
            # small delay to respect API
            time.sleep(0.1)
            
            # save temporary checkpoint every 1000 papers
            if (i + 1) % 1000 == 0:
                temp_df = pd.DataFrame(results)
                temp_df.to_pickle(f'data/processed/citations_batches/{batch_name}_temp.pkl')
        
        except Exception as e:
            errors += 1
            results.append({
                'id': arxiv_id,
                'citation_count': 0,
                'ss_year': None
            })
            continue
    
    # create final dataframe
    df_results = pd.DataFrame(results)
    
    # save batch
    df_results.to_pickle(f'data/processed/citations_batches/{batch_name}.pkl')
    
    # remove temp file if exists
    temp_file = f'data/processed/citations_batches/{batch_name}_temp.pkl'
    if os.path.exists(temp_file):
        os.remove(temp_file)
    
    print(f"\n✓ Batch complete!")
    print(f"  Successful: {len(results) - errors:,}")
    print(f"  Errors: {errors:,}")
    print(f"  Rate limits hit: {rate_limits:,}")
    
    return df_results

In [5]:
# configuration

BATCH_SIZE = 50_000  # papers per batch
MAX_BATCHES = 10      # 10 batches = 500K papers (to start, will later change to 20 for full 1M if all is well)

# calculate batches
num_batches = min(MAX_BATCHES, (len(df_sample) // BATCH_SIZE) + 1)

print(f"Will collect {num_batches} batches of {BATCH_SIZE:,} papers")
print(f"Total target: {min(num_batches * BATCH_SIZE, len(df_sample)):,} papers")
print(f"Estimated time: {(num_batches * 1.5):.1f} hours at 33K/hour\n")

# verify everything is loaded before starting

try:
    print(f"✓ df_sample exists: {len(df_sample):,} papers")
    print(f"✓ BATCH_SIZE: {BATCH_SIZE:,}")
    print(f"✓ MAX_BATCHES: {MAX_BATCHES}")
except NameError as e:
    print(f"x ERROR: {e}")
    print("Run cells 1-4 first!")
    raise

# collect each batch
for batch_num in range(num_batches):
    start_idx = batch_num * BATCH_SIZE
    end_idx = min(start_idx + BATCH_SIZE, len(df_sample))
    
    batch_ids = df_sample['id'].iloc[start_idx:end_idx].tolist()
    batch_name = f"batch_{batch_num:02d}"
    
    # skip if batch citation pull already exists
    if os.path.exists(f'data/processed/citations_batches/{batch_name}.pkl'):
        print(f"\n✓ {batch_name} already exists, skipping...")
        continue
    
    print(f"BATCH {batch_num + 1}/{num_batches}")
    print(f"Papers {start_idx:,} to {end_idx:,}")
    
    # fetch citations
    df_batch = fetch_citations_batch(batch_ids, batch_name)
    
    print(f"\n✓ Saved: data/processed/citations_batches/{batch_name}.pkl")
    print(f"Progress: {end_idx:,} / {len(df_sample):,} papers")

print("\n" + "-"*60)
print("All batches complete!")

Will collect 10 batches of 50,000 papers
Total target: 500,000 papers
Estimated time: 15.0 hours at 33K/hour

✓ df_sample exists: 999,995 papers
✓ BATCH_SIZE: 50,000
✓ MAX_BATCHES: 10
BATCH 1/10
Papers 0 to 50,000

Fetching batch: batch_00
Papers to fetch: 50000


batch_00:  17%|█▋        | 8560/50000 [9:43:50<47:06:27,  4.09s/it]


KeyboardInterrupt: 

In [None]:
# combine all batches into single file

batch_dir = 'data/processed/citations_batches'
batch_files = sorted([f for f in os.listdir(batch_dir) if f.startswith('batch_') and not f.endswith('_temp.pkl')])

print(f"Combining {len(batch_files)} batches...")

all_batches = []
for batch_file in batch_files:
    df_batch = pd.read_pickle(os.path.join(batch_dir, batch_file))
    all_batches.append(df_batch)
    print(f"  {batch_file}: {len(df_batch):,} papers")

df_citations = pd.concat(all_batches, ignore_index=True)

print(f"\nTotal citations collected: {len(df_citations):,}")
print(f"Papers with citations > 0: {(df_citations['citation_count'] > 0).sum():,}")
print(f"Average citations: {df_citations['citation_count'].mean():.1f}")

# save combined file

df_citations.to_pickle('data/processed/citation_data_full.pkl')
print("\n✓ Saved: data/processed/citation_data_full.pkl")

In [None]:
import os

# check final file

if os.path.exists('data/processed/citation_data_full.pkl'):
    df_check = pd.read_pickle('data/processed/citation_data_full.pkl')
    size_mb = os.path.getsize('data/processed/citation_data_full.pkl') / 1024**2
    
    print("✓✓✓ Success! ✓✓✓")
    print(f"Citations collected: {len(df_check):,}")
    print(f"File size: {size_mb:.1f} MB")
    print(f"\nCitation stats:")
    print(df_check['citation_count'].describe())
else:
    print("x File not created yet")