In [1]:
# Training RET+ Models for Solar Flare Prediction (M5 × 72)

from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_training_data
import numpy as np

# --- Configuration ---
flare_classes = ["C"]         # Only M5 for now, "C", "M", "M5"
time_windows = ["48", "72"]          # 72-hour window, "24", "48", "72"
input_shape = (10, 9)
epochs = 300
batch_size = 512

# --- Loop over class × horizon ---
for flare_class in flare_classes:
    for time_window in time_windows:
        print(f"🚀 Training model for flare class {flare_class} with {time_window}h window")

        # Load & prepare training data
        X_train, y_train = get_training_data(str(time_window), flare_class)

        # Initialize wrapper and train (this will early-stop, save best weights + metadata, and return the model dir)
        model = RETPlusWrapper(input_shape)
        model_dir = model.train(
            X_train,
            y_train,
            epochs=epochs,
            batch_size=batch_size,
            flare_class=flare_class,
            time_window=time_window
        )

        # Report where everything landed
        print(f"✅ Best weights and metadata stored in: {model_dir}")
        print("-" * 60)

2025-05-15 16:33:35.837900: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-15 16:33:35.837986: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-15 16:33:35.940990: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-15 16:33:36.133468: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-15 16:34:03.525939: W tensorflow/core/common_

TensorFlow backend version: 2.15.0
SUCCESS: PyTorch found GPU: Quadro RTX 6000
PyTorch CUDA version: 12.6
PyTorch version: 2.7.0+cu126
Python version: 3.11.12

🚀 Training model for flare class C with 48h window
Epoch 1/300 - loss: 0.0011 - acc: 0.6851 - tss: 0.3730 - gamma: 0.00
Epoch 2/300 - loss: 0.0010 - acc: 0.7002 - tss: 0.4037 - gamma: 0.04
Epoch 3/300 - loss: 0.0010 - acc: 0.7173 - tss: 0.4395 - gamma: 0.08
Epoch 4/300 - loss: 0.0009 - acc: 0.7360 - tss: 0.4774 - gamma: 0.12
Epoch 5/300 - loss: 0.0008 - acc: 0.7542 - tss: 0.5139 - gamma: 0.16
Epoch 6/300 - loss: 0.0008 - acc: 0.7709 - tss: 0.5473 - gamma: 0.20
Epoch 7/300 - loss: 0.0007 - acc: 0.7869 - tss: 0.5789 - gamma: 0.24
Epoch 8/300 - loss: 0.0006 - acc: 0.8016 - tss: 0.6080 - gamma: 0.28
Epoch 9/300 - loss: 0.0006 - acc: 0.8153 - tss: 0.6349 - gamma: 0.32
Epoch 10/300 - loss: 0.0005 - acc: 0.8278 - tss: 0.6593 - gamma: 0.36
Epoch 11/300 - loss: 0.0005 - acc: 0.8381 - tss: 0.6795 - gamma: 0.40
Epoch 12/300 - loss: 0.0005 

In [1]:
import os
import re
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score
from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_testing_data

# Constants
input_shape = (100, 9)
threshold = 0.5
base_dir = "models"
flare_classes = ["M5"]
horizons = ["24", "48", "72"]

# Version extraction regex
version_pattern = re.compile(r"EVEREST-v([\d.]+)-([A-Z0-9]+)-(\d+)h")

def get_latest_model_path(flare_class, time_window):
    candidates = []
    for dirname in os.listdir(base_dir):
        match = version_pattern.fullmatch(dirname)
        if match:
            version, fclass, thours = match.groups()
            if fclass == flare_class and thours == time_window:
                candidates.append((tuple(map(int, version.split("."))), dirname))
    if not candidates:
        return None
    latest = sorted(candidates)[-1][1]
    return os.path.join(base_dir, latest, "model_weights.pt")

# Iterate through flare × horizon
for flare_class in flare_classes:
    for time_window in horizons:
        model_path = get_latest_model_path(flare_class, time_window)
        if not model_path or not os.path.exists(model_path):
            print(f"⚠️ No model found for {flare_class}-{time_window}h.")
            continue

        print(f"\n🔍 Testing EVEREST on {flare_class}-class, {time_window}h horizon")
        print(f"Using model: {model_path}")

        # Load test data
        X_test, y_test = get_testing_data(time_window, flare_class)
        y_test = np.array(y_test)

        # Load model and weights
        model = RETPlusWrapper(input_shape)
        model.load(model_path)

        # Predict
        probs = model.predict_proba(X_test)
        y_pred = (probs >= threshold).astype(int).squeeze()

        # Metrics
        cm = confusion_matrix(y_test, y_pred)
        acc = accuracy_score(y_test, y_pred)
        prec = precision_score(y_test, y_pred)
        rec = recall_score(y_test, y_pred)
        tss = rec + cm[0,0] / (cm[0,0] + cm[0,1] + 1e-8) - 1

        # Output
        print("Confusion matrix:\n", cm)
        print(f"Accuracy:  {acc:.4f}")
        print(f"Precision: {prec:.4f}")
        print(f"Recall:    {rec:.4f}")
        print(f"TSS:       {tss:.4f}")

2025-05-16 11:54:57.290390: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-16 11:54:57.290464: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-16 11:54:57.291542: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-16 11:54:57.298583: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-16 11:55:23.724801: W tensorflow/core/common_

TensorFlow backend version: 2.15.0
SUCCESS: PyTorch found GPU: Quadro RTX 6000
PyTorch CUDA version: 12.6
PyTorch version: 2.7.0+cu126
Python version: 3.11.12


🔍 Testing EVEREST on M5-class, 24h horizon
Using model: models/EVEREST-v4.5-M5-24h/model_weights.pt
Confusion matrix:
 [[47596    75]
 [   52    52]]
Accuracy:  0.9973
Precision: 0.4094
Recall:    0.5000
TSS:       0.4984

🔍 Testing EVEREST on M5-class, 48h horizon
Using model: models/EVEREST-v1.0-M5-48h/model_weights.pt


RuntimeError: Error(s) in loading state_dict for RETPlusModel:
	size mismatch for pos.pe: copying a param with shape torch.Size([1, 10, 128]) from checkpoint, the shape in current model is torch.Size([1, 100, 128]).