In [None]:
import json
import numpy as np
from pathlib import Path
import re
from collections import defaultdict
import Levenshtein
import time


## Helper Functions

In [None]:
def load_jsonl(file_path):
    """Load JSONL file into a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def extract_checkpoint_number(filename):
    """Extract checkpoint number from filename like 'predictions_checkpoint_02625.jsonl'."""
    match = re.search(r'checkpoint_(\d+)', filename)
    return int(match.group(1)) if match else -1

def find_highest_checkpoint(pass_dir):
    """Find the prediction file with the highest checkpoint number in a directory."""
    checkpoint_files = list(Path(pass_dir).glob('predictions_final_*.jsonl'))
    if not checkpoint_files:
        return None
    
    highest = max(checkpoint_files, key=lambda f: extract_checkpoint_number(f.name))
    return highest

def calculate_consistency_stats(pass1_data, pass2_data, pass3_data):
    """Calculate consistency statistics across 3 passes."""
    # Build dictionaries keyed by ipis_id
    pass1_dict = {item['ipis_id']: item['target'] for item in pass1_data}
    pass2_dict = {item['ipis_id']: item['target'] for item in pass2_data}
    pass3_dict = {item['ipis_id']: item['target'] for item in pass3_data}
    
    all_match = 0
    has_difference = 0
    edit_distances = []
    
    for ipis_id in pass1_dict.keys():
        t1 = pass1_dict[ipis_id]
        t2 = pass2_dict[ipis_id]
        t3 = pass3_dict[ipis_id]
        
        if t1 == t2 == t3:
            all_match += 1
        else:
            has_difference += 1
            # Calculate pairwise edit distances
            d12 = Levenshtein.distance(t1, t2)
            d13 = Levenshtein.distance(t1, t3)
            d23 = Levenshtein.distance(t2, t3)
            edit_distances.append(max(d12, d13, d23))
    
    total = len(pass1_dict)
    determinism_rate = all_match / total * 100
    
    stats = {
        'total': total,
        'all_match': all_match,
        'has_difference': has_difference,
        'determinism_rate': determinism_rate,
        'mean_edit_dist': np.mean(edit_distances) if edit_distances else 0,
        'median_edit_dist': np.median(edit_distances) if edit_distances else 0,
        'min_edit_dist': np.min(edit_distances) if edit_distances else 0,
        'max_edit_dist': np.max(edit_distances) if edit_distances else 0,
    }
    
    return stats

## Analyze All LoRA Ranks

In [None]:
from pathlib import Path
from collections import defaultdict
from io import StringIO

base_path = Path("solution/task_proofreading/02_inference")
lora_ranks = [8, 16, 32, 64, 128]
results = defaultdict(list)

output_buffer = StringIO()

def write(msg):
    """Helper to write to buffer"""
    output_buffer.write(msg + "\n")

write("Analyzing inference consistency across LoRA ranks...")
write("")

for rank in lora_ranks:
    rank_dir = base_path / f"inference_checkpoints_lora_r{rank}"

    if not rank_dir.exists():
        write(f"[SKIP] LoRA rank {rank}: directory not found")
        write("=" * 100)
        continue

    write(f"Processing LoRA rank {rank}...")

    if rank == 64:
        pass1_file = find_highest_checkpoint(rank_dir / "pass_1")
        pass2_file = find_highest_checkpoint(rank_dir / "pass_2")
        pass3_file = find_highest_checkpoint(rank_dir / "submission")

        if not all([pass1_file, pass2_file, pass3_file]):
            write(f"  [WARN] Missing passes for rank {rank}")
            write(f"     Pass 1: {'OK' if pass1_file else 'MISSING'}")
            write(f"     Pass 2: {'OK' if pass2_file else 'MISSING'}")
            write(f"     Submission: {'OK' if pass3_file else 'MISSING'}")
            write("=" * 100)
            continue

        write(f"  Pass 1: {pass1_file.name}")
        write(f"  Pass 2: {pass2_file.name}")
        write(f"  Submission: {pass3_file.name}")

    elif rank == 128:
        old_dir = rank_dir / "OLD"
        if old_dir.exists():
            write(f"  Note: Skipping OLD folder for rank {rank}")

        pass1_file = find_highest_checkpoint(rank_dir / "pass_1")
        pass2_file = find_highest_checkpoint(rank_dir / "pass_2")
        pass3_file = find_highest_checkpoint(rank_dir / "pass_3")

        if not all([pass1_file, pass2_file, pass3_file]):
            write(f"  [WARN] Missing passes for rank {rank}")
            write(f"     Pass 1: {'OK' if pass1_file else 'MISSING'}")
            write(f"     Pass 2: {'OK' if pass2_file else 'MISSING'}")
            write(f"     Pass 3: {'OK' if pass3_file else 'MISSING'}")
            write("=" * 100)
            continue

        write(f"  Pass 1: {pass1_file.name}")
        write(f"  Pass 2: {pass2_file.name}")
        write(f"  Pass 3: {pass3_file.name}")

    else:
        pass1_file = find_highest_checkpoint(rank_dir / "pass_1")
        pass2_file = find_highest_checkpoint(rank_dir / "pass_2")
        pass3_file = find_highest_checkpoint(rank_dir / "pass_3")

        if not all([pass1_file, pass2_file, pass3_file]):
            write(f"  [WARN] Missing passes for rank {rank}")
            write(f"     Pass 1: {'OK' if pass1_file else 'MISSING'}")
            write(f"     Pass 2: {'OK' if pass2_file else 'MISSING'}")
            write(f"     Pass 3: {'OK' if pass3_file else 'MISSING'}")
            write("=" * 100)
            continue

        write(f"  Pass 1: {pass1_file.name}")
        write(f"  Pass 2: {pass2_file.name}")
        write(f"  Pass 3: {pass3_file.name}")

    pass1_data = load_jsonl(pass1_file)
    pass2_data = load_jsonl(pass2_file)
    pass3_data = load_jsonl(pass3_file)

    stats = calculate_consistency_stats(pass1_data, pass2_data, pass3_data)
    stats["rank"] = rank
    results[rank].append(stats)

    write(f"  Determinism rate: {stats['determinism_rate']:.2f}%")
    write(f"  Examples with differences: {stats['has_difference']}/{stats['total']}")
    if stats["has_difference"] > 0:
        write(f"  Mean edit distance: {stats['mean_edit_dist']:.2f}")
    write("=" * 100)

write("Analysis complete!")

print(output_buffer.getvalue())

## Aggregate Statistics by LoRA Rank

In [None]:
# Aggregate stats (though each rank has only 1 data point in this case)
aggregated_stats = {}

for rank in sorted(results.keys()):
    rank_stats = results[rank]
    
    if not rank_stats:
        continue
    
    # Extract values
    determinism_rates = [s['determinism_rate'] for s in rank_stats]
    mean_edit_dists = [s['mean_edit_dist'] for s in rank_stats if s['has_difference'] > 0]
    
    aggregated_stats[rank] = {
        'determinism_mean': np.mean(determinism_rates),
        'determinism_std': np.std(determinism_rates) if len(determinism_rates) > 1 else 0,
        'determinism_min': np.min(determinism_rates),
        'determinism_max': np.max(determinism_rates),
        'edit_dist_mean': np.mean(mean_edit_dists) if mean_edit_dists else 0,
        'edit_dist_std': np.std(mean_edit_dists) if len(mean_edit_dists) > 1 else 0,
        'edit_dist_min': np.min(mean_edit_dists) if mean_edit_dists else 0,
        'edit_dist_max': np.max(mean_edit_dists) if mean_edit_dists else 0,
        'total_examples': rank_stats[0]['total'],
        'n_runs': len(rank_stats)
    }

print("Aggregated statistics by LoRA rank:")
print("="*80)
for rank, stats in sorted(aggregated_stats.items()):
    print(f"\nLoRA Rank {rank}:")
    print(f"  Determinism rate: {stats['determinism_mean']:.2f}% (±{stats['determinism_std']:.2f})")
    print(f"  Range: [{stats['determinism_min']:.2f}%, {stats['determinism_max']:.2f}%]")
    if stats['edit_dist_mean'] > 0:
        print(f"  Mean edit distance: {stats['edit_dist_mean']:.2f} (±{stats['edit_dist_std']:.2f})")
        print(f"  Range: [{stats['edit_dist_min']:.2f}, {stats['edit_dist_max']:.2f}]")
print("="*80)

## Generate LaTeX Table

In [None]:
print("\\begin{table}[!htbp]")
print("\\centering")
print("\\small")
print("\\caption{Inference consistency across LoRA ranks (3 passes per rank)}")
print("\\label{tab:inference_consistency}")
print()
print("\\begin{tabularx}{\\columnwidth}{Xrr}")
print("\\toprule")
print("\\textbf{LoRA Rank} & \\textbf{Determinism (\\%)\\textsuperscript{1}} & \\textbf{Mean Edit Distance\\textsuperscript{2}} \\\\")
print("\\midrule")

# Iterate through each rank
for rank in [8, 16, 32, 64, 128]:
    if rank in aggregated_stats:
        det_val = f"{aggregated_stats[rank]['determinism_mean']:.2f}"
        if aggregated_stats[rank]['edit_dist_mean'] > 0:
            edit_val = f"{aggregated_stats[rank]['edit_dist_mean']:.2f}"
        else:
            edit_val = "---"
    else:
        det_val = "---"
        edit_val = "---"
    print(f"r={rank} & {det_val} & {edit_val} \\\\")

print("\\bottomrule")
print("\\end{tabularx}")
print()
print("\\vspace{2pt}")
print("\\raggedright")
print("\\footnotesize{\\textsuperscript{1} Percentage of examples with identical outputs across 3 inference passes (temperature=0.3)}")
print()
print("\\footnotesize{\\textsuperscript{2} Levenshtein distance (characters) for examples with differences}")
print("\\end{table}")

## Summary Statistics

In [None]:
# Display summary information
print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"\nAnalyzed LoRA ranks: {sorted(aggregated_stats.keys())}")
print(f"Note: r=64 uses pass_1, pass_2, and 'submission' folder (3 passes total)")
print(f"Note: r=128 uses pass_1, pass_2, pass_3 (OLD folder skipped)")
print("\nDeterminism across ranks:")
for rank in sorted(aggregated_stats.keys()):
    stats = aggregated_stats[rank]
    print(f"  r={rank:3d}: {stats['determinism_mean']:5.2f}% deterministic")
print("="*80)
