# Linear Probing Results Analysis

This notebook analyzes the results from linear probing experiments run on the HPC cluster. It includes:
- Within-species results (Human and Mouse separately)
- Cross-species results
- Visualization of AUROC performance and ROC curves


### Setup


In [None]:
import os
import pickle
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, roc_curve
from dotenv import load_dotenv

load_dotenv()

# Seaborn base theme (same as deprecated notebooks)
sns.set_theme(
    context="paper",  # 'paper' = smaller, for journal figures
    style="whitegrid",  # clean background with subtle grid
    font="DejaVu Sans",  # use a consistent sans-serif
    font_scale=1.4,  # scale up text a bit for readability
    palette="Set2",
    rc={
        # Figure sizing
        "figure.figsize": (10, 8),  # in inches; adjust for single-column
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "axes.linewidth": 1.0,
        "axes.labelpad": 8,
        "axes.grid": True,
        "grid.linewidth": 0.4,
        "grid.alpha": 0.6,
        "lines.linewidth": 1.5,
        "lines.markersize": 5,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "legend.frameon": False,
        "savefig.dpi": 300,
        "savefig.transparent": True,  # transparent background for vector exports
        "pdf.fonttype": 42,  # embed TrueType fonts (important for Illustrator)
        "ps.fonttype": 42,
    },
)

# Matplotlib tight layout by default
plt.rcParams.update({"figure.autolayout": True})


### Load Results


In [None]:
DATA_DIR = os.getenv("INFLAMM_DEBATE_FM_DATA_ROOT")
if DATA_DIR is None:
    DATA_DIR = "../data/"

DATA_ROOT = Path(DATA_DIR)
RESULTS_DIR = DATA_ROOT / "probing_results"

# Check if results directory exists
if not RESULTS_DIR.exists():
    print(f"Warning: Results directory not found at {RESULTS_DIR}")
    print("Please ensure probing experiments have been run and results are saved.")
else:
    print(f"Results directory: {RESULTS_DIR}")

# Load within-species results
within_species_dir = RESULTS_DIR / "within_species"
human_results_path = within_species_dir / "human_results.pkl"
mouse_results_path = within_species_dir / "mouse_results.pkl"

# Load cross-species results
cross_species_dir = RESULTS_DIR / "cross_species"
# Try combined results first, then individual file
cross_species_results_path = cross_species_dir / "cross_species_results_combined.pkl"
if not cross_species_results_path.exists():
    cross_species_results_path = cross_species_dir / "cross_species_results.pkl"

# Load results
human_results = None
mouse_results = None
cross_species_results = None

if human_results_path.exists():
    print(f"Loading human results from {human_results_path}")
    with open(human_results_path, "rb") as f:
        human_results = pickle.load(f)
    print(f"  Loaded: {len(human_results.get('results', {}).get('CrossValidation', {}).get('Linear', {}))} setups")
else:
    print(f"Warning: Human results not found at {human_results_path}")

if mouse_results_path.exists():
    print(f"Loading mouse results from {mouse_results_path}")
    with open(mouse_results_path, "rb") as f:
        mouse_results = pickle.load(f)
    print(f"  Loaded: {len(mouse_results.get('results', {}).get('CrossValidation', {}).get('Linear', {}))} setups")
else:
    print(f"Warning: Mouse results not found at {mouse_results_path}")

if cross_species_results_path.exists():
    print(f"Loading cross-species results from {cross_species_results_path}")
    with open(cross_species_results_path, "rb") as f:
        cross_species_results = pickle.load(f)
    print(f"  Loaded cross-species results")
else:
    print(f"Warning: Cross-species results not found at {cross_species_results_path}")
    # Check for partial bootstrap files
    bootstrap_files = list(cross_species_dir.glob("cross_species_results_bs_*.pkl"))
    if bootstrap_files:
        print(f"  Found {len(bootstrap_files)} partial bootstrap files. Please combine them first.")


### Plotting Functions


In [None]:
def plot_auroc_summary(all_results, model_type="Linear", setup_order=None, title_suffix=""):
    """
    Barplot summary of AUROC for CV and LODO across setups and data types.
    """
    data_types = ["Raw", "Embedding"]
    val_types = ["CrossValidation", "LODO"]
    colors = sns.color_palette("Set2", n_colors=2)
    hatches = ["//", ""]

    fig, ax = plt.subplots(figsize=(14, 8))
    width = 0.2
    
    # Extract setups from results if not provided
    if setup_order is None:
        example_dict = all_results.get("CrossValidation", {}).get(model_type, {})
        setup_order = sorted({key.split("::")[0] for key in example_dict.keys()})
    
    x = np.arange(len(setup_order))

    for i, val_type in enumerate(val_types):
        for j, data_type in enumerate(data_types):
            heights, errs = [], []
            for setup in setup_order:
                # Look for keys matching setup::data_type pattern
                full_key = f"{setup}::{data_type}"
                results_dict = all_results.get(val_type, {}).get(model_type, {})
                mean_std = results_dict.get(full_key, (np.nan, np.nan))
                if isinstance(mean_std, tuple) and len(mean_std) == 2:
                    heights.append(mean_std[0])
                    errs.append(mean_std[1])
                else:
                    heights.append(np.nan)
                    errs.append(0.0)
            bars = ax.bar(
                x + (i * len(data_types) + j) * width,
                heights,
                width=width,
                yerr=errs,
                capsize=3,
                label=f"{val_type}-{data_type}",
                color=colors[j],
                hatch=hatches[i],
            )

    ax.set_xticks(x + width * 1.5)
    ax.set_xticklabels(setup_order, rotation=0, ha="right", fontsize=8)
    ax.set_ylabel("AUROC")
    ax.set_ylim(0, 1.19)
    ax.axhline(0.5, color="k", linestyle="--", alpha=0.3, linewidth=0.8)
    
    if title_suffix:
        ax.set_title(f"AUROC Summary - {model_type} Model{title_suffix}", fontsize=14)

    for i, tick in enumerate(ax.xaxis.get_major_ticks()):
        if i % 2 != 0:  # Select every other tick (starting from the second)
            tick.set_pad(15)  # Increase padding to move it down

    for label in ax.get_xticklabels():
        label.set_horizontalalignment("center")

    # Move legend outside
    ax.legend(loc="upper right", fontsize=8)

    plt.tight_layout()
    plt.show()


def plot_cv_roc(all_roc_data, model_type="Linear", setup_order=None, title_suffix=""):
    """
    Plot CV ROC curves for each setup (2x3 grid, all lines shown).
    """
    data_types = ["Raw", "Embedding"]
    colors = sns.color_palette("Set2", n_colors=2)

    # Extract setups from ROC data if not provided
    if setup_order is None:
        example_dict = all_roc_data.get("CrossValidation", {}).get(model_type, {})
        setup_order = sorted({key.split("::")[0] for key in example_dict.keys()})

    n = len(setup_order)
    cols = 3
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(16, 8), sharex=True, sharey=True)
    if rows == 1 and cols == 1:
        axes = np.array([axes])
    axes = axes.flatten()

    for ax, setup in zip(axes, setup_order):
        for data_type, color in zip(data_types, colors):
            full_key = f"{setup}::{data_type}"
            roc_list = all_roc_data.get("CrossValidation", {}).get(model_type, {}).get(full_key, [])
            if not roc_list:
                continue
            for fpr, tpr in roc_list:
                ax.plot(fpr, tpr, color=color, alpha=0.5, lw=1)
        ax.plot([0, 1], [0, 1], "k--", lw=0.8, alpha=0.4)
        ax.set_title(setup, fontsize=9)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

    # Only label outer axes
    for i, ax in enumerate(axes):
        if i // cols == rows - 1:  # bottom row
            ax.set_xlabel("FPR")
            ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
            ax.set_xticklabels([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
        else:
            ax.set_xticklabels([])
        if i % cols == 0:  # left column
            ax.set_ylabel("TPR")
            ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
            ax.set_yticklabels([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
        else:
            ax.set_yticklabels([])
    
    # Hide unused subplots
    for i in range(len(setup_order), len(axes)):
        fig.delaxes(axes[i])

    plt.suptitle(f"Cross-Validation ROC Curves - {model_type} Model{title_suffix}", fontsize=14)
    plt.subplots_adjust(wspace=0.1, hspace=0.25)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()


def plot_lodo_roc(all_roc_data, model_type="Linear", setup_order=None, title_suffix=""):
    """
    Plot LODO ROC curves for each setup (2x3 grid, all lines shown).
    """
    data_types = ["Raw", "Embedding"]
    colors = sns.color_palette("Set2", n_colors=2)

    # Extract setups from ROC data if not provided
    if setup_order is None:
        example_dict = all_roc_data.get("LODO", {}).get(model_type, {})
        setup_order = sorted({key.split("::")[0] for key in example_dict.keys()})

    n = len(setup_order)
    cols = 3
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(10, 8), sharex=True, sharey=True)
    if rows == 1 and cols == 1:
        axes = np.array([axes])
    axes = axes.flatten()

    for ax, setup in zip(axes, setup_order):
        for data_type, color in zip(data_types, colors):
            full_key = f"{setup}::{data_type}"
            roc_list = all_roc_data.get("LODO", {}).get(model_type, {}).get(full_key, [])
            if not roc_list:
                continue
            for fpr, tpr in roc_list:
                ax.plot(fpr, tpr, color=color, alpha=0.5, lw=1)
        ax.plot([0, 1], [0, 1], "k--", lw=0.8, alpha=0.4)
        ax.set_title(setup, fontsize=9)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

    # Only label outer axes
    for i, ax in enumerate(axes):
        if i // cols == rows - 1:
            ax.set_xlabel("FPR")
        else:
            ax.set_xticklabels([])
        if i % cols == 0:
            ax.set_ylabel("TPR")
        else:
            ax.set_yticklabels([])
    
    # Hide unused subplots
    for i in range(len(setup_order), len(axes)):
        fig.delaxes(axes[i])

    plt.suptitle(f"LODO ROC Curves - {model_type} Model{title_suffix}", fontsize=14)
    plt.subplots_adjust(wspace=0.1, hspace=0.25)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()


In [None]:
if human_results is not None:
    # Get setup order from results
    example_dict = human_results.get("results", {}).get("CrossValidation", {}).get("Linear", {})
    setup_order = sorted({key.split("::")[0] for key in example_dict.keys()})
    
    print("="*60)
    print("Human - Linear Model")
    print("="*60)
    
    # AUROC summary
    plot_auroc_summary(
        human_results.get("results", {}), 
        model_type="Linear", 
        setup_order=setup_order,
        title_suffix=" - Human"
    )
    
    # CV ROC curves
    plot_cv_roc(
        human_results.get("roc_data", {}), 
        model_type="Linear", 
        setup_order=setup_order,
        title_suffix=" - Human"
    )
    
    # LODO ROC curves
    plot_lodo_roc(
        human_results.get("roc_data", {}), 
        model_type="Linear", 
        setup_order=setup_order,
        title_suffix=" - Human"
    )
else:
    print("Human results not available. Skipping plots.")


### Within-Species Results: Mouse


In [None]:
if mouse_results is not None:
    # Get setup order from results
    example_dict = mouse_results.get("results", {}).get("CrossValidation", {}).get("Linear", {})
    setup_order = sorted({key.split("::")[0] for key in example_dict.keys()})
    
    print("="*60)
    print("Mouse - Linear Model")
    print("="*60)
    
    # AUROC summary
    plot_auroc_summary(
        mouse_results.get("results", {}), 
        model_type="Linear", 
        setup_order=setup_order,
        title_suffix=" - Mouse"
    )
    
    # CV ROC curves
    plot_cv_roc(
        mouse_results.get("roc_data", {}), 
        model_type="Linear", 
        setup_order=setup_order,
        title_suffix=" - Mouse"
    )
    
    # LODO ROC curves
    plot_lodo_roc(
        mouse_results.get("roc_data", {}), 
        model_type="Linear", 
        setup_order=setup_order,
        title_suffix=" - Mouse"
    )
else:
    print("Mouse results not available. Skipping plots.")


### Cross-Species Results


In [None]:
def plot_roc_facet_clean(all_roc_data, model_type="Linear", setup_order=None):
    """
    Clean ROC facet grid for cross-species: Linear and Nonlinear separately, four lines per subplot.
    - Only leftmost column and bottom row show axis labels and ticks
    - Only bottom-right subplot shows legend
    - Optional setup_order to enforce subplot order
    """
    import matplotlib.gridspec as gridspec

    colors = {"Raw": "C0", "Embedding": "C1"}
    linestyles = {"Human→Mouse": "-", "Mouse→Human": "--"}

    # Extract setups
    example_dict = all_roc_data.get(model_type, {}).get("Raw", {})
    setups = sorted({key.split(" (")[0] for key in example_dict.keys()})
    if setup_order is not None:
        setups = [s for s in setup_order if s in setups]

    n = len(setups)
    cols = 3
    rows = (n + cols - 1) // cols
    fig = plt.figure(figsize=(cols * 5, rows * 4))
    gs = gridspec.GridSpec(rows, cols, figure=fig)

    for i, setup in enumerate(setups):
        ax = fig.add_subplot(gs[i])
        for data_type in ["Raw", "Embedding"]:
            for direction in ["Human→Mouse", "Mouse→Human"]:
                key = f"{setup} ({direction})"
                if key in all_roc_data.get(model_type, {}).get(data_type, {}):
                    roc_data = all_roc_data[model_type][data_type][key]
                    # Handle both formats: (fpr, tpr, auroc) tuple or list of (fpr, tpr) tuples
                    if isinstance(roc_data, tuple) and len(roc_data) == 3:
                        fpr, tpr, auroc = roc_data
                        ax.plot(
                            fpr,
                            tpr,
                            color=colors[data_type],
                            linestyle=linestyles[direction],
                            lw=2,
                            label=f"{data_type} {direction} AUROC={auroc:.2f}",
                        )
                    elif isinstance(roc_data, (list, tuple)) and len(roc_data) > 0:
                        # Multiple ROC curves (bootstrap results)
                        for roc_item in roc_data:
                            if isinstance(roc_item, tuple) and len(roc_item) >= 2:
                                fpr, tpr = roc_item[:2]
                                auroc = roc_item[2] if len(roc_item) > 2 else None
                                label = None
                                if i == n - 1:  # Only label on last subplot
                                    label = f"{data_type} {direction}"
                                    if auroc is not None:
                                        label += f" AUROC={auroc:.2f}"
                                ax.plot(
                                    fpr,
                                    tpr,
                                    color=colors[data_type],
                                    linestyle=linestyles[direction],
                                    lw=1,
                                    alpha=0.6,
                                    label=label,
                                )

        ax.plot([0, 1], [0, 1], "--", color="black", lw=1, alpha=0.7)
        ax.set_title(setup, fontsize=9)

        # Only show x-labels for bottom row
        if i // cols == rows - 1:
            ax.set_xlabel("FPR")
        else:
            ax.set_xticklabels([])
            ax.set_xlabel("")

        # Only show y-labels for leftmost column
        if i % cols == 0:
            ax.set_ylabel("TPR")
        else:
            ax.set_yticklabels([])
            ax.set_ylabel("")

        # Show legend only on bottom-right subplot
        if i == n - 1:
            ax.legend(fontsize=8)
    
    plt.suptitle(f"Cross-Species ROC Curves - {model_type} Model", fontsize=14, y=0.995)
    plt.tight_layout()
    plt.show()


def plot_auroc_bar_clean_sns_top_legend(all_results, model_type="Linear", setup_order=None):
    """
    AUROC barplot for cross-species setups.
    - Groups: Raw/Embedding side by side for each direction
    - Colors by Raw vs Embedding using sns.Set2
    - Leftmost column shows Y-axis label 'AUROC'
    - Legend on top-left
    - Horizontal line at 0.5 for random chance
    """
    palette = sns.color_palette("Set2", 2)  # Raw, Embedding
    directions = ["Human→Mouse", "Mouse→Human"]
    data_types = ["Raw", "Embedding"]

    example_dict = all_results.get(model_type, {}).get("Raw", {})
    setups = sorted({key.split(" (")[0] for key in example_dict.keys()})
    if setup_order is not None:
        setups = [s for s in setup_order if s in setups]

    n = len(setups)
    cols = 3
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4))
    if rows == 1 and cols == 1:
        axes = np.array([axes])
    axes = axes.flatten()

    for i, setup in enumerate(setups):
        ax = axes[i]
        # Bar positions
        x = np.arange(len(directions))  # 0=H→M, 1=M→H
        width = 0.35
        for j, data_type in enumerate(data_types):
            heights = []
            for dir_ in directions:
                key = f"{setup} ({dir_})"
                result = all_results.get(model_type, {}).get(data_type, {}).get(key, np.nan)
                # Handle both formats: dict with mean/std or tuple
                if isinstance(result, dict) and "mean" in result:
                    heights.append(result["mean"])
                elif isinstance(result, tuple) and len(result) >= 1:
                    heights.append(result[0])
                else:
                    heights.append(np.nan)
            
            ax.bar(
                x + (j - 0.5) * width,
                heights,
                width=width,
                color=palette[j],
                edgecolor="black",
                label=data_type if i == 0 else "",
            )

        # Horizontal line at 0.5
        ax.axhline(0.5, color="black", linestyle="--", linewidth=1, alpha=0.7)

        ax.set_xticks(x)
        ax.set_xticklabels(directions)
        ax.set_ylim(0, 1.05)
        ax.set_title(setup)

        # Y-axis label only on leftmost column
        if i % cols == 0:
            ax.set_ylabel("AUROC")
        else:
            ax.set_yticklabels([])

    # Remove extra axes
    for j in range(len(setups), len(axes)):
        fig.delaxes(axes[j])

    if len(setups) > 0:
        axes[0].legend()

    plt.suptitle(f"Cross-Species AUROC - {model_type} Model", fontsize=14, y=0.995)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()


In [None]:
if cross_species_results is not None:
    # Extract setup order from results
    example_dict = cross_species_results.get("results", {}).get("Linear", {}).get("Raw", {})
    setup_order = sorted({key.split(" (")[0] for key in example_dict.keys()})
    
    print("="*60)
    print("Cross-Species Results - Linear Model")
    print("="*60)
    
    # AUROC bar plot
    plot_auroc_bar_clean_sns_top_legend(
        cross_species_results.get("results", {}),
        model_type="Linear",
        setup_order=setup_order,
    )
    
    # ROC facet grid
    plot_roc_facet_clean(
        cross_species_results.get("roc_data", {}),
        model_type="Linear",
        setup_order=setup_order,
    )
else:
    print("Cross-species results not available. Skipping plots.")


### Summary Statistics


In [None]:
# Create summary tables for all results
print("="*80)
print("Summary Statistics")
print("="*80)

# Within-species summaries
if human_results is not None:
    print("\nHuman - Within-Species Results:")
    print("-"*80)
    rows = []
    results = human_results.get("results", {})
    for val_type in ["CrossValidation", "LODO"]:
        for model_type in ["Linear", "Nonlinear"]:
            for setup_key, (mean_val, std_val) in results.get(val_type, {}).get(model_type, {}).items():
                setup, data_type = setup_key.split("::")
                rows.append({
                    "Species": "Human",
                    "Validation": val_type,
                    "Model": model_type,
                    "Data Type": data_type,
                    "Setup": setup,
                    "AUROC Mean": mean_val,
                    "AUROC Std": std_val,
                })
    if rows:
        df_human = pd.DataFrame(rows)
        print(df_human.to_string(index=False))
    else:
        print("No results found")

if mouse_results is not None:
    print("\nMouse - Within-Species Results:")
    print("-"*80)
    rows = []
    results = mouse_results.get("results", {})
    for val_type in ["CrossValidation", "LODO"]:
        for model_type in ["Linear", "Nonlinear"]:
            for setup_key, (mean_val, std_val) in results.get(val_type, {}).get(model_type, {}).items():
                setup, data_type = setup_key.split("::")
                rows.append({
                    "Species": "Mouse",
                    "Validation": val_type,
                    "Model": model_type,
                    "Data Type": data_type,
                    "Setup": setup,
                    "AUROC Mean": mean_val,
                    "AUROC Std": std_val,
                })
    if rows:
        df_mouse = pd.DataFrame(rows)
        print(df_mouse.to_string(index=False))
    else:
        print("No results found")

# Cross-species summary
if cross_species_results is not None:
    print("\nCross-Species Results:")
    print("-"*80)
    rows = []
    results = cross_species_results.get("results", {})
    for model_type in ["Linear", "Nonlinear"]:
        for data_type in ["Raw", "Embedding"]:
            for setup_key, result in results.get(model_type, {}).get(data_type, {}).items():
                # Handle both formats: dict with mean/std or tuple
                if isinstance(result, dict) and "mean" in result:
                    mean_val = result["mean"]
                    std_val = result.get("std", np.nan)
                elif isinstance(result, tuple) and len(result) >= 2:
                    mean_val = result[0]
                    std_val = result[1]
                else:
                    mean_val = np.nan
                    std_val = np.nan
                
                rows.append({
                    "Model": model_type,
                    "Data Type": data_type,
                    "Setup": setup_key,
                    "AUROC Mean": mean_val,
                    "AUROC Std": std_val,
                })
    if rows:
        df_cross = pd.DataFrame(rows)
        print(df_cross.to_string(index=False))
    else:
        print("No results found")

print("\n" + "="*80)
