# CNU Validation: Calibrated Neighborhood Uncertainty

Validates the theoretical framework:
1. Monotonicity of u(x) - does uncertainty score rank risk correctly?
2. Per-regime coverage - do intervals achieve target coverage?
3. Ablation - contribution of each uncertainty primitive

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

from src.calibration import NeighborhoodUncertainty, CNUCalibrator, NeighborhoodFeatures

# Settings
plt.rcParams['figure.dpi'] = 150
FIG_DIR = Path('../figures/paper')
FIG_DIR.mkdir(exist_ok=True)

## 1. Load Data

In [None]:
# Load train dataset
train_df = pd.read_csv('../data/raw/train.csv')
print(f"Train columns: {train_df.columns.tolist()}")
print(f"Train shape: {train_df.shape}")

# Get column names
smiles_col = [c for c in train_df.columns if 'smiles' in c.lower()][0]
tm_col = [c for c in train_df.columns if 'tm' in c.lower() or 'melt' in c.lower()][0]
print(f"Using SMILES column: {smiles_col}, Tm column: {tm_col}")

In [None]:
# Load SMP data with proper handling
smp_df = pd.read_csv('../data/raw/smiles_melting_point.csv')
print(f"SMP columns: {smp_df.columns.tolist()}")
print(f"SMP shape: {smp_df.shape}")

# Find columns
smp_smiles_col = [c for c in smp_df.columns if 'smiles' in c.lower()][0]
smp_tm_cols = [c for c in smp_df.columns if 'melt' in c.lower() or 'point' in c.lower()]
smp_tm_col = smp_tm_cols[0] if smp_tm_cols else None
print(f"SMP SMILES column: {smp_smiles_col}, Tm column: {smp_tm_col}")

# Convert Tm to numeric
if smp_tm_col:
    smp_df[smp_tm_col] = pd.to_numeric(smp_df[smp_tm_col], errors='coerce')
    smp_df = smp_df.dropna(subset=[smp_tm_col, smp_smiles_col])
    print(f"SMP after cleaning: {len(smp_df)} rows")

In [None]:
# Load Bradley data
try:
    bradley_df = pd.read_excel('../data/raw/BradleyMeltingPointDataset.xlsx')
    print(f"Bradley columns: {bradley_df.columns.tolist()}")
    
    # Find columns
    brad_smiles_col = [c for c in bradley_df.columns if 'smiles' in c.lower()][0]
    brad_tm_col = [c for c in bradley_df.columns if 'mpc' in c.lower() or 'mp' in c.lower()][0]
    bradley_df[brad_tm_col] = pd.to_numeric(bradley_df[brad_tm_col], errors='coerce')
    bradley_df = bradley_df.dropna(subset=[brad_tm_col, brad_smiles_col])
    print(f"Bradley after cleaning: {len(bradley_df)} rows")
except Exception as e:
    print(f"Could not load Bradley: {e}")
    bradley_df = None

In [None]:
# Combine datasets
all_smiles = list(train_df[smiles_col])
all_tms = train_df[tm_col].values.astype(float).copy()

# Add SMP
if smp_tm_col:
    smp_tms = smp_df[smp_tm_col].values.astype(float)
    # Check if Celsius (values < 200 on average)
    if np.nanmean(smp_tms) < 200:
        smp_tms = smp_tms + 273.15
    all_smiles.extend(list(smp_df[smp_smiles_col]))
    all_tms = np.concatenate([all_tms, smp_tms])
    print(f"Added {len(smp_df)} SMP molecules")

# Add Bradley
if bradley_df is not None and len(bradley_df) > 0:
    brad_tms = bradley_df[brad_tm_col].values.astype(float)
    if np.nanmean(brad_tms) < 200:
        brad_tms = brad_tms + 273.15
    all_smiles.extend(list(bradley_df[brad_smiles_col]))
    all_tms = np.concatenate([all_tms, brad_tms])
    print(f"Added {len(bradley_df)} Bradley molecules")

print(f"Total: {len(all_smiles)} molecules")

# Clean NaN values
valid_mask = ~np.isnan(all_tms)
all_smiles = [s for s, v in zip(all_smiles, valid_mask) if v]
all_tms = all_tms[valid_mask]
print(f"After NaN removal: {len(all_smiles)} molecules")

# Split 90/10
n = len(all_smiles)
np.random.seed(42)
perm = np.random.permutation(n)
n_train = int(0.9 * n)

train_idx, calib_idx = perm[:n_train], perm[n_train:]
train_smiles = [all_smiles[i] for i in train_idx]
train_tms = all_tms[train_idx]
calib_smiles = [all_smiles[i] for i in calib_idx]
calib_tms = all_tms[calib_idx]

print(f"\nTrain: {len(train_smiles)}, Calib: {len(calib_smiles)}")

## 2. Build Index and Calibrate

In [None]:
from src.models.hierarchical_mp_v8 import HierarchicalMPPredictorV8

predictor = HierarchicalMPPredictorV8(n_regimes=5, alpha=0.10)
predictor.fit_index(train_smiles, train_tms)

In [None]:
# Calibrate CNU
result = predictor.fit_calibration(calib_smiles, calib_tms)

## 3. Monotonicity Validation

**Lemma (Risk Ranking)**: If u(x) is monotone in each primitive uncertainty source and empirical residual quantiles increase with u(x), then u(x) is a valid risk ranking statistic.

In [None]:
# Validate monotonicity
mono_df = predictor.validate_monotonicity(calib_smiles, calib_tms)
print(mono_df)

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(mono_df['decile'], mono_df['mae'], color='steelblue', edgecolor='black')
ax.set_xlabel('Uncertainty Score u(x) Decile', fontsize=12)
ax.set_ylabel('Mean Absolute Error (K)', fontsize=12)
ax.set_title('Monotonicity Check: MAE vs. u(x) Decile', fontsize=14)
ax.set_xticks(range(1, 11))

# Add trend line
z = np.polyfit(mono_df['decile'], mono_df['mae'], 1)
p = np.poly1d(z)
ax.plot(mono_df['decile'], p(mono_df['decile']), 'r--', linewidth=2, label=f'Trend: slope={z[0]:.1f}')
ax.legend()

plt.tight_layout()
plt.savefig(FIG_DIR / 'fig_monotonicity.png', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'fig_monotonicity.pdf', bbox_inches='tight')
plt.show()

# Check monotonicity
is_monotone = all(mono_df['mae'].diff().dropna() >= -5)  # Allow small noise
print(f"\nMonotonicity check: {'PASSED' if is_monotone else 'FAILED'}")
print(f"MAE range: {mono_df['mae'].min():.1f}K to {mono_df['mae'].max():.1f}K")

## 4. Ablation Study

Contribution of each uncertainty primitive: (1-s₁), σ_w, 1/k_eff, ambiguity

In [None]:
# Ablation
ablation_df = predictor.compute_ablation(calib_smiles, calib_tms)
print(ablation_df)

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Mean width
ax1.barh(ablation_df['config'], ablation_df['mean_width'], color='steelblue')
ax1.set_xlabel('Mean Interval Width (K)', fontsize=12)
ax1.set_title('Interval Width by Configuration', fontsize=14)
ax1.invert_yaxis()

# Coverage
ax2.barh(ablation_df['config'], ablation_df['coverage'], color='seagreen')
ax2.axvline(0.90, color='red', linestyle='--', label='Target 90%')
ax2.set_xlabel('Empirical Coverage', fontsize=12)
ax2.set_title('Coverage by Configuration', fontsize=14)
ax2.invert_yaxis()
ax2.legend()

plt.tight_layout()
plt.savefig(FIG_DIR / 'fig_ablation.png', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'fig_ablation.pdf', bbox_inches='tight')
plt.show()

## 5. Per-Regime Coverage

**Theorem 2**: Regime-conditional coverage ≥ 1-α

In [None]:
# Per-regime analysis
regime_data = []
for regime, q in result.regime_quantiles.items():
    n = result.regime_counts.get(regime, 0)
    cov = result.coverage_achieved.get(regime, 0)
    regime_data.append({'regime': regime, 'quantile': q, 'n': n, 'coverage': cov})

regime_df = pd.DataFrame(regime_data).sort_values('regime')
print(regime_df)

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Quantiles
ax1.bar(regime_df['regime'], regime_df['quantile'], color='steelblue')
ax1.set_xlabel('Regime', fontsize=12)
ax1.set_ylabel('Conformal Quantile (K)', fontsize=12)
ax1.set_title('Per-Regime Error Quantiles', fontsize=14)
ax1.tick_params(axis='x', rotation=45)

# Coverage
colors = ['green' if c >= 0.88 else 'orange' if c >= 0.85 else 'red' for c in regime_df['coverage']]
ax2.bar(regime_df['regime'], regime_df['coverage'], color=colors)
ax2.axhline(0.90, color='red', linestyle='--', label='Target 90%')
ax2.set_xlabel('Regime', fontsize=12)
ax2.set_ylabel('Empirical Coverage', fontsize=12)
ax2.set_title('Per-Regime Coverage', fontsize=14)
ax2.tick_params(axis='x', rotation=45)
ax2.legend()

plt.tight_layout()
plt.savefig(FIG_DIR / 'fig_regime_coverage.png', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'fig_regime_coverage.pdf', bbox_inches='tight')
plt.show()

## 6. Learned Weights Analysis

In [None]:
# Learned weights
weights = result.weights
primitives = ['1 - s₁\n(coverage)', 'σ_w\n(disagreement)', '1/k_eff\n(sparsity)', 'ambiguity\n(gap)']

fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(primitives, weights, color='steelblue', edgecolor='black')
ax.set_ylabel('Learned Weight (w ≥ 0)', fontsize=12)
ax.set_title('Learned Uncertainty Weights via NNLS', fontsize=14)

for i, w in enumerate(weights):
    ax.annotate(f'{w:.3f}', (i, w + 0.01), ha='center', fontsize=11)

plt.tight_layout()
plt.savefig(FIG_DIR / 'fig_learned_weights.png', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'fig_learned_weights.pdf', bbox_inches='tight')
plt.show()

print(f"Weights: {dict(zip(['w_cov', 'w_var', 'w_sparse', 'w_ambig'], weights))}")

## 7. Summary Statistics

In [None]:
print("="*60)
print("CNU VALIDATION SUMMARY")
print("="*60)
print(f"\nLearned weights: {weights}")
print(f"Global quantile: {result.global_quantile:.1f}K")
print(f"\nMonotonicity: {'PASSED' if is_monotone else 'FAILED'}")
print(f"  MAE in lowest u(x) decile: {mono_df['mae'].iloc[0]:.1f}K")
print(f"  MAE in highest u(x) decile: {mono_df['mae'].iloc[-1]:.1f}K")
print(f"\nPer-regime coverage:")
for _, row in regime_df.iterrows():
    status = '✓' if row['coverage'] >= 0.88 else '⚠'
    print(f"  {status} {row['regime']}: {row['coverage']:.1%} (n={row['n']})")