In [1]:
# Training RET+ Models for Solar Flare Prediction (C, M, M5 × 24, 48, 72)

from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_training_data, supported_flare_class
import numpy as np
import os

# --- Configuration ---
flare_classes = ["M"] # C, and M5 taken out temporarily
time_windows = ["72"] # M has already been trained on 24hrs
input_shape = (10, 9)
epochs = 300
batch_size = 512

# --- Loop over all class × horizon combinations ---
for flare_class in flare_classes:
    for time_window in time_windows:
        print(f"🚀 Training model for flare class {flare_class} with {time_window}h window")

        # Path to save the model
        model_path = f"retplus_weights_{flare_class}_{time_window}.pt"

        # Load & prepare training data
        X_train, y_train = get_training_data(time_window, flare_class)
        X_train = np.array(X_train)
        y_train = np.array(y_train)

        # Initialize and train model
        model = RETPlusWrapper(input_shape)
        model.train(X_train, y_train, epochs=epochs, batch_size=batch_size)

        # Save trained weights
        model.save(model_path, flare_class, time_window)
        print(f"✅ Saved model to {model_path}\n{'-'*60}")

2025-05-11 18:02:05.922284: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-11 18:02:05.922364: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-11 18:02:06.098627: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-11 18:02:06.210240: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-11 18:02:31.227527: W tensorflow/core/common_

Python version: 3.11.12
Tensorflow bakcend version: 2.15.0

🚀 Training model for flare class M with 72h window
Epoch 1/300 - loss: 0.1354 - acc: 0.9488 - tss: 0.0090 - gamma: 0.00
Epoch 2/300 - loss: 0.1156 - acc: 0.9507 - tss: 0.0661 - gamma: 0.04
Epoch 3/300 - loss: 0.0990 - acc: 0.9535 - tss: 0.1418 - gamma: 0.08
Epoch 4/300 - loss: 0.0834 - acc: 0.9585 - tss: 0.2554 - gamma: 0.12
Epoch 5/300 - loss: 0.0713 - acc: 0.9621 - tss: 0.3357 - gamma: 0.16
Epoch 6/300 - loss: 0.0626 - acc: 0.9649 - tss: 0.4124 - gamma: 0.20
Epoch 7/300 - loss: 0.0554 - acc: 0.9680 - tss: 0.4824 - gamma: 0.24
Epoch 8/300 - loss: 0.0495 - acc: 0.9703 - tss: 0.5362 - gamma: 0.28
Epoch 9/300 - loss: 0.0449 - acc: 0.9721 - tss: 0.5740 - gamma: 0.32
Epoch 10/300 - loss: 0.0412 - acc: 0.9736 - tss: 0.6049 - gamma: 0.36
Epoch 11/300 - loss: 0.0379 - acc: 0.9750 - tss: 0.6343 - gamma: 0.40
Epoch 12/300 - loss: 0.0348 - acc: 0.9765 - tss: 0.6608 - gamma: 0.44
Epoch 13/300 - loss: 0.0323 - acc: 0.9775 - tss: 0.6832 - 

In [2]:
# RET+ testing block

from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_testing_data
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score

# Config
flare_class = "M5"
time_window = "72"
input_shape = (10, 9)
model_path = f"retplus_weights_{flare_class}_{time_window}.pt"
mc_passes = 30
threshold = 0.5

# Load test data
X_test, y_test = get_training_data(time_window, flare_class)
y_test = np.array(y_test)

# Load model
model = RETPlusWrapper(input_shape)
model.load(model_path)

# MC prediction
probs = model.predict_proba(X_test)
y_pred = (probs >= threshold).astype(int).squeeze()

# Metrics
cm = confusion_matrix(y_test, y_pred)
acc = accuracy_score(y_test, y_pred)
prec = precision_score(y_test, y_pred)
rec = recall_score(y_test, y_pred)
tss = rec + cm[0,0] / (cm[0,0] + cm[0,1] + 1e-8) - 1

print("Confusion matrix:\n", cm)
print(f"Accuracy:  {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall:    {rec:.4f}")
print(f"TSS:       {tss:.4f}")

OutOfMemoryError: CUDA out of memory. Tried to allocate 10.15 GiB. GPU 0 has a total capacity of 23.46 GiB of which 9.42 GiB is free. Including non-PyTorch memory, this process has 14.04 GiB memory in use. Of the allocated memory 13.79 GiB is allocated by PyTorch, and 50.25 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)