# Imports

In [None]:
import os
import glob
import json
import random
import warnings
from pathlib import Path
from collections import defaultdict
from datetime import datetime

import logging
logging.getLogger("absl").setLevel(logging.ERROR)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as TF

from transformers import (
    ViTImageProcessor,
    ViTModel,
    get_cosine_schedule_with_warmup
)

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    f1_score,
    roc_auc_score,
    average_precision_score
)

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Seeding

In [None]:
def set_seed(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    def worker_init_fn(worker_id):
        seed = SEED + worker_id
        random.seed(seed)
        np.random.seed(seed)

    g = torch.Generator()
    g.manual_seed(SEED)
    return worker_init_fn, g

# Configrations

In [None]:
SEED = 2024
worker_init_fn, g_seed = set_seed(SEED)

SEQ_LEN = 4

BATCH_SIZE = 32
NUM_EPOCHS = 30
IMG_SIZE = 128

FLARE_THRESHOLD = 1e-5


TRAIN_DIR = "/kaggle/input/sdobenchmark_full/training"
TEST_DIR  = "/kaggle/input/sdobenchmark_full/test"

MODEL_NAME = "facebook/deit-tiny-patch16-224"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Device: {DEVICE}")
print(f"Seed: {SEED}")

# Load metadata

In [None]:
df_train = pd.read_csv(os.path.join(TRAIN_DIR, "meta_data.csv"))
df_test  = pd.read_csv(os.path.join(TEST_DIR, "meta_data.csv"))

df_train["label"] = (df_train["peak_flux"] >= FLARE_THRESHOLD).astype(int)
df_test["label"]  = (df_test["peak_flux"] >= FLARE_THRESHOLD).astype(int)

# Sequence building (24h horizon)

In [None]:
def build_sequences(df, seq_len, horizon_hours):
    df = df.copy()
    df["sid"] = df["id"].apply(lambda x: x.split("_")[0])
    df["start_time"] = pd.to_datetime(df["start"])

    sequences = []

    for sid, group in df.groupby("sid"):
        group = group.sort_values("start_time").reset_index(drop=True)

        for i in range(seq_len, len(group)):
            past = group.iloc[i-seq_len:i]
            t0 = group.iloc[i]["start_time"]
            t1 = t0 + pd.Timedelta(hours=horizon_hours)

            assert past["id"].str.startswith(sid).all(), \
                "Sequence contains samples from multiple active regions"

            future = group[
                (group["start_time"] > t0) &
                (group["start_time"] <= t1)
            ]

            if not future.empty:
                assert future["start_time"].min() > t0, \
                    "Future window includes non-future samples"

            label = int((future["peak_flux"] >= FLARE_THRESHOLD).any())

            past_offsets = [
                (t0 - t).total_seconds() / 3600.0
                for t in past["start_time"]
            ]

            assert all(o >= 0 for o in past_offsets), \
                "Negative time offset detected (future leakage)"

            assert past_offsets == sorted(past_offsets, reverse=True), \
                "Past offsets not in non-increasing order"


            history_span_h = (
                t0 - past["start_time"].iloc[0]
            ).total_seconds() / 3600.0

            if history_span_h < 1.0:
                continue

            sequences.append({
                "past_ids": past["id"].tolist(),
                "past_offsets": past_offsets,
                "sid": sid,
                "label": label
            })
    return sequences

In [None]:
train_seqs = build_sequences(df_train, SEQ_LEN, horizon_hours=24)
test_seqs  = build_sequences(df_test,  SEQ_LEN, horizon_hours=24)

In [None]:
print("Train sequences:", len(train_seqs))
print("Test  sequences:", len(test_seqs))

In [None]:
import numpy as np

train_labels = [s["label"] for s in train_seqs]
test_labels  = [s["label"] for s in test_seqs]

print("Train positives:", np.mean(train_labels))
print("Test  positives:", np.mean(test_labels))

In [None]:
s = train_seqs[3]

print("SID:", s["sid"])
print("Label:", s["label"])
print("Past IDs:", s["past_ids"])
print("Past offsets (h):", s["past_offsets"])
print("History span (h):", max(s["past_offsets"]))

In [None]:
all_offsets = [o for s in train_seqs for o in s["past_offsets"]]

print("Offset stats (hours):")
print("  min:", min(all_offsets))
print("  max:", max(all_offsets))
print("  mean:", sum(all_offsets) / len(all_offsets))

# Views Map

In [None]:
WANTED_WL = [
   "94.jpg","131.jpg","171.jpg","193.jpg","211.jpg", "304.jpg","335.jpg","continuum.jpg","1700.jpg","magnetogram.jpg"
]

# "94.jpg","131.jpg","171.jpg","193.jpg","211.jpg", "304.jpg","335.jpg","continuum.jpg","1700.jpg","magnetogram.jpg"

from collections import defaultdict
import os, glob
import pandas as pd

def build_views_map(base_dir):
    views = defaultdict(dict)
    times = defaultdict(dict)  

    for p in glob.glob(os.path.join(base_dir, "*", "*", "*.jpg")):
        try:
            ts  = os.path.basename(os.path.dirname(p)) 
            sid = os.path.basename(os.path.dirname(os.path.dirname(p)))
            mid = f"{sid}_{ts}"

            t_str, wl = os.path.basename(p).split("__")
            if wl not in WANTED_WL:
                continue

            t_img = pd.to_datetime(t_str)

            
            if wl not in times[mid] or t_img > times[mid][wl]:
                times[mid][wl] = t_img
                views[mid][wl] = p

        except Exception as e:
            pass

    return views

views_map_train = build_views_map(TRAIN_DIR)
views_map_test  = build_views_map(TEST_DIR)

In [None]:
def sanity_check_views(views_map, name):
    wl_counts = []
    for mid, wls in views_map.items():
        wl_counts.append(len(wls))

    print(f"{name}:")
    print("  Samples:", len(views_map))
    print("  Min WL count:", min(wl_counts))
    print("  Max WL count:", max(wl_counts))
    print("  Mean WL count:", sum(wl_counts)/len(wl_counts))

sanity_check_views(views_map_train, "TRAIN")
sanity_check_views(views_map_test,  "TEST")

In [None]:
few_wl = [mid for mid, wls in views_map_train.items() if len(wls) < 10]

print("Samples with < 10 wavelengths:", len(few_wl))
print("Fraction:", len(few_wl) / len(views_map_train))

# SID-Based Split (NO LEAKAGE)

In [None]:
sid_to_seqs = defaultdict(list)
for s in train_seqs:
    sid_to_seqs[s["sid"]].append(s)

sids = list(sid_to_seqs.keys())
sid_labels = [int(any(seq["label"] for seq in sid_to_seqs[sid])) for sid in sids]

train_sids, val_sids = train_test_split(
    sids, test_size=0.25, random_state=SEED, stratify=sid_labels
)

train_seqs_by_sid, val_seqs_by_sid = [], []
for sid in train_sids:
    train_seqs_by_sid.extend(sid_to_seqs[sid])
for sid in val_sids:
    val_seqs_by_sid.extend(sid_to_seqs[sid])

In [None]:
assert set(train_sids).isdisjoint(set(val_sids)), \
    "Active region leakage between train and validation!"

In [None]:
print("Train ARs:", len(train_sids))
print("Val   ARs:", len(val_sids))

print("Train sequences:", len(train_seqs_by_sid))
print("Val   sequences:", len(val_seqs_by_sid))

In [None]:
def ar_pos_rate(sids):
    return sum(any(seq["label"] for seq in sid_to_seqs[sid]) for sid in sids) / len(sids)

print("Train AR positive rate:", ar_pos_rate(train_sids))
print("Val   AR positive rate:", ar_pos_rate(val_sids))

In [None]:
def seq_pos_rate(seqs):
    return sum(s["label"] for s in seqs) / len(seqs)

print("Train seq positive rate:", seq_pos_rate(train_seqs_by_sid))
print("Val   seq positive rate:", seq_pos_rate(val_seqs_by_sid))

# Augmentation

In [None]:
class RandomDiscreteRotation:
    def __init__(self, angles=(0, 90, 180, 270), p=0.3):
        self.angles = angles
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            angle = random.choice(self.angles)
            return TF.rotate(img, angle)
        return img

In [None]:
transform_safe = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.3),
    RandomDiscreteRotation(angles=(0, 90, 180, 270), p=0.3),
])

# Processor 

In [None]:
processor = ViTImageProcessor.from_pretrained(MODEL_NAME, size=IMG_SIZE)

# Dataset

In [None]:
WL_TO_IDX = {wl: i for i, wl in enumerate(WANTED_WL)}

class SolarFlareSequenceDataset(Dataset):
    def __init__(
        self,
        sequences,
        views_map,
        apply_augmentation=False,
        masked_wavelength=None,
        masked_wavelengths=None
    ):
        self.sequences = sequences
        self.views_map = views_map
        self.apply_augmentation = apply_augmentation

        if masked_wavelength is not None:
            self.masked_wls = {masked_wavelength}
        elif masked_wavelengths is not None:
            self.masked_wls = set(masked_wavelengths)
        else:
            self.masked_wls = None

        self.blank = torch.zeros(3, IMG_SIZE, IMG_SIZE)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        s = self.sequences[idx]
        frames = []

        for mid in s["past_ids"]:
            imgs = []
            for wl in WANTED_WL:
                if self.masked_wls is not None and wl in self.masked_wls:
                    img = self.blank.clone()

                elif mid in self.views_map and wl in self.views_map[mid]:
                    img_pil = Image.open(self.views_map[mid][wl]).convert("RGB")
                    img_pil = img_pil.resize((IMG_SIZE, IMG_SIZE))

                    if self.apply_augmentation:
                        img_pil = transform_safe(img_pil)

                    img = processor(
                        img_pil,
                        return_tensors="pt"
                    )["pixel_values"].squeeze(0)

                else:
                    img = self.blank.clone()

                imgs.append(img)

            frames.append(torch.stack(imgs))

        x = torch.stack(frames)          # (T, W, 3, H, W)
        y = torch.tensor(s["label"]).float()
        offsets = torch.tensor(s["past_offsets"]).float()

        return x, y, offsets, s["sid"]

# Model Pipeline

In [None]:
class CNNStem(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.GELU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 3, 1)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class CNN_ViT_Attn(nn.Module):
    def __init__(
        self,
        vit_name,
        dropout_p=0.2,
        n_heads=4,
        n_layers=2,
        use_cnn_stem=True,
        use_vit=True,
        use_wl_attn=True,
        use_deltas=True,
        use_transformer=True,
        use_time_emb=True,
    ):
        super().__init__()

        self.use_cnn_stem = use_cnn_stem
        self.use_vit = use_vit
        self.use_wl_attn = use_wl_attn
        self.use_deltas = use_deltas
        self.use_transformer = use_transformer
        self.use_time_emb = use_time_emb

        # CNN stem
        self.cnn = CNNStem() if use_cnn_stem else nn.Identity()

        # Spatial encoder
        if use_vit:
            self.vit = ViTModel.from_pretrained(vit_name)
            D = self.vit.config.hidden_size
        else:
            raise NotImplementedError("CNN-only backbone not shown here")

        # Wavelength attention
        if use_wl_attn:
            self.wl_attn = nn.Linear(D, 1)

        # Temporal deltas
        in_dim = D * 2 if use_deltas else D
        self.input_proj = nn.Linear(in_dim, D)

        # Time embedding
        if use_time_emb:
            self.time_proj = nn.Linear(1, D)

        # Temporal transformer
        if use_transformer:
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=D,
                nhead=n_heads,
                dim_feedforward=D * 4,
                dropout=dropout_p,
                activation="gelu",
                batch_first=True
            )
            self.transformer = nn.TransformerEncoder(
                encoder_layer, num_layers=n_layers
            )

        self.dropout = nn.Dropout(dropout_p)
        self.head = nn.Linear(D, 1)

    def forward(self, x, offsets, attn_mask):
        """
        x:          (B, T, W, 3, H, W)
        offsets:    (B, T)
        attn_mask:  (B, T)   True = valid, False = padding
        """

        B, T, W, C, H, W_ = x.shape

        #  flatten spatial -
        x = x.view(B * T * W, C, H, W_)

        if self.use_cnn_stem:
            x = self.cnn(x)
            x = F.interpolate(x, size=(IMG_SIZE, IMG_SIZE))

        # ViT 
        feats = self.vit(
            x, interpolate_pos_encoding=True
        ).last_hidden_state[:, 0]  # CLS token

        feats = feats.view(B, T, W, -1)

        # wavelength aggregation
        if self.use_wl_attn:
            attn = torch.softmax(
                self.wl_attn(feats).squeeze(-1), dim=2
            )
            feat = (feats * attn.unsqueeze(-1)).sum(dim=2)
        else:
            feat = feats.mean(dim=2)

        #  temporal deltas
        if self.use_deltas:
            delta = feat[:, 1:] - feat[:, :-1]
            delta = F.pad(delta, (0, 0, 1, 0))
            feat = torch.cat([feat, delta], dim=-1)

        feat = self.input_proj(feat)

        # time embedding 
        if self.use_time_emb:
            rec = 1.0 / (1.0 + offsets.unsqueeze(-1))
            feat = feat + self.time_proj(rec)

        # transformer (mask-aware) 
        if self.use_transformer:
            key_padding_mask = ~attn_mask  # True = PAD
            feat = self.transformer(
                feat,
                src_key_padding_mask=key_padding_mask
            )

        #  masked pooling 
        mask = attn_mask.unsqueeze(-1).float()
        feat = feat * mask
        pooled = feat.sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)

        pooled = self.dropout(pooled)
        return self.head(pooled).squeeze(1)

# Metrics & utils

In [None]:
def safe_confusion(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp = int(cm[0,0]), int(cm[0,1])
    fn, tp = int(cm[1,0]), int(cm[1,1])
    return tn, fp, fn, tp

In [None]:
def compute_binary_metrics(y_true, y_prob, threshold=0.5):
    y_pred = (y_prob >= threshold).astype(int)

    TN, FP, FN, TP = safe_confusion(y_true, y_pred)

    POD = TP / (TP + FN + 1e-6)
    Precision = TP / (TP + FP + 1e-6)
    FAR = FP / (TP + FP + 1e-6)
    CSI = TP / (TP + FP + FN + 1e-6)
    FPR = FP / (FP + TN + 1e-6)

    TSS = POD - FPR
    HSS = compute_hss(TP, TN, FP, FN)

    Acc = (
        accuracy_score(y_true, y_pred)
        if len(np.unique(y_true)) > 1
        else float(y_pred[0] == y_true[0])
    )

    F1 = f1_score(y_true, y_pred, zero_division=0)

    return {
        "TP": TP, "FP": FP, "TN": TN, "FN": FN,
        "Recall": POD,
        "Precision": Precision,
        "FAR": FAR,
        "CSI": CSI,
        "FPR": FPR,
        "TSS": TSS,
        "HSS": HSS,
        "Accuracy": Acc,
        "F1": F1,
    }

In [None]:
def tune_threshold(probs, labels, metric="tss", thresholds=None):
    thresholds = thresholds if thresholds is not None else np.arange(0.0, 1.01, 0.01)
    best_metric, best_threshold = -1, 0.5
    for t in thresholds:
        preds = (probs >= t).astype(int)
        tn, fp, fn, tp = confusion_matrix(labels, preds, labels=[0,1]).ravel()
        if metric == "tss":
            pod = tp / (tp + fn + 1e-6)
            fpr = fp / (fp + tn + 1e-6)
            value = pod - fpr
        else:
            precision = tp / (tp + fp + 1e-6)
            recall = tp / (tp + fn + 1e-6)
            value = 2 * (precision * recall) / (precision + recall + 1e-6)
        if value > best_metric:
            best_metric, best_threshold = value, t
    return best_threshold, best_metric

In [None]:
def aggregate_by_sid(probs, labels, sids, mode="max"):
    df = pd.DataFrame({"sid": sids, "prob": probs, "label": labels})

    if mode == "max":
        df_ar = df.groupby("sid").agg(
            prob=("prob", "max"),
            label=("label", "max")
        ).reset_index()
    else:
        raise ValueError("Unsupported aggregation mode")

    return df_ar

In [None]:
def compute_hss(tp, tn, fp, fn):
    numerator = 2 * (tp * tn - fp * fn)
    denominator = (
        (tp + fn) * (fn + tn) +
        (tp + fp) * (fp + tn)
    )
    return numerator / (denominator + 1e-6)

In [None]:
import numpy as np

def compute_ece(probs, labels, n_bins=10):
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    N = len(probs)

    for i in range(n_bins):
        mask = (probs > bins[i]) & (probs <= bins[i + 1])
        if np.sum(mask) == 0:
            continue

        acc = np.mean(labels[mask])
        conf = np.mean(probs[mask])
        ece += (np.sum(mask) / N) * abs(acc - conf)

    return ece

# Datasets

In [None]:
train_ds = SolarFlareSequenceDataset(
    train_seqs_by_sid,
    views_map_train,
    apply_augmentation=False
)

val_ds = SolarFlareSequenceDataset(
    val_seqs_by_sid,
    views_map_train,
    apply_augmentation=False
    
)
test_ds = SolarFlareSequenceDataset(
    test_seqs,
    views_map_test,
    apply_augmentation=False
)

# Weighted Sampler

In [None]:
train_labels = np.array([s["label"] for s in train_seqs_by_sid])

class_sample_count = np.array([
    np.sum(train_labels == 0),
    np.sum(train_labels == 1)
])

class_weights = class_sample_count.sum() / (2.0 * class_sample_count)
samples_weight = class_weights[train_labels]
samples_weight = torch.from_numpy(samples_weight).float()

# Dataloaders ( Train / VAL / Test )

In [None]:
def get_train_loader(epoch, train_ds, samples_weight, early_epochs_threshold=20):
    if epoch <= early_epochs_threshold:
        sampler = WeightedRandomSampler(
            weights=samples_weight,
            num_samples=len(samples_weight),
            replacement=True,
            generator=g_seed
        )
        return DataLoader(
            train_ds,
            batch_size=BATCH_SIZE,
            sampler=sampler,
            num_workers=4,
            worker_init_fn=worker_init_fn,
            pin_memory=(DEVICE.type == "cuda")
        )
    else:
        return DataLoader(
            train_ds,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=4,
            worker_init_fn=worker_init_fn,
            pin_memory=(DEVICE.type == "cuda")
        )

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    worker_init_fn=worker_init_fn,
    generator=g_seed,
    pin_memory=(DEVICE.type == "cuda")
)

test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    worker_init_fn=worker_init_fn,
    generator=g_seed,
    pin_memory=(DEVICE.type == "cuda")
)

# Model Initiate

In [None]:
model = CNN_ViT_Attn(
    MODEL_NAME,
    use_vit=True,
    use_cnn_stem=True,
    use_wl_attn=False,
    use_deltas=True,
    use_transformer=True,
    use_time_emb=False
).to(DEVICE)

In [None]:
print("use_vit:", model.use_vit)
print("use_cnn_stem:", model.use_cnn_stem)
print("use_transformer:", model.use_transformer)
print("use_wl_attn:", model.use_wl_attn)
print("use_deltas:", model.use_deltas)
print("use_time_emb:", model.use_time_emb)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params/1e6:.2f}M")
print(f"Trainable params: {trainable_params/1e6:.2f}M")

# Focal Loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        targets = targets.float()
        probs = torch.sigmoid(logits)

        pt = probs * targets + (1 - probs) * (1 - targets)

        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)

        loss = -alpha_t * ((1 - pt) ** self.gamma) * torch.log(pt + 1e-6)

        return loss.mean()



criterion = FocalLoss(alpha=0.5, gamma=2.0)

# Optimizer & hyperparameters

In [None]:
if model.use_vit:
    for p in model.vit.parameters():
        p.requires_grad = False

# Scheduler

In [None]:
def build_scheduler(optimizer, train_loader, num_epochs, warmup_frac=0.1):
    steps_per_epoch = len(train_loader)
    total_steps = steps_per_epoch * num_epochs
    warmup_steps = max(1, int(warmup_frac * total_steps))

    return get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

In [None]:
LR_VIT  = 1e-5
WD_HEAD = 5e-4

LR_HEAD = 1e-4
WD_VIT  = 5e-5


# Optimizer 
def build_optimizer(model):
    param_groups = []

    head_modules = [
        model.cnn,
        model.input_proj,
        model.head,
    ]

    if model.use_transformer:
        head_modules.append(model.transformer)

    if model.use_time_emb:
        head_modules.append(model.time_proj)

    if model.use_wl_attn:
        head_modules.append(model.wl_attn)

    for m in head_modules:
        if m is None:
            continue
        params = [p for p in m.parameters() if p.requires_grad]
        if len(params) > 0:
            param_groups.append({
                "params": params,
                "lr": LR_HEAD,
                "weight_decay": WD_HEAD
            })

    if model.use_vit and model.vit is not None:
        vit_params = [p for p in model.vit.parameters() if p.requires_grad]
        if len(vit_params) > 0:
            param_groups.append({
                "params": vit_params,
                "lr": LR_VIT,
                "weight_decay": WD_VIT
            })

    if hasattr(model, "cnn_backbone"):
        cnn_params = [p for p in model.cnn_backbone.parameters() if p.requires_grad]
        if len(cnn_params) > 0:
            param_groups.append({
                "params": cnn_params,
                "lr": LR_HEAD,
                "weight_decay": WD_HEAD
            })

    return torch.optim.AdamW(param_groups)

train_loader = get_train_loader(1, train_ds, samples_weight)

optimizer = build_optimizer(model)   

train_loader = get_train_loader(1, train_ds, samples_weight)

scheduler = build_scheduler(
    optimizer,
    train_loader,
    NUM_EPOCHS,
    warmup_frac=0.1
)

scaler = torch.cuda.amp.GradScaler()

In [None]:
loader_epoch_1 = get_train_loader(1, train_ds, samples_weight)
loader_epoch_30 = get_train_loader(30, train_ds, samples_weight)

print(len(loader_epoch_1), len(loader_epoch_30))
assert len(loader_epoch_1) == len(loader_epoch_30)

In [None]:
import inspect
print(inspect.signature(build_scheduler))

In [None]:
try:
    xb, yb, offs, sids = next(iter(train_loader))
    print("SMOKE: batch shapes:", xb.shape, yb.shape, offs.shape)
except Exception as e:
    print("SMOKE test failed:", e)

In [None]:
from collections import defaultdict

def ar_coverage_by_wavelength(seqs, views_map):
    wl_stats = defaultdict(lambda: {"AR_total": set(), "AR_with_data": set(), 
                                    "AR_pos": set(), "AR_neg": set()})

    for s in seqs:
        sid = s["sid"]
        label = s["label"]

        wl_present = set()
        for mid in s["past_ids"]:
            if mid in views_map:
                for wl in views_map[mid].keys():
                    wl_present.add(wl)

        for wl in WANTED_WL:
            wl_stats[wl]["AR_total"].add(sid)
            if wl in wl_present:
                wl_stats[wl]["AR_with_data"].add(sid)
                if label == 1:
                    wl_stats[wl]["AR_pos"].add(sid)
                else:
                    wl_stats[wl]["AR_neg"].add(sid)

    rows = []
    for wl, d in wl_stats.items():
        rows.append({
            "wavelength": wl.replace(".jpg",""),
            "AR_total": len(d["AR_total"]),
            "AR_with_data": len(d["AR_with_data"]),
            "AR_pos": len(d["AR_pos"]),
            "AR_neg": len(d["AR_neg"]),
            "pos_ratio": len(d["AR_pos"]) / max(1, len(d["AR_with_data"]))
        })

    return pd.DataFrame(rows)

df_wl_coverage = ar_coverage_by_wavelength(train_seqs_by_sid, views_map_train)
df_wl_coverage

In [None]:
def sequence_presence_by_wavelength(seqs, views_map):
    counts = defaultdict(lambda: {"present": 0, "total": 0})

    for s in seqs:
        for wl in WANTED_WL:
            present = False
            for mid in s["past_ids"]:
                if mid in views_map and wl in views_map[mid]:
                    present = True
                    break
            counts[wl]["total"] += 1
            if present:
                counts[wl]["present"] += 1

    rows = []
    for wl, d in counts.items():
        rows.append({
            "wavelength": wl.replace(".jpg",""),
            "seq_total": d["total"],
            "seq_present": d["present"],
            "presence_ratio": d["present"] / max(1, d["total"])
        })

    return pd.DataFrame(rows)

df_seq_presence = sequence_presence_by_wavelength(train_seqs_by_sid, views_map_train)
df_seq_presence

In [None]:
df = df_wl_coverage.copy()
df["neg_ratio"] = df["AR_neg"] / df["AR_with_data"].clip(lower=1)

df[["wavelength", "AR_with_data", "pos_ratio", "neg_ratio"]]

# Training loop

In [None]:
run_name = f"forecast_{24}h"
N_FREEZE_EPOCHS = 5          

run_config = {
    "run_name": run_name,
    "forecast_horizon_hours": 24,
    "batch_size": BATCH_SIZE,

    "flare_threshold": 1e-5,
    "split_seed": SEED,

    "model_name": MODEL_NAME,
    "num_epochs": NUM_EPOCHS,
    "n_freeze_epochs": "all",

    "optimizer": "AdamW",
    "lr_head": LR_HEAD,
    "lr_vit": LR_VIT,
    "wd_head": WD_HEAD,
    "wd_vit": WD_VIT,
    "scheduler": "CosineLR + linear warmup",

    "focal_alpha": 0.5,
    "focal_gamma": 2.0,
}

with open(f"run_config_{run_name}_{SEED}.json", "w") as f:
    json.dump(run_config, f, indent=2)

In [None]:
best_val_tss = -float("inf")
best_threshold = 0.5
best_epoch = -1
best_val_metrics = None

patience = 5
wait = 0
TSS_EPS = 1e-4

history = {
    "epoch": [],
    "train_loss": [],
    "val_loss": [],
    "val_tss": [],
    "threshold": [],
}

use_amp = DEVICE.type == "cuda"
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

In [None]:
for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}")

    train_loader = get_train_loader(epoch, train_ds, samples_weight)

    model.train()
    train_loss = 0.0

    for x, y, offsets, _ in tqdm(train_loader, desc="Training", leave=False):
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        offsets = offsets.to(DEVICE)

        attn_mask = torch.ones(
            x.size(0), x.size(1),
            dtype=torch.bool,
            device=x.device
        )

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(x, offsets, attn_mask).view(-1)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        train_loss += loss.item()

    train_loss /= max(1, len(train_loader))

    model.eval()
    val_loss = 0.0
    all_probs, all_labels, all_sids = [], [], []

    with torch.no_grad():
        for x, y, offsets, sids in val_loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            offsets = offsets.to(DEVICE)

            attn_mask = torch.ones(
                x.size(0), x.size(1),
                dtype=torch.bool,
                device=x.device
            )

            logits = model(x, offsets, attn_mask).view(-1)
            probs = torch.sigmoid(logits)
            loss = criterion(logits, y)

            val_loss += loss.item()
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            all_sids.extend(sids)

    val_loss /= max(1, len(val_loader))

    df_val_ar = aggregate_by_sid(
        np.array(all_probs),
        np.array(all_labels),
        np.array(all_sids),
        mode="max"
    )

    threshold, val_tss = tune_threshold(
        df_val_ar["prob"].values,
        df_val_ar["label"].values,
        metric="tss"
    )

    history["epoch"].append(epoch)
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["val_tss"].append(val_tss)
    history["threshold"].append(threshold)

    print(
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"AR-TSS: {val_tss:.3f} | "
        f"Thr: {threshold:.3f}"
    )

    is_better = False
    if val_tss > best_val_tss + TSS_EPS:
        is_better = True
    elif abs(val_tss - best_val_tss) <= TSS_EPS:
        if abs(threshold - 0.5) < abs(best_threshold - 0.5):
            is_better = True

    if is_better:
        best_val_tss = val_tss
        best_threshold = threshold
        best_epoch = epoch
        wait = 0

        best_val_metrics = compute_binary_metrics(
            y_true=df_val_ar["label"].values,
            y_prob=df_val_ar["prob"].values,
            threshold=best_threshold
        )

        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "epoch": best_epoch,
                "threshold": float(best_threshold),
                "val_tss": float(best_val_tss),
                "metrics": best_val_metrics,
            },
            f"best_model_{run_name}_{SEED}.pth"
        )

        print("✓ Best model saved")

    else:
        wait += 1
        if wait >= patience:
            print("Early stopping triggered")
            break

In [None]:
print("\n===== FINAL VALIDATION METRICS (BEST EPOCH) =====")
print(f"Best Epoch: {best_epoch}")
print(f"Threshold: {best_threshold:.3f}")

for k, v in best_val_metrics.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

In [None]:
print("\nTraining finished.")
print("Best AR-level TSS:", best_val_tss)

with open(f"training_history_{run_name}_{SEED}.json", "w") as f:
    json.dump(history, f, indent=2)

pd.DataFrame(history).to_csv(
    f"training_history_{run_name}_{SEED}.csv",
    index=False
)

run_config.update({
    "best_epoch": best_epoch,
    "best_val_tss": float(best_val_tss),
    "best_threshold": float(best_threshold),
})

with open(f"run_config_{run_name}_{SEED}.json", "w") as f:
    json.dump(run_config, f, indent=2)

# Evaluation on test set

In [None]:
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    confusion_matrix,
    accuracy_score,
    f1_score
)

# LOAD BEST MODEL
ckpt = torch.load(
    f"best_model_{run_name}_{SEED}.pth",
    map_location=DEVICE
)

model.load_state_dict(ckpt["model_state_dict"])
best_threshold = ckpt["threshold"]
model.eval()

all_probs, all_labels, all_sids = [], [], []

with torch.no_grad():
    for x, y, offsets, sids in test_loader:
        x = x.to(DEVICE)
        offsets = offsets.to(DEVICE)
        y = y.to(DEVICE)

        attn_mask = torch.ones(
            x.size(0), x.size(1),
            dtype=torch.bool,
            device=x.device
        )

        logits = model(x, offsets, attn_mask).view(-1)
        probs = torch.sigmoid(logits)

        all_probs.extend(probs.cpu().numpy())
        all_labels.extend(y.cpu().numpy())
        all_sids.extend(sids)

df_test_ar = aggregate_by_sid(
    np.array(all_probs),
    np.array(all_labels),
    np.array(all_sids),
    mode="max"
)

y_true = df_test_ar["label"].values
y_prob = df_test_ar["prob"].values
y_pred = (y_prob >= best_threshold).astype(int)

tn, fp, fn, tp = confusion_matrix(
    y_true, y_pred, labels=[0, 1]
).ravel()

POD = tp / (tp + fn + 1e-6)
FPR = fp / (fp + tn + 1e-6)
Precision = tp / (tp + fp + 1e-6)
FAR = fp / (tp + fp + 1e-6)
CSI = tp / (tp + fp + fn + 1e-6)
TSS = POD - FPR
HSS = compute_hss(tp, tn, fp, fn)
Accuracy = accuracy_score(y_true, y_pred)
F1 = f1_score(y_true, y_pred)
AUROC = roc_auc_score(y_true, y_prob)
AUPRC = average_precision_score(y_true, y_prob)

print(f"\n=== AR-LEVEL TEST METRICS (24h FORECAST) ===")
print(f"Samples: {len(y_true)} | Positives: {y_true.sum()} | Negatives: {len(y_true)-y_true.sum()}")
print(f"TN={tn}, FP={fp}, FN={fn}, TP={tp}")
print(f"TSS: {TSS:.3f}")
print(f"HSS: {HSS:.3f}")
print(f"POD (Recall): {POD:.3f}")
print(f"FPR: {FPR:.3f}")
print(f"Precision: {Precision:.3f}")
print(f"FAR: {FAR:.3f}")
print(f"CSI: {CSI:.3f}")
print(f"Accuracy: {Accuracy:.3f}")
print(f"F1: {F1:.3f}")
print(f"AUROC: {AUROC:.4f}")
print(f"AUPRC: {AUPRC:.4f}")

df_test_ar_out = df_test_ar.copy()
df_test_ar_out["pred"] = y_pred
df_test_ar_out["threshold"] = best_threshold
df_test_ar_out["forecast_horizon_hours"] = 24

df_test_ar_out.to_csv(
    f"test_ar_predictions_{run_name}_{SEED}.csv",
    index=False
)

print(f"✓ Saved test_ar_predictions_{run_name}_{SEED}.csv")

metrics_dict = {
    "Forecast_Horizon_h": 24,
    "Samples": len(y_true),
    "Positives": int(y_true.sum()),
    "Negatives": int(len(y_true) - y_true.sum()),
    "TN": tn,
    "FP": fp,
    "FN": fn,
    "TP": tp,
    "TSS": TSS,
    "HSS": HSS,
    "POD_Recall": POD,
    "FPR": FPR,
    "Precision": Precision,
    "FAR": FAR,
    "CSI": CSI,
    "Accuracy": Accuracy,
    "F1": F1,
    "AUROC": AUROC,
    "AUPRC": AUPRC,
    "Threshold": best_threshold,
}

pd.DataFrame([metrics_dict]).to_csv(
    f"test_metrics_summary_{run_name}_{SEED}.csv",
    index=False
)

print(f"✓ Saved test_metrics_summary_{run_name}_{SEED}.csv")

In [None]:
print("Unique ARs in test:", len(set(all_sids)))
print("Total sequences in test:", len(all_sids))

# testing threshold

In [None]:
from sklearn.metrics import confusion_matrix

thresholds = np.linspace(0.1, 0.9, 17)
rows = []

print("Running threshold sweep on TEST set...")

for t in thresholds:
    y_pred_t = (y_prob >= t).astype(int)

    cm = confusion_matrix(y_true, y_pred_t, labels=[0, 1])
    tn, fp, fn, tp = cm.ravel()

    POD = tp / (tp + fn + 1e-6)
    FPR = fp / (fp + tn + 1e-6)
    TSS = POD - FPR
    Precision = tp / (tp + fp + 1e-6)
    FAR = fp / (tp + fp + 1e-6)

    rows.append({
        "threshold": float(t),
        "TSS": float(TSS),
        "POD": float(POD),
        "FPR": float(FPR),
        "Precision": float(Precision),
        "FAR": float(FAR),
        "TP": int(tp),
        "FP": int(fp),
        "FN": int(fn),
        "TN": int(tn),
    })

df_threshold_analysis = pd.DataFrame(rows)

print(df_threshold_analysis)   

df_threshold_analysis.to_csv(
    f"test_threshold_sweep_{run_name}_{SEED}.csv",
    index=False
)

print(f"✓ Saved test_threshold_sweep_{run_name}_{SEED}.csv")