In [1]:
# RET+ precision-tuned evaluation block

from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_testing_data
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score

# --- Config ---
flare_classes = ["M5"]
time_windows = ["72"]
input_shape = (10, 9)
thresholds = np.linspace(0.1, 0.9, 81)  # Fine-grained search
min_recall = 0.75  # Optional: skip thresholds with too-low recall

# --- Evaluation Loop ---
for flare in flare_classes:
    for time in time_windows:
        model_path = f"retplus_weights_{flare}_{time}.pt"

        # Load data
        X_test, y_test = get_testing_data(time, flare)
        y_test = np.array(y_test)

        # Load model
        model = RETPlusWrapper(input_shape)
        model.load(model_path)

        # Predict probabilities
        probs = model.predict_proba(X_test).squeeze()

        # Threshold tuning
        best_score = -np.inf
        best_thresh = None
        best_metrics = {}

        for t in thresholds:
            y_pred = (probs >= t).astype(int)

            cm = confusion_matrix(y_test, y_pred)
            acc = accuracy_score(y_test, y_pred)
            prec = precision_score(y_test, y_pred, zero_division=0)
            rec = recall_score(y_test, y_pred, zero_division=0)
            f1 = f1_score(y_test, y_pred, zero_division=0)
            tss = rec + cm[0, 0] / (cm[0, 0] + cm[0, 1] + 1e-8) - 1

            # Skip thresholds with very low recall
            if rec < min_recall:
                continue

            # Precision-weighted scoring rule
            score = 0.6 * prec + 0.2 * f1 + 0.2 * tss

            if score > best_score:
                best_score = score
                best_thresh = t
                best_metrics = {
                    'confusion_matrix': cm,
                    'accuracy': acc,
                    'precision': prec,
                    'recall': rec,
                    'f1': f1,
                    'tss': tss
                }

        # --- Print Results ---
        print(f"\n🎯 Best threshold for {model_path}: {best_thresh:.2f}")
        print("Confusion matrix:\n", best_metrics['confusion_matrix'])
        print(f"Accuracy:  {best_metrics['accuracy']:.4f}")
        print(f"Precision: {best_metrics['precision']:.4f}")
        print(f"Recall:    {best_metrics['recall']:.4f}")
        print(f"F1:        {best_metrics['f1']:.4f}")
        print(f"TSS:       {best_metrics['tss']:.4f}")

2025-05-10 22:40:03.719892: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-10 22:40:03.719966: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-10 22:40:03.721081: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-10 22:40:03.728065: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-10 22:40:28.559261: W tensorflow/core/common_

Python version: 3.11.12
Tensorflow bakcend version: 2.15.0


🎯 Best threshold for retplus_weights_M5_72.pt: 0.28
Confusion matrix:
 [[71500   125]
 [   24    80]]
Accuracy:  0.9979
Precision: 0.3902
Recall:    0.7692
F1:        0.5178
TSS:       0.7675
