# SDO Solar Flare Prediction Model GPU Training

This notebook imports the existing code from the SDO Models repository and runs the model training on GPU in Google Colab.

## 1. Environment Setup and Repository Cloning

In [None]:
# Install required packages from requirements.txt
%pip install torch>=1.12.0 torchvision>=0.13.0 numpy>=1.22.0 pandas>=1.4.0 scikit-learn>=1.0.0 \
    matplotlib>=3.5.0 seaborn>=0.11.0 tqdm>=4.64.0 pillow>=9.0.0 h5py>=3.7.0 \
    opencv-python>=4.5.0 pytorch-lightning>=1.8.0 transformers>=4.21.0 captum>=0.5.0 shap>=0.41.0

# Check if GPU is available
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

# Clone the repository (replace with your actual repository URL)
!git clone https://github.com/your-username/SDOModels.git /content/SDOModels
%cd /content/SDOModels

# Setup for Google Drive for saving models
from google.colab import drive
drive.mount('/content/drive')
!mkdir -p /content/drive/MyDrive/SDOBenchmark/models
!mkdir -p /content/drive/MyDrive/SDOBenchmark/results

: 

## 2. Set Up Python Path

In [None]:
# Add the repository to Python path
import sys
import os
sys.path.append('/content/SDOModels')

# Check the repository structure
!ls -la

## 3. Download and Prepare Dataset

In [None]:
# Define dataset paths
DATA_URL = "https://github.com/i4Ds/SDOBenchmark/archive/data-full.zip"
DOWNLOAD_PATH = "/content/data-full.zip"
EXTRACT_PATH = "/content"
DATASET_PATH = "/content/SDOBenchmark_data"

# Create download and extraction functions
import os
import urllib.request
from zipfile import ZipFile
import logging
import sys

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Create necessary directories
os.makedirs(DATASET_PATH, exist_ok=True)

# Function to download data with progress reporting
def download_with_progress(url, output_path):
    try:
        logger.info(f"Downloading data from {url}")

        def report_progress(block_num, block_size, total_size):
            downloaded = block_num * block_size
            percent = min(100, downloaded * 100 / total_size)
            if total_size > 0:
                sys.stdout.write(f"\rDownloaded {downloaded/1024/1024:.1f} MB of {total_size/1024/1024:.1f} MB ({percent:.1f}%)")
                sys.stdout.flush()

        urllib.request.urlretrieve(url, output_path, reporthook=report_progress)
        logger.info("\nDownload completed successfully!")
        return True
    except Exception as e:
        logger.error(f"Error downloading data: {str(e)}")
        return False

# Function to extract dataset with validation
def extract_with_validation(zip_path, extract_path):
    try:
        logger.info(f"Extracting {zip_path} to {extract_path}")
        with ZipFile(zip_path, 'r') as zip_ref:
            file_list = zip_ref.namelist()
            logger.info(f"Found {len(file_list)} files in the archive")
            zip_ref.extractall(extract_path)

        extracted_dir = os.path.join(extract_path, "SDOBenchmark-data-full")
        if os.path.exists(extracted_dir):
            logger.info("Extraction completed successfully!")
            return True
        else:
            logger.error(f"Extraction failed: {extracted_dir} not found")
            return False
    except Exception as e:
        logger.error(f"Error extracting data: {str(e)}")
        return False

# Function to organize extracted data
def organize_data(source_dir, target_dir):
    try:
        logger.info(f"Moving files from {source_dir} to {target_dir}")
        os.makedirs(target_dir, exist_ok=True)
        os.system(f"cp -r {source_dir}/* {target_dir}/")

        # Check for metadata files
        training_meta = os.path.join(target_dir, "training", "meta_data.csv")
        test_meta = os.path.join(target_dir, "test", "meta_data.csv")
        
        if os.path.exists(training_meta):
            logger.info(f"Found training metadata file at {training_meta}")
        else:
            logger.warning(f"Training metadata file not found at {training_meta}")
            
        if os.path.exists(test_meta):
            logger.info(f"Found test metadata file at {test_meta}")
        else:
            logger.warning(f"Test metadata file not found at {test_meta}")

        for subdir in ["training", "test"]:
            expected_dir = os.path.join(target_dir, subdir)
            if os.path.exists(expected_dir):
                logger.info(f"Successfully copied {subdir} directory")
            else:
                logger.error(f"Failed to copy {subdir} directory")
                return False
        return True
    except Exception as e:
        logger.error(f"Error organizing data: {str(e)}")
        return False

# Check if metadata files already exist
training_meta = os.path.join(DATASET_PATH, "training", "meta_data.csv")
test_meta = os.path.join(DATASET_PATH, "test", "meta_data.csv")
metadata_exists = os.path.exists(training_meta) and os.path.exists(test_meta)

# Download and prepare the dataset
if not os.path.exists(os.path.join(DATASET_PATH, "training")) or \
   not os.path.exists(os.path.join(DATASET_PATH, "test")) or \
   not metadata_exists:

    # 1. Download dataset
    if not os.path.exists(DOWNLOAD_PATH):
        success = download_with_progress(DATA_URL, DOWNLOAD_PATH)
        if not success:
            raise RuntimeError("Failed to download the dataset")
    else:
        logger.info(f"Using existing download at {DOWNLOAD_PATH}")

    # 2. Extract dataset
    extracted_dir = os.path.join(EXTRACT_PATH, "SDOBenchmark-data-full")
    if not os.path.exists(extracted_dir):
        success = extract_with_validation(DOWNLOAD_PATH, EXTRACT_PATH)
        if not success:
            raise RuntimeError("Failed to extract the dataset")
    else:
        logger.info(f"Using existing extracted data at {extracted_dir}")

    # 3. Organize data
    success = organize_data(extracted_dir, DATASET_PATH)
    if not success:
        raise RuntimeError("Failed to organize the dataset")
else:
    logger.info(f"Dataset already exists at {DATASET_PATH} with metadata files")

# Verify dataset structure and check for metadata files
logger.info("Dataset structure verification:")
!ls -la {DATASET_PATH}

# Check for metadata files in test and training folders
if os.path.exists(training_meta):
    logger.info(f"Training metadata file found: {training_meta}")
    !head {training_meta}
else:
    logger.error(f"Training metadata file not found at expected path: {training_meta}")
    raise FileNotFoundError(f"Training metadata file not found: {training_meta}")

if os.path.exists(test_meta):
    logger.info(f"Test metadata file found: {test_meta}")
    !head {test_meta}
else:
    logger.error(f"Test metadata file not found at expected path: {test_meta}")
    raise FileNotFoundError(f"Test metadata file not found: {test_meta}")

## 4. Import and Test Dataset Functionality

In [None]:
# Import the dataset functionality from the repository
from data.preprocessing import SDOBenchmarkDataset, get_data_loaders, SDODataAugmentation
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Define paths to metadata files
training_meta = os.path.join(DATASET_PATH, "training", "meta_data.csv")
test_meta = os.path.join(DATASET_PATH, "test", "meta_data.csv")

# Verify metadata files exist
if not os.path.exists(training_meta):
    raise FileNotFoundError(f"Training metadata file not found: {training_meta}")
if not os.path.exists(test_meta):
    raise FileNotFoundError(f"Test metadata file not found: {test_meta}")

print(f"Training metadata file exists: {os.path.exists(training_meta)}")
print(f"Test metadata file exists: {os.path.exists(test_meta)}")

# Create data loaders using the existing metadata files
data_loaders = get_data_loaders(
    data_path=DATASET_PATH,
    metadata_path={
        'train': training_meta,
        'test': test_meta
    },
    batch_size=8,
    img_size=128,
    num_workers=2
)

# Test the data loaders
print(f"Number of training batches: {len(data_loaders['train'])}")
print(f"Number of validation batches: {len(data_loaders['val'])}")
print(f"Number of test batches: {len(data_loaders['test'])}")

# Visualize a sample
batch = next(iter(data_loaders['train']))
print("Batch keys:", batch.keys())
print(f"Magnetogram shape: {batch['magnetogram'].shape}")
print(f"EUV shape: {batch['euv'].shape}")

# Plot an example
sample_idx = 0
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(batch['magnetogram'][sample_idx, 0, 0].numpy(), cmap='gray')
axes[0].set_title("Magnetogram (t=0)")
axes[0].axis('off')

axes[1].imshow(batch['euv'][sample_idx, 0, 1].numpy(), cmap='hot') # Show 131Å channel
axes[1].set_title("EUV 131Å (t=0)")
axes[1].axis('off')

plt.tight_layout()
plt.show()

# Print sample metadata
print(f"Sample ID: {batch['sample_id'][sample_idx]}")
print(f"Peak flux (log10): {batch['peak_flux'][sample_idx].item()}")
print(f"GOES class: {batch['goes_class'][sample_idx]}")

## 5. Import Model Architecture

In [None]:
# Import model components
from models.model import SolarFlareModel, SolarFlareLoss, PhysicsInformedRegularization

# Create model instance with desired configuration
model_config = {
    'magnetogram_channels': 1,
    'euv_channels': 8,
    'pretrained': True,
    'freeze_backbones': False,
    'use_attention': True,
    'fusion_method': 'concat',
    'temporal_type': 'lstm',
    'temporal_hidden_size': 512,
    'temporal_num_layers': 2,
    'dropout': 0.1,
    'final_hidden_size': 512,
    'use_uncertainty': True,
    'use_multi_task': True
}

# Initialize model
model = SolarFlareModel(**model_config)
print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")

# Create loss function
loss_config = {
    'regression_weight': 1.0,
    'c_vs_0_weight': 0.5,
    'm_vs_c_weight': 0.5,
    'm_vs_0_weight': 0.5,
    'use_uncertainty': True,
    'uncertainty_weight': 0.1,
    'use_multi_task': True
}

criterion = SolarFlareLoss(**loss_config)

# Create physics-informed regularization
physics_reg = PhysicsInformedRegularization(weight=0.1)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(f"Model moved to device: {device}")

## 6. Set Up Training Module

In [None]:
# Import training module
from training.train import SolarFlareModule

# Create configurations
model_config = {
    'magnetogram_channels': 1,
    'euv_channels': 8,
    'pretrained': True,
    'freeze_backbones': False,
    'use_attention': True,
    'fusion_method': 'concat',
    'temporal_type': 'lstm',
    'temporal_hidden_size': 512,
    'temporal_num_layers': 2,
    'dropout': 0.1,
    'final_hidden_size': 512,
    'use_uncertainty': True,
    'use_multi_task': True
}

loss_config = {
    'regression_weight': 1.0,
    'c_vs_0_weight': 0.5,
    'm_vs_c_weight': 0.5,
    'm_vs_0_weight': 0.5,
    'use_uncertainty': True,
    'uncertainty_weight': 0.1,
    'use_multi_task': True,
    'physics_reg_weight': 0.1,
    'dynamic_weighting': True
}

optimizer_config = {
    'lr': 5e-5,
    'weight_decay': 0.001,
    'scheduler': 'cosine',
    'use_warmup': True,
    'warmup_epochs': 5,
    't_0': 20,
    't_mult': 2,
    'eta_min': 1e-7
}

# Define paths to metadata files
training_meta = os.path.join(DATASET_PATH, "training", "meta_data.csv")
test_meta = os.path.join(DATASET_PATH, "test", "meta_data.csv")

# Update data_config to use the correct metadata paths
data_config = {
    'data_path': DATASET_PATH,
    'metadata_path': {
        'train': training_meta,
        'test': test_meta
    },
    'batch_size': 8,
    'img_size': 128,
    'num_workers': 2,
    'sample_type': 'all'
}

# Create the PyTorch Lightning module
model = SolarFlareModule(
    model_config=model_config,
    loss_config=loss_config,
    optimizer_config=optimizer_config,
    data_config=data_config
)

## 7. Train the Model

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

# Create output directories with correct paths
MODEL_OUTPUT_DIR = '/content/drive/MyDrive/SDOBenchmark/models'
LOGS_OUTPUT_DIR = '/content/drive/MyDrive/SDOBenchmark/logs'
RESULTS_OUTPUT_DIR = '/content/drive/MyDrive/SDOBenchmark/results'

os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
os.makedirs(LOGS_OUTPUT_DIR, exist_ok=True)
os.makedirs(RESULTS_OUTPUT_DIR, exist_ok=True)

# Configure callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=MODEL_OUTPUT_DIR,
    filename='sdo_flare_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    monitor='val_loss',
    mode='min',
    save_last=True
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    mode='min'
)

lr_monitor = LearningRateMonitor(logging_interval='epoch')

# Configure logger
logger = TensorBoardLogger(
    save_dir=LOGS_OUTPUT_DIR,
    name='sdo_flare_model',
    default_hp_metric=False
)

# Configure trainer
trainer = pl.Trainer(
    max_epochs=50,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback, early_stopping, lr_monitor],
    logger=logger,
    log_every_n_steps=10,
    gradient_clip_val=1.0,
    precision=16 if torch.cuda.is_available() else 32  # Use mixed precision on GPU
)

# Train the model
trainer.fit(
    model,
    train_dataloaders=data_loaders['train'],
    val_dataloaders=data_loaders['val']
)

# Print best model path
print(f"Best model checkpoint: {checkpoint_callback.best_model_path}")
print(f"Best validation loss: {checkpoint_callback.best_model_score:.4f}")

## 8. Evaluate the Model

In [None]:
# Load the best model checkpoint
best_model_path = checkpoint_callback.best_model_path
model = SolarFlareModule.load_from_checkpoint(best_model_path)
model.eval()

# Run test evaluation
test_results = trainer.test(model, dataloaders=data_loaders['test'])
print(f"Test results: {test_results}")

# Save evaluation results
import json
results_file = os.path.join(RESULTS_OUTPUT_DIR, 'test_results.json')
with open(results_file, 'w') as f:
    json.dump(test_results, f)
print(f"Saved test results to: {results_file}")

## 9. Visualize Results and Predictions

In [None]:
# Visualize predictions on test set
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve, auc

# Get predictions on test set
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

all_preds = []
all_targets = []
all_sample_ids = []

with torch.no_grad():
    for batch in data_loaders['test']:
        # Move inputs to device
        magnetogram = batch['magnetogram'].to(device)
        euv = batch['euv'].to(device)
        
        # Get predictions
        outputs = model(magnetogram, euv)
        
        # Extract values
        pred_flux = outputs['regression'][0].cpu().numpy() if isinstance(outputs['regression'], tuple) else outputs['regression'].cpu().numpy()
        target_flux = batch['peak_flux'].numpy()
        
        # Store predictions and targets
        for i in range(len(pred_flux)):
            all_preds.append({
                'flux': pred_flux[i][0],
                'c_vs_0': outputs['c_vs_0'][i].item(),
                'm_vs_c': outputs['m_vs_c'][i].item(),
                'm_vs_0': outputs['m_vs_0'][i].item()
            })
            all_targets.append({
                'flux': target_flux[i].item(),
                'c_vs_0': batch['is_c_or_above'][i].item(),
                'm_vs_c': batch['is_m_or_above'][i].item(),
                'm_vs_0': batch['is_m_vs_quiet'][i].item()
            })
            all_sample_ids.append(batch['sample_id'][i])

# Convert to numpy arrays for easier plotting
pred_flux = np.array([p['flux'] for p in all_preds])
target_flux = np.array([t['flux'] for t in all_targets])

pred_c_vs_0 = np.array([p['c_vs_0'] for p in all_preds])
target_c_vs_0 = np.array([t['c_vs_0'] for t in all_targets])

# Plot regression results
plt.figure(figsize=(8, 8))
plt.scatter(target_flux, pred_flux, alpha=0.5)
plt.plot([-8, -3], [-8, -3], 'r--')  # Diagonal line
plt.xlabel('True log10(Peak Flux)')
plt.ylabel('Predicted log10(Peak Flux)')
plt.title('Regression Performance')
plt.grid(True, alpha=0.3)
plt.tight_layout()

# Save results to the correct paths
PLOTS_DIR = os.path.join(RESULTS_OUTPUT_DIR, 'plots')
os.makedirs(PLOTS_DIR, exist_ok=True)
regression_plot_file = os.path.join(PLOTS_DIR, 'regression_performance.png')
plt.savefig(regression_plot_file)
plt.show()
print(f"Saved regression plot to: {regression_plot_file}")

# Plot classification results
plt.figure(figsize=(10, 5))

# ROC curve
plt.subplot(1, 2, 1)
fpr, tpr, _ = roc_curve(target_c_vs_0, pred_c_vs_0)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve (C-class or above)')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)

# Confusion matrix
plt.subplot(1, 2, 2)
pred_c_vs_0_binary = (pred_c_vs_0 > 0.5).astype(int)
cm = confusion_matrix(target_c_vs_0, pred_c_vs_0_binary)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (C-class or above)')

plt.tight_layout()
classification_plot_file = os.path.join(PLOTS_DIR, 'classification_performance.png')
plt.savefig(classification_plot_file)
plt.show()
print(f"Saved classification plot to: {classification_plot_file}")

# Save predictions for further analysis
predictions_df = pd.DataFrame({
    'sample_id': all_sample_ids,
    'true_flux': [t['flux'] for t in all_targets],
    'pred_flux': [p['flux'] for p in all_preds],
    'true_c_vs_0': [t['c_vs_0'] for t in all_targets],
    'pred_c_vs_0': [p['c_vs_0'] for p in all_preds],
    'true_m_vs_c': [t['m_vs_c'] for t in all_targets],
    'pred_m_vs_c': [p['m_vs_c'] for p in all_preds],
    'true_m_vs_0': [t['m_vs_0'] for t in all_targets],
    'pred_m_vs_0': [p['m_vs_0'] for p in all_preds]
})

PREDICTIONS_PATH = os.path.join(RESULTS_OUTPUT_DIR, 'predictions.csv')
predictions_df.to_csv(PREDICTIONS_PATH, index=False)
print(f"Saved predictions to: {PREDICTIONS_PATH}")