In [1]:
# Training RET+ Models for Solar Flare Prediction (M5 × 72)

from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_training_data
import numpy as np

# --- Configuration ---
flare_classes = ["M5"]         # Only M5 for now, "C", "M", "M5"
time_windows = ["72"]          # 72-hour window, "24", "48", "72"
input_shape = (10, 9)
epochs = 300
# batch_size = 512
batch_size = 512

# --- Loop over class × horizon ---
for flare_class in flare_classes:
    for time_window in time_windows:
        print(f"🚀 Training model for flare class {flare_class} with {time_window}h window")

        # Load & prepare training data
        X_train, y_train = get_training_data(str(time_window), flare_class)

        # Initialize wrapper and train (this will early-stop, save best weights + metadata, and return the model dir)
        model = RETPlusWrapper(input_shape)
        model_dir = model.train(
            X_train,
            y_train,
            epochs=epochs,
            batch_size=batch_size,
            flare_class=flare_class,
            time_window=time_window,
            in_memory_dataset=True
        )

        # Report where everything landed
        print(f"✅ Best weights and metadata stored in: {model_dir}")
        print("-" * 60)

2025-05-28 02:24:23.030804: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-28 02:24:23.030869: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-28 02:24:23.144224: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-28 02:24:23.274591: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-28 02:24:49.242674: W tensorflow/core/common_

TensorFlow backend version: 2.15.0
SUCCESS: PyTorch found GPU: Quadro RTX 6000
PyTorch CUDA version: 12.6
PyTorch version: 2.7.0+cu126
Python version: 3.11.12

🚀 Training model for flare class M5 with 72h window
Training on: Quadro RTX 6000 (23 GB)
Note: CO2 tracking unavailable: BaseEmissionsTracker.__init__() got an unexpected keyword argument 'offline_mode'
Epoch 1/1 - loss: 0.0338 - acc: 0.9929 - tss: 0.0042 - gamma: 0.00 - time: 33.4s

📊 Training completed in 33.5s (0.01h)
   • Average epoch time: 33.4s
   • GPU: Quadro RTX 6000
   • Average GPU power: 3.8W
Model saved to models/EVEREST-v2.5-M5-72h
✅ Best weights and metadata stored in: models/EVEREST-v2.5-M5-72h
------------------------------------------------------------


In [2]:
import os
import re
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score
from solarknowledge_ret_plus import RETPlusWrapper
from utils import get_testing_data

# Constants
input_shape = (10, 9)
threshold = 0.5
base_dir = "models"
flare_classes = ["C"]
horizons = ["72"]

# Version extraction regex
version_pattern = re.compile(r"EVEREST-v([\d.]+)-([A-Z0-9]+)-(\d+)h")

def get_latest_model_path(flare_class, time_window):
    candidates = []
    for dirname in os.listdir(base_dir):
        match = version_pattern.fullmatch(dirname)
        if match:
            version, fclass, thours = match.groups()
            if fclass == flare_class and thours == time_window:
                candidates.append((tuple(map(int, version.split("."))), dirname))
    if not candidates:
        return None
    latest = sorted(candidates)[-1][1]
    return os.path.join(base_dir, latest, "model_weights.pt")

# Iterate through flare × horizon
for flare_class in flare_classes:
    for time_window in horizons:
        model_path = get_latest_model_path(flare_class, time_window)
        if not model_path or not os.path.exists(model_path):
            print(f"⚠️ No model found for {flare_class}-{time_window}h.")
            continue

        print(f"\n🔍 Testing EVEREST on {flare_class}-class, {time_window}h horizon")
        print(f"Using model: {model_path}")

        # Load test data
        X_test, y_test = get_testing_data(time_window, flare_class)
        y_test = np.array(y_test)

        # Load model and weights
        model = RETPlusWrapper(input_shape)
        model.load(model_path)

        # Predict
        probs = model.predict_proba(X_test)
        y_pred = (probs >= threshold).astype(int).squeeze()

        # Metrics
        cm = confusion_matrix(y_test, y_pred)
        acc = accuracy_score(y_test, y_pred)
        prec = precision_score(y_test, y_pred)
        rec = recall_score(y_test, y_pred)
        tss = rec + cm[0,0] / (cm[0,0] + cm[0,1] + 1e-8) - 1

        # Output
        print("Confusion matrix:\n", cm)
        print(f"Accuracy:  {acc:.4f}")
        print(f"Precision: {prec:.4f}")
        print(f"Recall:    {rec:.4f}")
        print(f"TSS:       {tss:.4f}")

In [3]:
# Updated: Training RET+ Models for Solar Flare Prediction

from solarknowledge_ret_plus import RETPlusWrapper
# MODIFICATION: Import get_testing_data (or your equivalent function for validation/test data)
from utils import get_training_data, get_testing_data
import numpy as np

# --- Configuration ---
flare_classes = ["M"]       # Example: "C", "M", "M5"
time_windows = ["48", "72"]   # Example: "24", "48", "72"
input_shape = (10, 9)         # (sequence_length, num_features)
epochs = 300
batch_size = 512

# --- Loop over class × horizon ---
for flare_class in flare_classes:
    for time_window in time_windows:
        print(f"🚀 Training model for flare class {flare_class} with {time_window}h window")

        # Load & prepare training data
        # Ensure time_window is passed as a string if get_training_data expects it
        X_train, y_train = get_training_data(str(time_window), flare_class)

        # MODIFICATION: Load & prepare testing/validation data
        # Using X_val, y_val as variable names for clarity, as this set is used for validation
        X_val, y_val = get_testing_data(str(time_window), flare_class)

        # Check if data was loaded successfully (optional but good practice)
        if X_train is None or y_train is None:
            print(f"❌ Training data not found for {flare_class} / {time_window}h. Skipping this combination.")
            continue
        if X_val is None or y_val is None:
            print(f"❌ Validation (testing) data not found for {flare_class} / {time_window}h. Skipping this combination.")
            # Alternatively, you could decide to train without validation if that's an acceptable fallback,
            # but then the RETPlusWrapper.train method would need to handle X_val=None, y_val=None.
            # For early stopping on a test/validation set, this data is crucial.
            continue

        # Initialize wrapper
        # The input_shape parameter in RETPlusWrapper's __init__ is used to construct RETPlusModel.
        model = RETPlusWrapper(input_shape=input_shape) # Pass the defined input_shape

        # MODIFICATION: Pass validation data (X_val, y_val) to the train method
        model_dir = model.train(
            X_train,
            y_train,
            X_val,          # Pass validation X data
            y_val,          # Pass validation y data
            epochs=epochs,
            batch_size=batch_size,
            flare_class=flare_class,
            time_window=str(time_window) # Ensure time_window is a string
        )

        # Report where everything landed
        if model_dir: # Check if training completed and returned a directory
            print(f"✅ Best weights and metadata stored in: {model_dir}")
        else:
            print(f"⚠️ Model training for {flare_class} / {time_window}h did not complete successfully or was skipped.")
        print("-" * 60)

print("🏁 All training combinations processed.")

TypeError: RETPlusWrapper.train() got multiple values for argument 'epochs'