In [1]:
# RET+ precision-tuned evaluation block

from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_testing_data
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score

# --- Config ---
flare_classes = ["M5"]
time_windows = ["72"]
input_shape = (10, 9)
thresholds = np.linspace(0.1, 0.9, 81)  # Fine-grained search
min_recall = 0.75  # Optional: skip thresholds with too-low recall

# --- Evaluation Loop ---
for flare in flare_classes:
    for time in time_windows:
        model_path = f"models/EVEREST-v1.0-{flare}-{time}h/model_weights.pt"

        # Load data
        X_test, y_test = get_testing_data(time, flare)
        y_test = np.array(y_test)

        # Load model
        model = RETPlusWrapper(input_shape)
        model.load(model_path)

        # Predict probabilities
        probs = model.predict_proba(X_test).squeeze()

        # Threshold tuning
        best_score = -np.inf
        best_thresh = None
        best_metrics = {}

        for t in thresholds:
            y_pred = (probs >= t).astype(int)

            cm = confusion_matrix(y_test, y_pred)
            acc = accuracy_score(y_test, y_pred)
            prec = precision_score(y_test, y_pred, zero_division=0)
            rec = recall_score(y_test, y_pred, zero_division=0)
            f1 = f1_score(y_test, y_pred, zero_division=0)
            tss = rec + cm[0, 0] / (cm[0, 0] + cm[0, 1] + 1e-8) - 1

            # Skip thresholds with very low recall
            if rec < min_recall:
                continue

            # Precision-weighted scoring rule
            score = 0.6 * prec + 0.2 * f1 + 0.2 * tss

            if score > best_score:
                best_score = score
                best_thresh = t
                best_metrics = {
                    'confusion_matrix': cm,
                    'accuracy': acc,
                    'precision': prec,
                    'recall': rec,
                    'f1': f1,
                    'tss': tss
                }

        # --- Print Results ---
        print(f"\n🎯 Best threshold for {model_path}: {best_thresh:.2f}")
        print("Confusion matrix:\n", best_metrics['confusion_matrix'])
        print(f"Accuracy:  {best_metrics['accuracy']:.4f}")
        print(f"Precision: {best_metrics['precision']:.4f}")
        print(f"Recall:    {best_metrics['recall']:.4f}")
        print(f"F1:        {best_metrics['f1']:.4f}")
        print(f"TSS:       {best_metrics['tss']:.4f}")

2025-05-15 10:26:33.280411: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-15 10:26:33.280477: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-15 10:26:33.281530: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-15 10:26:33.288650: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-15 10:27:01.400581: W tensorflow/core/common_

TensorFlow backend version: 2.15.0
SUCCESS: PyTorch found GPU: Quadro RTX 6000
PyTorch CUDA version: 12.6
PyTorch version: 2.7.0+cu126
Python version: 3.11.12


🎯 Best threshold for models/EVEREST-v1.0-M5-72h/model_weights.pt: 0.38
Confusion matrix:
 [[71559    66]
 [   26    78]]
Accuracy:  0.9987
Precision: 0.5417
Recall:    0.7500
F1:        0.6290
TSS:       0.7491


In [1]:
# plot_attention_and_uncertainty.py

import os
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
from scipy.stats import spearmanr, kendalltau

from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_testing_data
from model_tracking import get_latest_version

# ── Configuration ──────────────────────────────────────────────────────────────
flare       = "C"
time_window = "72"
input_shape = (10, 9)
threshold   = 0.5
n_per_cat   = 5    # number of samples per TP/TN/FP/FN
batch_size  = 512

device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps"  if torch.backends.mps.is_available() else
    "cpu"
)

out_attention = "attention_plots"
out_uncert    = "uncertainty_plots"
out_evt       = "evt_plots"
os.makedirs(out_attention, exist_ok=True)
os.makedirs(out_uncert, exist_ok=True)
os.makedirs(out_evt, exist_ok=True)

# ── Load test data & model ────────────────────────────────────────────────────
X_test, y_test = get_testing_data(time_window, flare)
y_test = np.array(y_test)

version    = get_latest_version(flare, time_window)
model_dir  = f"models/EVEREST-v{version}-{flare}-{time_window}h"
weights_fp = os.path.join(model_dir, "model_weights.pt")

model = RETPlusWrapper(input_shape)
model.load(weights_fp)
model.model.to(device).eval()

# ── 1) Attention heatmaps ────────────────────────────────────────────────────
probs = model.predict_proba(X_test).ravel()
preds = (probs >= threshold).astype(int)

idx_tp = np.where((preds==1) & (y_test==1))[0]
idx_tn = np.where((preds==0) & (y_test==0))[0]
idx_fp = np.where((preds==1) & (y_test==0))[0]
idx_fn = np.where((preds==0) & (y_test==1))[0]
categories = {"TP": idx_tp, "TN": idx_tn, "FP": idx_fp, "FN": idx_fn}

att_storage = {}
def att_hook(module, inp, out):
    # out shape: [batch, T, 1]
    att_storage['w'] = out.detach().cpu().numpy()
model.model.att_pool.register_forward_hook(att_hook)

for cat, idxs in categories.items():
    if len(idxs) == 0:
        continue
    chosen = random.sample(list(idxs), min(n_per_cat, len(idxs)))
    for i in chosen:
        x_np = X_test[i:i+1]  # shape (1, T, F)
        x_t  = torch.tensor(x_np, dtype=torch.float32).to(device)
        _    = model.model(x_t)  # triggers hook
        weights = att_storage['w'][0, :, 0]  # shape (T,)

        fig, ax = plt.subplots(figsize=(6, 1.5))
        im = ax.imshow(weights[np.newaxis, :], aspect='auto')
        ax.set_title(f"{cat} idx={i} true={y_test[i]} prob={probs[i]:.2f}")
        ax.set_yticks([])
        ax.set_xticks(np.arange(len(weights)))
        ax.set_xticklabels(np.arange(1, len(weights) + 1))
        fig.colorbar(im, ax=ax, label="Attention weight")
        fig.tight_layout()

        fn = os.path.join(out_attention,
                          f"{flare}_{time_window}h_{cat}_{i}.png")
        fig.savefig(fn, dpi=200)
        plt.close(fig)

# ── 2) Epistemic & Aleatoric variance violin plots ────────────────────────────
# Collect evidential outputs for all test samples
all_mu, all_v, all_a, all_b = [], [], [], []
for start in range(0, len(X_test), batch_size):
    xb = torch.tensor(X_test[start:start+batch_size],
                      dtype=torch.float32).to(device)
    with torch.no_grad():
        evid = model.model(xb)['evid']  # shape [B,4]
    mu,v,a,b = evid.split(1, dim=-1)
    all_mu.append(mu.cpu().numpy())
    all_v.append(v.cpu().numpy())
    all_a.append(a.cpu().numpy())
    all_b.append(b.cpu().numpy())

mu_arr = np.vstack(all_mu).squeeze(-1)
v_arr  = np.vstack(all_v).squeeze(-1)
a_arr  = np.vstack(all_a).squeeze(-1)
b_arr  = np.vstack(all_b).squeeze(-1)

# Compute total, epistemic, and aleatoric variances
var_total     = b_arr / ((a_arr - 1) * v_arr + 1e-12)
var_epistemic = var_total / a_arr
var_aleatoric = var_total * (1 - 1/a_arr)

# Build DataFrame
df_unc = pd.DataFrame({
    'Epistemic': var_epistemic.flatten(),
    'Aleatoric': var_aleatoric.flatten(),
    'Outcome': np.where((preds==1)&(y_test==1),'TP',
                np.where((preds==0)&(y_test==0),'TN',
                np.where((preds==1)&(y_test==0),'FP','FN')))
})

outcomes = ['TP','TN','FP','FN']
for col in ['Epistemic','Aleatoric']:
    data = [df_unc.loc[df_unc['Outcome']==o, col].values for o in outcomes]
    fig, ax = plt.subplots(figsize=(6,4))
    parts = ax.violinplot(
        data,
        positions=range(len(outcomes)),
        showmeans=True,
        showmedians=True,
        showextrema=False
    )
    ax.set_xticks(range(len(outcomes)))
    ax.set_xticklabels(outcomes)
    ax.set_xlabel("Prediction Outcome")
    ax.set_ylabel(f"{col} Variance")
    ax.set_title(f"{col} Uncertainty by Outcome")
    for pc in parts['bodies']:
        pc.set_edgecolor('black')
        pc.set_alpha(0.7)
    fig.tight_layout()
    fn = os.path.join(out_uncert, f"{col}_uncertainty_violin.png")
    fig.savefig(fn, dpi=200)
    plt.close(fig)

# ── 3) EVT score distributions & correlations ────────────────────────────────
flare_classes = ["C"]
evt_scores = []
class_labels = []

for fc in flare_classes:
    X_t, y_t = get_testing_data(time_window, fc)
    X_t = np.array(X_t)
    version = get_latest_version(fc, time_window)
    wd = f"models/EVEREST-v{version}-{fc}-{time_window}h"
    wp = os.path.join(wd, "model_weights.pt")

    m = RETPlusWrapper(input_shape)
    m.load(wp)
    m.model.to(device).eval()

    xi_list, sig_list = [], []
    for start in range(0, len(X_t), batch_size):
        xb = torch.tensor(X_t[start:start+batch_size],
                          dtype=torch.float32).to(device)
        with torch.no_grad():
            gpd_out = m.model(xb)['gpd']  # shape [B,2]
        xi, sig = gpd_out.split(1, dim=-1)
        xi_list.append(xi.cpu().numpy().squeeze(-1))
        sig_list.append(sig.cpu().numpy().squeeze(-1))

    xi_arr  = np.concatenate(xi_list)
    sig_arr = np.concatenate(sig_list)
    score   = xi_arr * sig_arr

    evt_scores.append(score)
    class_labels.append(np.full_like(score, fill_value=flare_classes.index(fc)))

# Boxplot of EVT score by flare class
fig, ax = plt.subplots(figsize=(6,4))
ax.boxplot(evt_scores, labels=flare_classes)
ax.set_title("EVT Score (ξ·σ) by Flare Class")
ax.set_xlabel("Flare Class")
ax.set_ylabel("ξ·σ")
fig.tight_layout()
fig.savefig(os.path.join(out_evt, "evt_score_boxplot.png"), dpi=200)
plt.close(fig)

# Summary & correlations
print("EVT Score Summary by Class:")
for fc, score in zip(flare_classes, evt_scores):
    print(f"  {fc}: min={score.min():.4f}, mean={score.mean():.4f}, max={score.max():.4f}")

all_scores  = np.concatenate(evt_scores)
all_classes = np.concatenate(class_labels)

if len(np.unique(all_classes)) > 1 and np.nanstd(all_scores) > 0:
    rho, _ = spearmanr(all_classes, all_scores)
    tau, _ = kendalltau(all_classes, all_scores)
    print(f"Spearman ρ = {rho:.3f}")
    print(f"Kendall τ  = {tau:.3f}")
else:
    print("Not enough variation to compute Spearman/Kendall correlations.")

print("Done. Check output folders:")
print(f"  Attention plots: {out_attention}")
print(f"  Uncertainty violin plots: {out_uncert}")
print(f"  EVT plots: {out_evt}")


2025-05-14 20:17:56.100450: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-14 20:17:56.100522: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-14 20:17:56.101580: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-14 20:17:56.108909: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-14 20:18:24.840535: W tensorflow/core/common_

TensorFlow backend version: 2.15.0
SUCCESS: PyTorch found GPU: Quadro RTX 6000
PyTorch CUDA version: 12.6
PyTorch version: 2.7.0+cu126
Python version: 3.11.12

EVT Score Summary by Class:
  C: min=-40.9388, mean=-8.9765, max=1.9916
Not enough variation to compute Spearman/Kendall correlations.
Done. Check output folders:
  Attention plots: attention_plots
  Uncertainty violin plots: uncertainty_plots
  EVT plots: evt_plots
Representative attention maps:
  TP: attention_plots/C_72h_TP_35590.png  (prob=0.985)
  TN: attention_plots/C_72h_TN_57907.png  (prob=0.018)
  FP: attention_plots/C_72h_FP_13223.png  (prob=0.570)
  FN: attention_plots/C_72h_FN_5215.png  (prob=0.420)
Representative attention maps:
  TP: attention_plots/C_72h_TP_35590.png  (prob=0.985)
  TN: attention_plots/C_72h_TN_57907.png  (prob=0.018)
  FP: attention_plots/C_72h_FP_13223.png  (prob=0.570)
  FN: attention_plots/C_72h_FN_5215.png  (prob=0.420)
Representative attention maps:
  TP: attention_plots/C_72h_TP_35590.png  

In [4]:
import os, glob, re
import numpy as np

# --- assume you’ve already run the model on X_test ---
probs = model.predict_proba(X_test).ravel()
preds = (probs >= threshold).astype(int)

flare       = "C"
time_window = "72"
threshold   = 0.5

# grab the probs & preds again (if needed)
probs = model.predict_proba(X_test).ravel()
preds = (probs >= threshold).astype(int)

# list all your heatmap files
files = glob.glob(f"attention_plots/{flare}_{time_window}h_*.png")

# pattern to capture category and index
pat = re.compile(rf"{flare}_{time_window}h_(TP|TN|FP|FN)_(\d+)\.png")

buckets = {"TP": [], "TN": [], "FP": [], "FN": []}
for fn in files:
    m = pat.search(os.path.basename(fn))
    if not m:
        continue
    cat, idx = m.group(1), int(m.group(2))
    p = float(probs[idx])
    buckets[cat].append((fn, p))

# pick one per category nearest to that category's median prob
picked = {}
for cat, lst in buckets.items():
    if not lst:
        continue
    ps = np.array([p for (_, p) in lst])
    med = np.median(ps)
    best = min(lst, key=lambda x: abs(x[1] - med))
    picked[cat] = best

print("Representative attention maps:")
for cat in ["TP","TN","FP","FN"]:
    fn, p = picked[cat]
    print(f"  {cat}: {fn}  (prob={p:.3f})")
