In [None]:
# Training RET+ Models for Solar Flare Prediction (M5 × 72)

from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_training_data
import numpy as np

# --- Configuration ---
flare_classes = ["C", "M", "M5"]         # Only M5 for now, "C", "M", "M5"
time_windows = ["24", "48", "72"]          # 72-hour window, "24", "48", "72"
input_shape = (10, 9)
epochs = 300
batch_size = 512

# --- Loop over class × horizon ---
for flare_class in flare_classes:
    for time_window in time_windows:
        print(f"🚀 Training model for flare class {flare_class} with {time_window}h window")

        # Load & prepare training data
        X_train, y_train = get_training_data(time_window, flare_class)

        # Initialize wrapper and train (this will early-stop, save best weights + metadata, and return the model dir)
        model = RETPlusWrapper(input_shape)
        model_dir = model.train(
            X_train,
            y_train,
            epochs=epochs,
            batch_size=batch_size,
            flare_class=flare_class,
            time_window=time_window
        )

        # Report where everything landed
        print(f"✅ Best weights and metadata stored in: {model_dir}")
        print("-" * 60)

2025-05-14 13:43:28.529418: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-14 13:43:28.529517: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-14 13:43:28.677191: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-14 13:43:28.806718: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-14 13:43:57.013966: W tensorflow/core/common_

Python version: 3.11.12
Tensorflow bakcend version: 2.15.0

🚀 Training model for flare class C with 24h window
Epoch 1/300 - loss: 0.0009 - acc: 0.6948 - tss: 0.3905 - gamma: 0.00


In [None]:
# RET+ testing block

from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_testing_data
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score

# Config
flare_class = "M5"
time_window = "72"
input_shape = (10, 9)
model_path = f"retplus_weights_{flare_class}_{time_window}.pt"
mc_passes = 30
threshold = 0.5

# Load test data
X_test, y_test = get_testing_data(time_window, flare_class)
y_test = np.array(y_test)

# Load model
model = RETPlusWrapper(input_shape)
model.load(model_path)

# MC prediction
probs = model.predict_proba(X_test)
y_pred = (probs >= threshold).astype(int).squeeze()

# Metrics
cm = confusion_matrix(y_test, y_pred)
acc = accuracy_score(y_test, y_pred)
prec = precision_score(y_test, y_pred)
rec = recall_score(y_test, y_pred)
tss = rec + cm[0,0] / (cm[0,0] + cm[0,1] + 1e-8) - 1

print("Confusion matrix:\n", cm)
print(f"Accuracy:  {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall:    {rec:.4f}")
print(f"TSS:       {tss:.4f}")