# Statistical Analysis

#### Imports

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
import os
import seaborn as sns
import itertools
from typing import Optional

In [2]:
def set_project_root_cwd(root_dir_name: str, max_depth: int = 5) -> Optional[Path]:
    try:
        current_path = Path(os.getcwd())
        original_cwd = current_path
        
        for _ in range(max_depth):
            if current_path.name == root_dir_name:
                os.chdir(current_path)
                print(f"CWD set to project root: {os.getcwd()}")
                return current_path
            
            if current_path == current_path.parent:
                break
            
            current_path = current_path.parent
            
        raise FileNotFoundError(f"Project root '{root_dir_name}' not found within {max_depth} levels above {original_cwd}")
        
    except Exception as e:
        print(f"Error setting CWD: {e}")
        return None

In [3]:
set_project_root_cwd('rag_uncertainty')

CWD set to project root: /vol/tensusers8/fgerding/rag_uncertainty


PosixPath('/vol/tensusers8/fgerding/rag_uncertainty')

In [4]:
data_path = Path("./results/results.json")
print("File exists:", data_path.exists())

# Load a small preview to confirm structure
with open(data_path, "r") as f:
    data = json.load(f)

print(len(data), "records loaded.")


File exists: True
450 records loaded.


#### Configuration

In [None]:
UE_METRICS = [
    "semantic_entropy_global",
    "sum_eigen",
]

RAFE_METRICS = [
    "rafe_overall_score",
]

ALL_METRICS = UE_METRICS + RAFE_METRICS
OUTPUT_DIR = Path("data/analysis_outputs/")

#### Data Processing Modules

In [None]:
def build_dataframes(data):
    """Parses raw JSON-like data into structured DataFrames for analysis."""
    
    # Example-level Data
    example_rows = []
    for i, rec in enumerate(data):
        row = {
            "example_id": i,
            "category": rec.get("category"),
            "question": rec.get("question"),
            "answer": rec.get("answer"),
            "answer_len": len(rec.get("answer", "")),
            # Unpack metrics dynamically
            **{k: rec.get(k) for k in ALL_METRICS},
            # RAFE specific details
            "rafe_gen_supported": rec.get("rafe_gen_supported"),
            "rafe_gen_not_supported": rec.get("rafe_gen_not_supported"),
            "rafe_gen_irrelevant": rec.get("rafe_gen_irrelevant"),
            "rafe_gen_total_claims": rec.get("rafe_gen_total_claims"),
        }
        example_rows.append(row)
    example_df = pd.DataFrame(example_rows)

    # Claim-level Data
    claim_rows = []
    for i, rec in enumerate(data):
        for j, c in enumerate(rec.get("rafe_details") or []):
            claim_rows.append({
                "example_id": i,
                "claim_id": j,
                "claim_text": c.get("claim"),
                "safe_label": c.get("label"),
                "n_evidence": len(c.get("evidence") or []),
            })
    claims_df = pd.DataFrame(claim_rows)

    # Atomic Facts Data
    atomic_rows = []
    for i, rec in enumerate(data):
        for j, txt in enumerate(rec.get("atomic_facts") or []):
            atomic_rows.append({
                "example_id": i, 
                "atomic_fact_id": j, 
                "atomic_fact_text": txt
            })
    atomic_df = pd.DataFrame(atomic_rows)

    return example_df, claims_df, atomic_df

def check_consistency(data):
    """
    Verifies data integrity by comparing reported summary scores 
    against recomputed counts from detailed logs.
    """
    records = []
    for i, rec in enumerate(data):
        labels = [c.get("label") for c in (rec.get("safe_details") or [])]
        counts = {
            "supported": labels.count("SUPPORTED"),
            "not_supported": labels.count("NOT_SUPPORTED"),
            "irrelevant": labels.count("IRRELEVANT"),
            "total": len(labels)
        }
        
        row = {"example_id": i}
        for key in counts:
            # Handle edge case where total_claims is reported but others are None
            reported = rec.get(f"safe_gen_{key}") 
            if key == "total" and reported is None:
                reported = rec.get("safe_gen_total_claims")
            
            # Default to 0 if None for calculation, strictly
            reported_val = reported if reported is not None else 0
            
            row[f"diff_{key}"] = reported_val - counts[key]
        records.append(row)

    df = pd.DataFrame(records)
    
    # Check if all difference columns are 0
    consistent_mask = (df[["diff_supported", "diff_not_supported", 
                           "diff_irrelevant", "diff_total"]] == 0).all(axis=1)
    
    print(f"Consistency Check: {consistent_mask.sum()}/{len(df)} examples fully consistent.")
    return df

#### Statistical Analysis

In [7]:
def compute_correlations(df, metrics_x, metrics_y, label):
    """Computes Pearson and Spearman correlations between metric sets."""
    rows = []
    seen_pairs = set()
    
    for x in metrics_x:
        for y in metrics_y:
            # Skip self-correlation and duplicate pairs (A-B vs B-A)
            if x == y or (y, x) in seen_pairs: 
                continue
            
            sub = df[[x, y]].dropna()
            if len(sub) < 3: # Insufficient data
                continue

            rows.append({
                "type": label,
                "metric_1": x,
                "metric_2": y,
                "pearson_r": sub[x].corr(sub[y], method="pearson"),
                "spearman_r": sub[x].corr(sub[y], method="spearman"),
                "n_samples": len(sub)
            })
            
            # If computing within the same list, mark pair as seen
            if metrics_x is metrics_y:
                seen_pairs.add((x, y))

    return pd.DataFrame(rows)

#### Visualization Modules

In [8]:
def generate_all_scatter_plots(df, metrics, plot_dir):
    """
    Generates scatter plots for all unique pairs of metrics.
    Uses itertools.combinations to avoid duplicates and self-plots.
    """
    plot_dir.mkdir(parents=True, exist_ok=True)
    
    for m1, m2 in itertools.combinations(metrics, 2):
        sub = df[[m1, m2]].dropna()
        if len(sub) < 3:
            continue
            
        plt.figure(figsize=(5, 4))
        plt.scatter(sub[m1], sub[m2], alpha=0.6, s=15, edgecolors='none')
        plt.xlabel(m1)
        plt.ylabel(m2)
        plt.title(f"{m1}\nvs\n{m2}", fontsize=9)
        
        plt.tight_layout()
        plt.savefig(plot_dir / f"scatter_{m1}_vs_{m2}.pdf")
        plt.close()

def plot_correlation_heatmap(df, metrics, save_path):
    """Generates a correlation matrix heatmap."""
    corr = df[metrics].corr(method='pearson')
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", center=0, vmin=-1, vmax=1)
    plt.title("Metric Correlation Matrix (Pearson)")
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def plot_histograms(df, metrics, plot_dir):
    """Generates distribution histograms for all metrics."""
    for col in metrics:
        sub = df[col].dropna()
        if sub.empty: continue
        
        plt.figure(figsize=(5, 4))
        plt.hist(sub, bins=30, alpha=0.7, color='steelblue', edgecolor='black')
        plt.title(f"Distribution: {col}", fontsize=10)
        plt.xlabel("Score")
        plt.ylabel("Frequency")
        plt.tight_layout()
        plt.savefig(plot_dir / f"hist_{col}.pdf")
        plt.close()

#### Execution

In [9]:
print("Starting Analysis Pipeline...")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
PLOT_DIR = OUTPUT_DIR / "plots"
PLOT_DIR.mkdir(parents=True, exist_ok=True)

# 1. Build Data
ex_df, claim_df, atom_df = build_dataframes(data)
cons_df = check_consistency(data)
print(f"Data Loaded -> Examples: {ex_df.shape}, Claims: {claim_df.shape}")

# 2. Compute Correlations (Between and Within groups)
print("Computing correlations...")
corr_ue_safe = compute_correlations(ex_df, UE_METRICS, SAFE_METRICS, label="UE_vs_SAFE")
corr_within_ue = compute_correlations(ex_df, UE_METRICS, UE_METRICS, label="Within_UE")
corr_within_safe = compute_correlations(ex_df, SAFE_METRICS, SAFE_METRICS, label="Within_SAFE")

all_corrs = pd.concat([corr_ue_safe, corr_within_ue, corr_within_safe], ignore_index=True)

# 3. Visualization
print("Generating plots...")
generate_all_scatter_plots(ex_df, ALL_METRICS, PLOT_DIR)
plot_histograms(ex_df, ALL_METRICS, PLOT_DIR)
plot_correlation_heatmap(ex_df, ALL_METRICS, PLOT_DIR / "correlation_heatmap.pdf")

# 4. Save Artifacts
ex_df.to_csv(OUTPUT_DIR / "examples.csv", index=False)
claim_df.to_csv(OUTPUT_DIR / "claims.csv", index=False)
atom_df.to_csv(OUTPUT_DIR / "atomic_facts.csv", index=False)
cons_df.to_csv(OUTPUT_DIR / "consistency_report.csv", index=False)
all_corrs.to_csv(OUTPUT_DIR / "correlations.csv", index=False)

print("\nSummary of Correlations by Type")
print(all_corrs.groupby("type")[["pearson_r", "spearman_r"]].mean())
print(f"\nAnalysis complete. All outputs saved to: {OUTPUT_DIR.resolve()}")

Starting Analysis Pipeline...
Consistency Check: 450/450 examples fully consistent.
Data Loaded -> Examples: (450, 12), Claims: (2133, 5)
Computing correlations...
Generating plots...

Summary of Correlations by Type
            pearson_r  spearman_r
type                             
UE_vs_SAFE  -0.107894   -0.171139
Within_UE    0.667426    0.810184

Analysis complete. All outputs saved to: /vol/tensusers8/fgerding/rag_uncertainty/data/analysis_outputs
