In [2]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support

# FIX: Remove double SDOBenchmark-data-full path
DATA_DIR = r"F:\DevProjects\Projects\Solar_Flare_Prediction\data\SDOBenchmark-data-full"

TRAIN_IMG_DIR = os.path.join(DATA_DIR, "training")
TEST_IMG_DIR  = os.path.join(DATA_DIR, "test")

TRAIN_CSV = os.path.join(TRAIN_IMG_DIR, "train.csv")
TEST_CSV  = os.path.join(TEST_IMG_DIR,  "test.csv")

# Read CSVs with correct header settings
train_df = pd.read_csv(TRAIN_CSV, header=None, 
                       names=["id", "start", "end", "peak_flux"])
test_df  = pd.read_csv(TEST_CSV, header=None, 
                       names=["id", "start", "end", "peak_flux"])

# Convert peak_flux to numeric
train_df["peak_flux"] = pd.to_numeric(train_df["peak_flux"], errors="coerce")
test_df["peak_flux"]  = pd.to_numeric(test_df["peak_flux"],  errors="coerce")

# Drop NaN rows
train_df = train_df.dropna(subset=["peak_flux"]).reset_index(drop=True)
test_df  = test_df.dropna(subset=["peak_flux"]).reset_index(drop=True)

print(train_df.head())
print("Train shape:", train_df.shape)


                            id                          start  \
0  11390_2012_01_05_17_06_01_0  2012-01-05 05:06:01.000000000   
1  11390_2012_01_05_17_19_01_0  2012-01-05 05:19:01.000000000   
2  11390_2012_01_05_17_19_01_1  2012-01-06 05:19:00.000000000   
3  11390_2012_01_06_17_20_58_0  2012-01-06 05:20:58.000000000   
4  11390_2012_01_04_07_22_01_0  2012-01-03 19:22:01.000000000   

                             end     peak_flux  
0  2012-01-05 17:06:01.000000000  8.000000e-07  
1  2012-01-05 17:19:01.000000000  1.647059e-06  
2  2012-01-06 17:19:00.000000000  1.647059e-06  
3  2012-01-06 17:20:58.000000000  1.164706e-06  
4  2012-01-04 07:22:01.000000000  2.235294e-06  
Train shape: (8336, 4)


In [3]:
def get_sample_dir(img_root, sample_id):
    """
    sample_id like: 11390_2012_01_05_17_19_01_0
    Returns: img_root/11390/2012_01_05_17_19_01_0
    """
    region, rest = sample_id.split("_", 1)
    return os.path.join(img_root, region, rest)

CHANNEL_TAGS = ["94", "131", "171", "193", "211",
                "304", "335", "1700", "continuum", "magnetogram"]

def parse_file(fname):
    """
    Example: 2012-01-05T051901__171.jpg
    Returns: (2012-01-05T051901, 171)
    """
    base = fname[:-4]  # drop .jpg
    ts_str, tag = base.split("__")
    return ts_str, tag

def load_sample_from_id(img_root, sample_id):
    """
    Load all 40 images (4 timestamps x 10 channels) for one sample.
    Returns: (40, 256, 256) array
    """
    sample_dir = get_sample_dir(img_root, sample_id)
    fnames = sorted(os.listdir(sample_dir))

    # Group files by timestamp
    by_ts = {}
    for f in fnames:
        ts_str, tag = parse_file(f)
        by_ts.setdefault(ts_str, {})[tag] = f

    # Pick up to 4 earliest timestamps
    ts_list = sorted(by_ts.keys())[:4]

    imgs = []
    for ts in ts_list:
        tag2file = by_ts[ts]
        for tag in CHANNEL_TAGS:
            if tag in tag2file:
                img_path = os.path.join(sample_dir, tag2file[tag])
                img = Image.open(img_path).convert("L")
                imgs.append(np.array(img))
            else:
                # missing channel -> zero image
                imgs.append(np.zeros((256, 256), dtype=np.uint8))

    imgs = np.array(imgs)  # (n_imgs, 256, 256)
    
    # Pad to exactly 40 if needed
    if imgs.shape[0] < 40:
        pad = np.zeros((40 - imgs.shape[0], 256, 256), dtype=np.uint8)
        imgs = np.concatenate([imgs, pad], axis=0)
    
    return imgs[:40]  # ensure exactly 40


In [4]:
EPS = 1e-9
train_df["log_peak_flux"] = np.log10(train_df["peak_flux"].clip(lower=EPS))
test_df["log_peak_flux"]  = np.log10(test_df["peak_flux"].clip(lower=EPS))

train_df["is_MX"] = (train_df["peak_flux"] >= 1e-5).astype(int)
test_df["is_MX"]  = (test_df["peak_flux"]  >= 1e-5).astype(int)

print("M/X class distribution in train_df:")
print(train_df["is_MX"].value_counts())


M/X class distribution in train_df:
is_MX
0    7822
1     514
Name: count, dtype: int64


In [5]:
class SDOBenchmarkDataset(Dataset):
    def __init__(self, df, img_root):
        self.df = df.reset_index(drop=True)
        self.img_root = img_root

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sample_id = row["id"]
        log_peak_flux = row["log_peak_flux"]
        is_MX = row["is_MX"]

        # Load 40 images
        imgs = load_sample_from_id(self.img_root, sample_id)  # (40, 256, 256)
        
        # Reshape to (T=4, C=10, H=256, W=256)
        imgs = imgs.reshape(4, 10, 256, 256)
        
        # Convert to torch float [0, 1]
        imgs = torch.from_numpy(imgs).float() / 255.0

        # Targets
        reg_target = torch.tensor(log_peak_flux, dtype=torch.float32)
        cls_target = torch.tensor(is_MX, dtype=torch.float32)
        
        return imgs, reg_target, cls_target


In [6]:
# Use only first 300 samples for fast training
subset_df = train_df.iloc[:300].copy()

train_df_small, val_df_small = train_test_split(
    subset_df,
    test_size=0.2,
    random_state=42
)

train_dataset = SDOBenchmarkDataset(train_df_small, TRAIN_IMG_DIR)
val_dataset   = SDOBenchmarkDataset(val_df_small,   TRAIN_IMG_DIR)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=2, shuffle=False, num_workers=0)

# Sanity check
imgs_batch, reg_t, cls_t = next(iter(train_loader))
print("Train batch imgs shape:", imgs_batch.shape)  # Should be [2, 4, 10, 256, 256]
print("Train batch reg targets (first 2):", reg_t[:2])
print("Train batch cls targets (first 2):", cls_t[:2])


Train batch imgs shape: torch.Size([2, 4, 10, 256, 256])
Train batch reg targets (first 2): tensor([-6.1099, -6.4823])
Train batch cls targets (first 2): tensor([0., 0.])


In [7]:
class FlareCNNMultiTask(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(10, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128, 64)
        
        # Two heads: regression (log-flux) + classification (M/X)
        self.reg_head = nn.Linear(64, 1)
        self.cls_head = nn.Linear(64, 1)

    def forward(self, x):
        # x: (B, T=4, C=10, H=256, W=256)
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)  # (B*T, 10, 256, 256)

        # Conv blocks
        x = self.pool(F.relu(self.conv1(x)))          # (B*T, 32, 128, 128)
        x = self.pool(F.relu(self.conv2(x)))          # (B*T, 64, 64, 64)
        x = self.pool(F.relu(self.conv3(x)))          # (B*T, 128, 32, 32)

        # Global average pooling
        x = F.adaptive_avg_pool2d(x, 1)               # (B*T, 128, 1, 1)
        x = x.view(B, T, 128)                         # (B, T=4, 128)
        x = x.mean(dim=1)                             # (B, 128)

        # FC layer
        x = F.relu(self.fc1(x))                       # (B, 64)

        # Two outputs
        reg_out = self.reg_head(x).squeeze(-1)        # (B,)
        cls_logit = self.cls_head(x).squeeze(-1)      # (B,)
        
        return reg_out, cls_logit


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

model = FlareCNNMultiTask().to(device)

reg_criterion = nn.L1Loss()                    # MAE for regression
cls_criterion = nn.BCEWithLogitsLoss()         # BCE for classification
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


Device: cuda


In [9]:
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    for imgs, reg_t, cls_t in loader:
        imgs = imgs.to(device)
        reg_t = reg_t.to(device)
        cls_t = cls_t.to(device)

        optimizer.zero_grad()
        reg_out, cls_logit = model(imgs)

        reg_loss = reg_criterion(reg_out, reg_t)
        cls_loss = cls_criterion(cls_logit, cls_t)
        loss = reg_loss + cls_loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
    
    return total_loss / len(loader.dataset)


In [10]:
def eval_metrics(model, loader, device):
    model.eval()
    reg_preds, reg_trues = [], []
    cls_probs, cls_trues = [], []
    
    with torch.no_grad():
        for imgs, reg_t, cls_t in loader:
            imgs = imgs.to(device)
            reg_t = reg_t.to(device)
            cls_t = cls_t.to(device)

            reg_out, cls_logit = model(imgs)
            prob = torch.sigmoid(cls_logit)

            reg_preds.append(reg_out.cpu())
            reg_trues.append(reg_t.cpu())
            cls_probs.append(prob.cpu())
            cls_trues.append(cls_t.cpu())

    reg_preds = torch.cat(reg_preds)
    reg_trues = torch.cat(reg_trues)
    cls_probs = torch.cat(cls_probs)
    cls_trues = torch.cat(cls_trues)

    # Regression metrics
    log_mae = torch.mean(torch.abs(reg_preds - reg_trues)).item()
    flux_pred = 10 ** reg_preds
    flux_true = 10 ** reg_trues
    flux_mae = torch.mean(torch.abs(flux_pred - flux_true)).item()

    # Classification metrics
    cls_pred = (cls_probs >= 0.5).int().numpy()
    cls_true = cls_trues.int().numpy()
    precision, recall, f1, _ = precision_recall_fscore_support(
        cls_true, cls_pred, average="binary", zero_division=0
    )
    
    return log_mae, flux_mae, precision, recall, f1


In [11]:
EPOCHS = 5

for epoch in range(1, EPOCHS + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, device)
    val_log_mae, val_flux_mae, prec, rec, f1 = eval_metrics(model, val_loader, device)
    
    print(f"Epoch {epoch}: train_loss={train_loss:.3e}, "
          f"val_log_MAE={val_log_mae:.3e}, val_flux_MAE={val_flux_mae:.3e}, "
          f"prec={prec:.3f}, rec={rec:.3f}, f1={f1:.3f}")


Epoch 1: train_loss=4.649e+00, val_log_MAE=1.734e+00, val_flux_MAE=4.527e-06, prec=0.000, rec=0.000, f1=0.000
Epoch 2: train_loss=2.048e+00, val_log_MAE=1.681e+00, val_flux_MAE=4.736e-06, prec=0.000, rec=0.000, f1=0.000
Epoch 3: train_loss=2.039e+00, val_log_MAE=1.689e+00, val_flux_MAE=4.617e-06, prec=0.000, rec=0.000, f1=0.000
Epoch 4: train_loss=1.991e+00, val_log_MAE=1.655e+00, val_flux_MAE=4.800e-06, prec=0.000, rec=0.000, f1=0.000
Epoch 5: train_loss=1.986e+00, val_log_MAE=1.662e+00, val_flux_MAE=5.141e-06, prec=0.000, rec=0.000, f1=0.000


In [12]:
CHECKPOINT_DIR = "models"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

ckpt_path = os.path.join(CHECKPOINT_DIR, "flare_cnn_multitask_subset_e5.pth")
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": EPOCHS,
}, ckpt_path)

print("Saved checkpoint to:", ckpt_path)


Saved checkpoint to: models\flare_cnn_multitask_subset_e5.pth
