## Preprocess data, create, training and inference the model to identify probability of Genesis of the Hurricanes

### Libraries

In [None]:
import os
os.environ["TORCH_DISABLE_DYNAMO"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from dataclasses import dataclass
from pathlib import Path
import os

from typing import List, Tuple, Dict, Optional
import numpy as np
import pandas as pd
#from terratorch.registry import BACKBONE_REGISTRY
import sys, platform
import matplotlib.pyplot as plt
import cv2
#import matplotlib 
#matplotlib.use("Agg")

from __future__ import annotations

from dataclasses import dataclass

from PIL import Image
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset as TorchDataset
import typing as T
import cartopy.crs as ccrs


from tqdm import tqdm
from netCDF4 import Dataset, num2date
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    precision_recall_curve,
    average_precision_score,
    roc_auc_score,
    accuracy_score,
    precision_recall_fscore_support,
    brier_score_loss, 
    roc_curve
)
   

import torch._dynamo
torch._dynamo.config.suppress_errors = True

#import torch, multiprocessing as mp
#if mp.get_start_method(allow_none=True) != "spawn":
#    mp.set_start_method("spawn", force=True)

print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)

import triton
#import triton.language as tl
#print("Triton version:", triton.__version__)

    

CUDA Available: True
CUDA Version: 12.1


### Configuration parameters

In [2]:
@dataclass
class GenesisCfg:
  
    HERE = Path(os.getcwd()).resolve()
    dataset_dir = HERE / "Database"
       
    surf_dir = dataset_dir / "surface_dataset"
    press_dir =  dataset_dir / "pressure_dataset"

    out_root = output_dir = OUT_DIR = HERE / "Outputs" 

    results_dir = HERE / "results"

    sfc_csv: str = "surface_index.csv"
    pres_csv: str = "pressure_index.csv"
    genesis_csv: str = "genesis.csv"
    track_csv: str = "track_estimate.csv"
    pres_csv_inference: str = "test_pressure_index.csv"
    sfc_csv_inference:  str = "test_surface_index.csv"
    pres_csv_train: str = "train_pressure_index.csv"
    sfc_csv_train:  str = "train_surface_index.csv"

    pressure_token: str = "pres"
    surface_token: str  = "sfc"    
        
    image_size: Tuple[int,int] = (224, 352)   # H,W (keep aspect)
    min_lat = 0.0
    max_lat =  0.0
    min_lon =  0.0
    max_lon =  0.0

    
    # Feature selection 
    apply_PCA: bool = False  
    max_channels: int = 16
    min_channels: int = 8
    prefer_vars_surface: Tuple[str, ...] = ("SLP","U10M","V10M","T2M", "FROCEAN", "EFLUX", "HFLUX", "TQL", "TQI", "TQV")
    prefer_vars_pressure: Tuple[str, ...] = ("U","V","T","QV","OMEGA", "CLOUD")

    # Inference
    visual_feature: str = "FROCEAN" 
    
    # Mask variables
    horizon_hours: int = 24
    radius_km: float = 300.0           # pixels within ~300km become positives if genesis occurs next 24h
    earth_radius_km: float = 6371.0
    lon_0_360: bool = False             # set True if your grid longitudes are 0..360

    # Training
    batch_size: int = 4  
    num_workers: int = 16
    epochs: int = 100
    lr: float = 0.001
    use_amp: bool = True
    seed: int = 42
    patience : int = 10                # early stopping patience
    min_delta: float = 0.0         # early stopping min delta
    
    # Weighting
    pos_weight: float = 15.0          
    
    # Memory / speed     
    pin_memory: bool = True     
    prefetch_factor: int = 4      
    persistent_workers: bool = True
    use_compile: bool = True
    grad_accum_steps: int = 1   

cfg = GenesisCfg()
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


### Create file CSV with the Genesis of the Hurricans used for training from Trajectory 

In [3]:
# Function to build genesis CSV from track estimate CSV choosing the earliest hurricane point per storm
def build_genesis_from_track (
        track_csv: str, 
        out_csv: str = "genesis.csv", 
        lon_to_0360: bool = True,
        out_root: str | Path | None = None):

    base = Path(out_root)
    base.mkdir(parents=True, exist_ok=True)
    track_path = (base / track_csv) if (base / track_csv).exists() else Path(track_csv)
    if not track_path.exists():
        raise FileNotFoundError(f"Track CSV not found at: {track_path}")

    df = pd.read_csv(track_path)

    cols = {c.lower(): c for c in df.columns}
    required = {"time", "label", "storm_id", "lat", "lon"}
    missing = [c for c in required if c not in cols]
    if missing:
        raise ValueError(f"Missing required columns: {missing}")

    df = df.rename(columns={
        cols["time"]: "time",
        cols["label"]: "Label",
        cols["storm_id"]: "Storm_id",
        cols["lat"]: "lat",
        cols["lon"]: "lon",
    })
 
    df["time"] = pd.to_datetime(df["time"], errors="coerce").dt.tz_localize(None)
    df = df.dropna(subset=["time", "lat", "lon"])

    df["Storm_id"] = df["Storm_id"].astype(str).str.strip()
    df = df.replace({"Storm_id": {"": np.nan}}).dropna(subset=["Storm_id"])

    df_h = df[df["Label"].astype(str).str.strip().str.lower() == "hurricane"].copy()
    if df_h.empty:        
        out_df = pd.DataFrame(columns=["Storm_id", "time", "lat", "lon"])
        out_df.to_csv(base / out_csv, index=False)
        print(f"No hurricane rows found.")
        return base / out_csv

    # Pick the earliest time per storm (genesis)
    idx = df_h[df_h["lat"].notna() & (df_h["lat"] != "")].groupby("Storm_id")["time"].idxmin()
    genesis = df_h.loc[idx, ["Storm_id", "time", "lat", "lon"]].sort_values(["Storm_id", "time"]).reset_index(drop=True)

    # Optional conversion to 0..360
    if lon_to_0360:
        genesis["lon"] = genesis["lon"].astype(float).apply(lambda x: x if x >= 0 else x + 360.0)

   
    out_path = base / out_csv
    genesis.to_csv(out_path, index=False)
    print(f"[GENESIS] Wrote {len(genesis)}")

build_genesis_from_track(cfg.track_csv, cfg.genesis_csv, cfg.lon_0_360, cfg.out_root)

[GENESIS] Wrote 9


### Precompute masks

In [4]:
# Returns distances in km with the shape of the broadcast
def _spherical_dist_km(lat1, lon1, lat2, lon2, R=6371.0):   
    lat1 = np.asarray(lat1, dtype=float)
    lon1 = np.asarray(lon1, dtype=float)
    lat2 = np.asarray(lat2, dtype=float)
    lon2 = np.asarray(lon2, dtype=float)

    if lat1.shape == ():
        lat1 = np.broadcast_to(lat1, lat2.shape)
    if lon1.shape == ():
        lon1 = np.broadcast_to(lon1, lon2.shape)

    p = np.pi / 180.0
    dlat = (lat2 - lat1) * p
    dlon = (lon2 - lon1) * p
    a = np.sin(dlat / 2.0) ** 2 + np.cos(lat1 * p) * np.cos(lat2 * p) * np.sin(dlon / 2.0) ** 2
    a = np.clip(a, 0.0, 1.0)
    return 2.0 * R * np.arcsin(np.sqrt(a))

# Normalize longitude to either [-180,180] or [0,360]
def _lon_norm(lon: float, to_0360: bool) -> float:
    if to_0360:
        return lon if lon >= 0 else lon + 360.0
    return lon if lon <= 180 else lon - 360.0

# Create the (H,W) mask with 1 where any genesis event occurs within horizon_h hours and radius_km.
def build_genesis_mask_for_time(grid_lats: np.ndarray, grid_lons: np.ndarray,
                                ts: pd.Timestamp,
                                events: pd.DataFrame,
                                horizon_h: int, radius_km: float,
                                cfg: GenesisCfg) -> np.ndarray:
    H,W = grid_lats.shape
    tmax = ts + pd.Timedelta(hours=horizon_h)
    cand = events[(events["GenesisTime"] >= ts) & (events["GenesisTime"] <= tmax)]
    if cand.empty: 
        return np.zeros((H,W), dtype=np.float32)
    
    

    mask = np.zeros((H,W), dtype=np.float32)
    for _, e in cand.iterrows():        
        glat = float(e["GenesisLat"]); 
        glon = float(e["GenesisLon"])
        #print(f"Building mask for event at time {e['GenesisTime']} lat:{glat} lon:{glon}")
        for i in range(H):
            dists = _spherical_dist_km(glat, glon, grid_lats[i,:], grid_lons[i,:], cfg.earth_radius_km)
            mask[i,:] = np.maximum(mask[i,:], (dists <= radius_km).astype(np.float32))
    return mask


# Second method to chose only one pressure level row per Timestamp
def pick_pressure_row_per_ts(pres: pd.DataFrame, target_pa: float = 85000.0) -> pd.DataFrame:    
    if "Timestamp" not in pres or "Level" not in pres:
        raise ValueError("pres must contain 'Timestamp' and 'Level' columns")

    df = pres.copy().sort_values(["Timestamp", "Level"]).reset_index(drop=True)

    if "pres_PL_mean" in df and df["pres_PL_mean"].notna().any():
        df["diff"] = (df["pres_PL_mean"] - target_pa).abs()
    else:
        # If 'Level' is in hPa (typical), convert to Pa; adjust if your Level already in Pa
        df["diff"] = (df["Level"] * 100.0 - target_pa).abs() if np.issubdtype(df["Level"].dtype, np.number) else 0.0

    # Pick the smallest diff per Timestamp without using idxmin()
    chosen = (
        df.sort_values(["Timestamp", "diff"])
          .drop_duplicates(subset=["Timestamp"], keep="first")
          .drop(columns=["diff"])
          .reset_index(drop=True)
    )
    return chosen

def to_lon_minus180_180(lon):
    lon = np.asarray(lon, dtype=np.float32)
    return (lon + 180.0) % 360.0 - 180.0

def get_latlon(cfg) -> tuple[np.ndarray, np.ndarray]:
    
    # Choose a sample .nc to extract the grid from (dataset grid is static for MERRA-2)
    dataset_dir = Path(cfg.dataset_dir)
    if not dataset_dir.exists():
        raise FileNotFoundError(f"dataset_dir does not exist: {dataset_dir}")

    # Try preferred token first, then fallback
    sample = next((p for p in dataset_dir.rglob("*.nc") if cfg.surface_token in p.name.lower()), None)
    if sample is None:
        sample = next((p for p in dataset_dir.rglob("*.nc") if cfg.pressure_token in p.name.lower()), None)
    if sample is None:
        raise FileNotFoundError(f"No .nc with '{cfg.surface_token}' or '{cfg.pressure_token}' under {dataset_dir}")

    # Read grid + resize to model size
    with Dataset(sample, "r") as nc:        
        lats = np.array(nc.variables["lat"][:], dtype=np.float32)  # (Hsrc,)
        lons = np.array(nc.variables["lon"][:], dtype=np.float32)  # (Wsrc,)

    lons = to_lon_minus180_180(lons)

    # Make sure they are sorted if normalization changed order
    sort_idx = np.argsort(lons)
    lons = lons[sort_idx]

    LON, LAT = np.meshgrid(lons, lats)  # (Hsrc, Wsrc) 

    cfg.min_lat = float(lats.min())
    cfg.max_lat = float(lats.max())

    cfg.min_lon = float(lons.min())
    cfg.max_lon = float(lons.max())

    return LON, LAT, cfg.min_lat, cfg.max_lat, cfg.min_lon, cfg.max_lon

# Generates or get from cache the list with latitude and longitude grids based con H and W images 
def get_latlon_grid_cached(cfg) -> tuple[np.ndarray, np.ndarray]:
    
    if isinstance(cfg.image_size, int):
        H = W = int(cfg.image_size)
    else:
        H, W = map(int, cfg.image_size)
    
    
    cache = Path(cfg.out_root) / f"latlon_{H}x{W}.npz"
    cache.parent.mkdir(parents=True, exist_ok=True)    

    # Use cache if present
    if cache.exists():
        data = np.load(cache)        
        return data["lat_img"], data["lon_img"]
                   
    LON, LAT, min_lat, max_lat, min_lon, max_lon = get_latlon(cfg)

    lat_img = np.array(Image.fromarray(LAT).resize((W, H), Image.BILINEAR), dtype=np.float32)
    lon_img = np.array(Image.fromarray(LON).resize((W, H), Image.BILINEAR), dtype=np.float32)

    lat_img = np.flipud(lat_img)
    lon_img = np.flipud(lon_img)


    print("lat_img min/max:", lat_img.min(), lat_img.max())
    print("lon_img min/max:", lon_img.min(), lon_img.max())
    print("cfg min/max:", cfg.min_lat, cfg.max_lat, cfg.min_lon, cfg.max_lon)    


    np.savez_compressed(cache, lat_img=lat_img, lon_img=lon_img)
    return lat_img, lon_img

# Apply PCA or the preferred feature list to select image channels based on more significant mean features
def select_image_channels(df: pd.DataFrame, cfg: GenesisCfg) -> List[str]:

    mean_cols = [
        col for col in df.columns
        if (col.startswith("sfc_") or col.startswith("pres_")) and col.endswith("_mean")
    ]
    img_map = {
        col: f"img_{col.split('_')[1]}"
        for col in mean_cols
        if f"img_{col.split('_')[1]}" in df.columns
    }

    if not img_map:
        raise ValueError("No matching img_<VAR> columns found.")
    
    X = df[list(img_map.keys())].astype(np.float32).to_numpy()
    col_means = np.nanmean(np.where(np.isfinite(X), X, np.nan), axis=0)
    X[~np.isfinite(X)] = np.take(col_means, np.where(~np.isfinite(X))[1])
    Xs = StandardScaler().fit_transform(X)

    if(cfg.apply_PCA):    
        # Run PCA
        n_comp = min(cfg.max_channels, Xs.shape[1])
        pca = PCA(n_components=n_comp, random_state=cfg.seed)
        pca.fit(Xs)

        loadings = pca.components_.T
        weights = pca.explained_variance_ratio_.reshape(1, -1)
        importance = np.sum(np.abs(loadings) * weights, axis=1)
        ranked_idx = np.argsort(-importance)
        mean_keys = list(img_map.keys())
        selected_imgs = [img_map[mean_keys[i]] for i in ranked_idx[:cfg.max_channels]]
    else:
        # Choose preferred variables first
        preferred = []
        for var in (cfg.prefer_vars_surface):
            img_var = f"img_{var}"
            if img_var in img_map.values():
                preferred.append(img_var)
        for var in (cfg.prefer_vars_pressure):
            img_var = f"img_{var}"
            if img_var in img_map.values() and img_var not in preferred:
                preferred.append(img_var)
        selected_imgs = preferred[:cfg.max_channels]           

    return selected_imgs


# Build input DataFrame merging surface and pressure data
def build_input_table(cfg: GenesisCfg, inference: bool = False, cross_validation=False) -> pd.DataFrame:
    root = Path(cfg.out_root)
    if not inference and cross_validation:
        sfc = pd.read_csv(root/cfg.sfc_csv, parse_dates=["Timestamp"])
        pres = pd.read_csv(root/cfg.pres_csv, parse_dates=["Timestamp"]) 
    if not inference and not cross_validation:
        sfc = pd.read_csv(root/cfg.sfc_csv_train, parse_dates=["Timestamp"])
        pres = pd.read_csv(root/cfg.pres_csv_train, parse_dates=["Timestamp"])  
    if inference:
        sfc = pd.read_csv(root/cfg.sfc_csv_inference, parse_dates=["Timestamp"])
        pres = pd.read_csv(root/cfg.pres_csv_inference, parse_dates=["Timestamp"])

    pres = pres[pres['img_U'].str.contains('lev3', case=False, na=False)]     
    
    idx = sfc[sfc["Latitude"].notna() & (sfc["Latitude"] != "")].groupby("Storm_id")["Timestamp"].idxmin()
    
    sfc["Genesis"] = 0
    # Set 1 for genesis 
    sfc.loc[idx, "Genesis"] = 1
       
    vars_surface = select_image_channels(sfc, cfg)
    vars_pressure = select_image_channels(pres, cfg)

    #print("Vars surface PCA:", vars_surface"cfg", cfg.vars_surface)
    #print("Vars pressure PCA:", vars_pressure, "cfg", cfg.vars_pressure)    
    
    img_cols_sfc  = [f"{v}" for v in vars_surface]
    img_cols_pres = [f"{v}" for v in vars_pressure]

    keep_sfc  = (["Storm_id", "Timestamp", "Latitude", "Longitud", "Genesis"]) \
                + [c for c in img_cols_sfc  if c in sfc.columns]
    keep_pres = (["Storm_id", "Timestamp", "Latitude", "Longitud"]) \
                + [c for c in img_cols_pres if c in pres.columns]
    
    sfc  = sfc[keep_sfc]
    pres = pres[keep_pres]
    #
    merge_key = ["Timestamp"]
    if ("Storm_id" in sfc.columns) and ("Storm_id" in pres.columns):
        merge_key = ["Storm_id", "Timestamp", "Latitude", "Longitud"]
   
    sfc  = sfc.drop_duplicates(subset=merge_key)
    pres = pres.drop_duplicates(subset=merge_key)
   
    df = (sfc.merge(pres, on=merge_key, how="inner")
            .sort_values(merge_key)
            .reset_index(drop=True))
    
    print("Columnas sfc:", df.columns)
   
    return df   
  

# Load genesis events from CSV
def load_genesis_events(cfg: GenesisCfg) -> pd.DataFrame:    
    root = Path(cfg.out_root)
    events = pd.read_csv(root/ cfg.genesis_csv, parse_dates=["time"])
    # normalize longitude convention
    events["lon"] = events["lon"].apply(lambda x: to_lon_minus180_180(float(x)))
    print(events.head())
    events = events.rename(columns={"time":"GenesisTime","lat":"GenesisLat","lon":"GenesisLon","Storm_id":"Storm_id"})
    return events[["GenesisTime","GenesisLat","GenesisLon","Storm_id"]]


# For each row in df, build a mask and save it
def precompute_masks(df: pd.DataFrame, cfg, events: pd.DataFrame | None, save_dir: Path) -> pd.DataFrame:  
    save_dir.mkdir(parents=True, exist_ok=True)
    lat_img, lon_img = get_latlon_grid_cached(cfg)  # same grid for all rows (MERRA-2 static grid)

    paths: list[str] = []    
    for _, row in df.iterrows():
        ts = pd.Timestamp(row["Timestamp"])     
        sid = str(row["Storm_id"]).strip() if "Storm_id" in df.columns else None 
      
        ev_use = events
        if (events is not None) and (sid is not None) and ("Storm_id" in events.columns):
            ev_use = events[events["Storm_id"].astype(str).str.strip() == sid]
       
        mask = build_genesis_mask_for_time(
            lat_img, lon_img, ts, ev_use,
            cfg.horizon_hours, cfg.radius_km, cfg
        ).astype(np.float32)
        #plot_image_np(mask, f"Mask for Storm ID: {sid} lat:{lat_img} lon:{lon_img} Time:{ts} ")

          
        mask_flip = np.flipud(mask)
       
        time_tag = ts.isoformat().replace(":", "-")
        if sid:
            out_dir = save_dir / sid
            out_dir.mkdir(parents=True, exist_ok=True)
            p = out_dir / f"mask_{sid}_{time_tag}.npy"
        else:
            p = save_dir / f"mask_{time_tag}.npy"

        np.save(p, mask_flip)
        paths.append(str(p))

    df = df.copy()
    df["mask_path"] = paths
    print(paths)
    return df


events = load_genesis_events(cfg) if cfg.genesis_csv else None

df_input = build_input_table(cfg)
df_masks = precompute_masks(df_input, cfg, events, Path(cfg.out_root)/"mask_cache")

            Storm_id                time   lat         lon
0    Hurricane_Delta 2020-10-04 18:00:00  17.0  -76.000000
1      Hurricane_Eta 2020-10-31 18:00:00  14.9  -67.099998
2     Hurricane_Goni 2020-10-26 12:00:00  14.1  141.700012
3  Hurricane_Haishen 2020-08-30 00:00:00  14.0  147.000000
4      Hurricane_Ida 2021-08-27 00:00:00  18.3  -80.199997
Columnas sfc: Index(['Storm_id', 'Timestamp', 'Latitude', 'Longitud', 'Genesis', 'img_SLP',
       'img_U10M', 'img_V10M', 'img_T2M', 'img_FROCEAN', 'img_EFLUX',
       'img_HFLUX', 'img_TQL', 'img_TQI', 'img_TQV', 'img_U', 'img_V', 'img_T',
       'img_QV', 'img_OMEGA', 'img_CLOUD'],
      dtype='object')
['C:\\Users\\adrid\\OneDrive\\Documents\\Master UMiami AI\\Topic DS - Computer vision\\Final Project\\Outputs\\mask_cache\\Hurricane_Delta\\mask_Hurricane_Delta_2020-10-04T00-00-00.npy', 'C:\\Users\\adrid\\OneDrive\\Documents\\Master UMiami AI\\Topic DS - Computer vision\\Final Project\\Outputs\\mask_cache\\Hurricane_Delta\\mask_Hurrica

### Verified outputs pregenerated masks

In [5]:
# Check some masks stats
def debug_masks(df, sample=10):
    print(f"Mask rows: {len(df)}")   
    idxs = np.linspace(0, len(df)-1, num=min(sample, len(df)), dtype=int)
    print("Checking mask stats for rows:", idxs.tolist())
    for i in idxs:
        row = df.iloc[i]
        m = np.load(row["mask_path"]).astype(np.float32)
        print(f"[{i}] {row['Timestamp']} {row['Storm_id']} shape={m.shape}  min={m.min():.3f} max={m.max():.3f} mean={m.mean():.6f} sum={m.sum():.0f}")        

# check events stats
def debug_events(events):
    print("Events rows:", len(events))
    if len(events) > 0:
        print(events.head())
        print("Lon range:", events['GenesisLon'].min(), "to", events['GenesisLon'].max())
        print("Lat range:", events['GenesisLat'].min(), "to", events['GenesisLat'].max())

# Check candidates for a given timestamp
def dbg_candidates_for_ts(ts, events, horizon_h):
    ts = pd.Timestamp(ts)
    cand = events[(events["GenesisTime"] >= ts) & (events["GenesisTime"] <= ts + pd.Timedelta(hours=horizon_h))]
    print(f"{ts}: candidates in 0–{horizon_h}h =", len(cand))
    return cand

# Check candidate for one row in df
def check_one(df, events, cfg, idx=0):
    ts = pd.Timestamp(df.iloc[idx]["Timestamp"])
    tmax = ts + pd.Timedelta(hours=cfg.horizon_hours)
    cand = events[(events["GenesisTime"] >= ts) & (events["GenesisTime"] <= tmax)]
    print(f"TS={ts}  candidates in window={len(cand)}")
    return cand.head(10)

# Describe lat/lon grid
def describe_grid(lat_img, lon_img, name="grid"):
    print(f"{name}: shape={lat_img.shape}")
    print(f"  lat range: {float(np.nanmin(lat_img)):.2f} → {float(np.nanmax(lat_img)):.2f} "
          f"(top row mean={float(np.nanmean(lat_img[0,:])):.2f}, "
          f"bottom row mean={float(np.nanmean(lat_img[-1,:])):.2f})")
    print(f"  lon range: {float(np.nanmin(lon_img)):.2f} → {float(np.nanmax(lon_img)):.2f} "
          f"(left col mean={float(np.nanmean(lon_img[:,0])):.2f}, "
          f"right col mean={float(np.nanmean(lon_img[:,-1])):.2f})")


ts0 = pd.Timestamp("2021-08-27 00:00:00")
cand = dbg_candidates_for_ts(ts0, events, cfg.horizon_hours)
print("Candidates", cand)

ts0 = pd.Timestamp("2020-11-13 12:00:00")
cand = dbg_candidates_for_ts(ts0, events, cfg.horizon_hours)
print("Candidates", cand)

debug_masks(df_masks, sample=10)
debug_events(events)
check_one(df_masks, events, cfg, idx=0)    
describe_grid(*get_latlon_grid_cached(cfg), name="cached grid")


2021-08-27 00:00:00: candidates in 0–24h = 1
Candidates   GenesisTime  GenesisLat  GenesisLon       Storm_id
4  2021-08-27        18.3  -80.199997  Hurricane_Ida
2020-11-13 12:00:00: candidates in 0–24h = 1
Candidates           GenesisTime  GenesisLat  GenesisLon        Storm_id
5 2020-11-13 12:00:00       12.06  -74.800003  Hurricane_Iota
Mask rows: 280
Checking mask stats for rows: [0, 31, 62, 93, 124, 155, 186, 217, 248, 279]
[0] 2020-10-04 00:00:00 Hurricane_Delta shape=(224, 352)  min=0.000 max=1.000 mean=0.000368 sum=29
[31] 2020-10-07 21:00:00 Hurricane_Delta shape=(224, 352)  min=0.000 max=0.000 mean=0.000000 sum=0
[62] 2020-11-01 18:00:00 Hurricane_Eta shape=(224, 352)  min=0.000 max=0.000 mean=0.000000 sum=0
[93] 2020-10-26 15:00:00 Hurricane_Goni shape=(224, 352)  min=0.000 max=0.000 mean=0.000000 sum=0
[124] 2020-11-13 12:00:00 Hurricane_Iota shape=(224, 352)  min=0.000 max=1.000 mean=0.000380 sum=30
[155] 2020-09-29 09:00:00 Hurricane_Marie shape=(224, 352)  min=0.000 max=

### Plot masks for verification

In [6]:
# Visualize the precomputed specific mask at row idx
def show_mask(df, idx: int, vmin=0.0, vmax=1.0, cmap="turbo"):   
    row = df.iloc[idx]
    m = np.load(row["mask_path"]).astype(np.float32)  # (H,W)
    plt.figure(figsize=(6,3))

    m_arr = m   

    m_arr = np.flipud(m_arr)
  
    im = plt.imshow(m_arr, vmin=vmin, vmax=vmax, cmap=cmap)
    plt.colorbar(im, fraction=0.046, pad=0.04, label="Genesis prob (0–24h)")
    plt.title(f"Mask {row['Timestamp']}")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

# Visualize the precomputed specific mask overlaid on an input image over specific column or feature
def show_mask_overlay(df, idx: int, img_col: str, out_root: str, alpha=0.3,
                      img_cmap="terrain", mask_cmap="Reds"):  
    save_dir = Path(cfg.results_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    row = df.iloc[idx]
   
    mask = np.load(row["mask_path"]).astype(np.float32)    
   
    rel = row.get(img_col, "")
    if not isinstance(rel, str) or len(rel) == 0:
        base = np.zeros_like(mask)
        title_img = f"(missing) {img_col}"
    else:
        p = Path(out_root) / rel
        if p.exists():            
            im = Image.open(p).convert("L").resize((mask.shape[1], mask.shape[0]), Image.BILINEAR)
            base = (np.asarray(im, dtype=np.float32) / 255.0)
            title_img = p.name
        else:
            base = np.zeros_like(mask)
            title_img = f"(not found) {img_col}"

    plt.figure(figsize=(6,3))
    base_arr = base
    mask_arr = mask

    base_arr = np.flipud(base_arr)
    mask_arr = np.flipud(mask_arr)

    plt.imshow(base_arr, cmap=img_cmap, vmin=0, vmax=1)
    plt.imshow(mask_arr, cmap=mask_cmap, vmin=0, vmax=1, alpha=alpha)
    plt.title(f"{img_col} + mask {row['Timestamp']}\n{title_img}")
    plt.axis("off")
    plt.tight_layout()
    #plt.show()    
    plt.savefig(cfg.results_dir / f"mask_overlay_{idx}_{img_col}.png", bbox_inches="tight")
    plt.close()

#Plot some gallery of masks across time (every `step` rows).
def show_mask_gallery(df, step=10, max_panels=12):   
    save_dir = Path(cfg.results_dir)
    save_dir.mkdir(parents=True, exist_ok=True) 
    idxs = list(range(0, len(df), step))[:max_panels]
    n = len(idxs)
    if n == 0:
        print("No rows to show.")
        return
    cols = min(4, n)
    rows = int(np.ceil(n / cols))

    plt.figure(figsize=(4*cols, 3*rows))
    for k, i in enumerate(idxs, 1):
        row = df.iloc[i]
        m = np.load(row["mask_path"]).astype(np.float32)
        ax = plt.subplot(rows, cols, k)
        
        mask_arr = m     
        mask_arr = np.flipud(mask_arr)

        ax.imshow(mask_arr, vmin=0, vmax=1)
        ax.set_title(pd.to_datetime(row["Timestamp"]).strftime("%Y-%m-%d %H:%M"))
        ax.axis("off")
    plt.tight_layout()
    #plt.show()
    plt.savefig(cfg.results_dir / f"mask_gallery.png", bbox_inches="tight")
    plt.close()


# pick all image columns
img_cols = [c for c in df_input.columns if c.startswith("img_")]

for idx in range(min(2, len(df_masks))):  # first 10 rows
    for img_col in img_cols:    
        show_mask_overlay(df_masks, idx=idx, img_col=img_col, out_root=cfg.out_root, alpha=0.5)

show_mask_gallery(df_masks, step=20, max_panels=48)


### Train and inference model

#### Unet

In [7]:
# Classes that implements the U-Net architecture
class ConvBNReLU(nn.Module):
    def __init__(self, c_in, c_out, k=3, s=1, p=1):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, k, s, p, bias=False)
        self.bn   = nn.BatchNorm2d(c_out)
        self.act  = nn.ReLU(inplace=True)
    def forward(self, x): 
        return self.act(self.bn(self.conv(x)))  
    

class UNetSmall(nn.Module):
    def __init__(self, c_in=6, c_mid=32):
        super().__init__()
        self.e1 = nn.Sequential(ConvBNReLU(c_in, c_mid), ConvBNReLU(c_mid, c_mid))
        self.p1 = nn.MaxPool2d(2)
        self.e2 = nn.Sequential(ConvBNReLU(c_mid, c_mid*2), ConvBNReLU(c_mid*2, c_mid*2))
        self.p2 = nn.MaxPool2d(2)
        self.b  = nn.Sequential(ConvBNReLU(c_mid*2, c_mid*4), ConvBNReLU(c_mid*4, c_mid*4))
        self.u2 = nn.ConvTranspose2d(c_mid*4, c_mid*2, 2, 2)
        self.d2 = nn.Sequential(ConvBNReLU(c_mid*4, c_mid*2), ConvBNReLU(c_mid*2, c_mid*2))
        self.u1 = nn.ConvTranspose2d(c_mid*2, c_mid, 2, 2)
        self.d1 = nn.Sequential(ConvBNReLU(c_mid*2, c_mid), ConvBNReLU(c_mid, c_mid))
        self.head = nn.Conv2d(c_mid, 1, 1)

    def forward(self, x):
        e1 = self.e1(x); p1 = self.p1(e1)
        e2 = self.e2(p1); p2 = self.p2(e2)
        b  = self.b(p2)
        u2 = self.u2(b)
        d2 = self.d2(torch.cat([u2, e2], dim=1))
        u1 = self.u1(d2)
        d1 = self.d1(torch.cat([u1, e1], dim=1))
        return self.head(d1)  # (B,1,H,W)


#### Attention Unet: Attention gate, regularization and stride convolution

In [None]:

# Attention gate module used to filter skip connection features (x) 
class AttentionGate(nn.Module):
   
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
   
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
  
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # Output attention map
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        
        # Combined feature map
        psi_input = self.relu(g1 + x1)
        
        # Attention coefficient
        psi_output = self.psi(psi_input)
        
        # Apply attention weights to the skip connection features
        return x * psi_output

# 3-level UNet with Strided Convolutions for downsampling and Attention Gates
class AttentionUNet(nn.Module):    
    def __init__(self, c_in=16, c_mid=32):
        super().__init__()
        
        # --- Encoder Path (3 Levels) ---
        
        # Level 1 (Initial Stage) Skip 1 (c_mid=32)
        self.enc1 = nn.Sequential(ConvBNReLU(c_in, c_mid), ConvBNReLU(c_mid, c_mid))
        
        # Level 2 (Downsample to H/2) Skip 2 (c_mid*2=64)
        self.enc2 = nn.Sequential(
            ConvBNReLU(c_mid, c_mid*2, s=2), # Downsample, 32 - 64 channels
            ConvBNReLU(c_mid*2, c_mid*2)
        )
        
        # Bottleneck (Downsample to H/4)  Final Features (c_mid*4=128)
        self.bottleneck = nn.Sequential(
            ConvBNReLU(c_mid*2, c_mid*4, s=2), # Downsample, 64 - 128 channels
            ConvBNReLU(c_mid*4, c_mid*4),
            nn.Dropout2d(p=0.2) 
        )
        
        # --- Decoder Path (2 Levels) ---
        
        # Up-convolution: 128 - 64
        self.up2 = nn.ConvTranspose2d(c_mid*4, c_mid*2, 2, 2) 
        
        # Decoder 2: Gate features (F_g) must match UP2 output (c_mid*2 = 64). Skip features (F_l) must match E2 (c_mid*2 = 64).
        self.att2 = AttentionGate(F_g=c_mid*2, F_l=c_mid*2, F_int=c_mid*2) 
        
        self.dec2 = nn.Sequential(
            # Input: c_mid*2 (Up) + c_mid*2 (Attended Skip) = 128 (c_mid*4)
            ConvBNReLU(c_mid*4, c_mid*2), 
            ConvBNReLU(c_mid*2, c_mid*2) 
        )

        # Up-convolution: 64 - 32
        self.up1 = nn.ConvTranspose2d(c_mid*2, c_mid, 2, 2) 
        
        # Decoder 1: Gate features (F_g) must match UP1 output (c_mid = 32). Skip features (F_l) must match E1 (c_mid = 32).
        self.att1 = AttentionGate(F_g=c_mid, F_l=c_mid, F_int=c_mid) 
        
        self.dec1 = nn.Sequential(
            # Input: c_mid (Up) + c_mid (Attended Skip) = 64 (c_mid*2)
            ConvBNReLU(c_mid*2, c_mid), 
            ConvBNReLU(c_mid, c_mid)
        )

        # Output Head
        self.head = nn.Conv2d(c_mid, 1, 1)

        print(f"Instantiated Attention UNet with C_in={c_in}, Base Dim={c_mid} (Channels fixed)")

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x) # Skip 1 (c_mid=32)
        e2 = self.enc2(e1) # Skip 2 (c_mid*2=64)
        b = self.bottleneck(e2) # Bottleneck (c_mid*4=128)

        # Decoder 2
        u2 = self.up2(b) # Output: 64 channels
        e2_att = self.att2(g=u2, x=e2) # g=64, x=64. 
        d2_in = torch.cat([u2, e2_att], dim=1)
        d2 = self.dec2(d2_in) # c_mid*2 channels (64)

        # Decoder 1
        u1 = self.up1(d2) # Output: 32 channels
        e1_att = self.att1(g=u1, x=e1) # g=32, x=32. 
        d1 = self.dec1(d1_in) # c_mid channels (32)

        return self.head(d1)

#### ResSegNet

In [None]:
# Classes that implements a ResNet-like Block architecture
# A basic residual block
class ResNetBlock(nn.Module):
    
    def __init__(self, c_in, c_out, downsample=False):
        super().__init__()
        stride = 2 if downsample else 1
        self.conv1 = ConvBNReLU(c_in, c_out, s=stride, p=1)
        self.conv2 = ConvBNReLU(c_out, c_out, s=1, p=1)
        
        # Identity shortcut connection
        if c_in != c_out or downsample:
            self.shortcut = nn.Sequential(
                nn.Conv2d(c_in, c_out, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(c_out)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.conv1(x)
        out = self.conv2(out)
        out = out + identity
        return F.relu(out, inplace=False)

# simplified Segmentation Network using ResNet-like blocks and simple upsampling.
class ResSegNet(nn.Module):
    
    def __init__(self, c_in=6, c_mid=32):
        super().__init__()
        
        # ENCODER (Downsampling path)
        self.entry = ConvBNReLU(c_in, c_mid, k=7, s=2, p=3)
        self.block1 = ResNetBlock(c_mid, c_mid)
        self.block2 = ResNetBlock(c_mid, c_mid*2, downsample=True) # Downsample x2
        self.block3 = ResNetBlock(c_mid*2, c_mid*4, downsample=True) # Downsample x4
        self.block4 = ResNetBlock(c_mid*4, c_mid*4)

        # DECODER (Upsampling path)
        self.up_c2 = ConvBNReLU(c_mid*4, c_mid*2)
        self.up_c1 = ConvBNReLU(c_mid*2, c_mid)
        
        # FINAL HEAD: c_mid -> 1 channel (logits)
        self.head = nn.Conv2d(c_mid, 1, 1)

    def forward(self, x):
        # Initial downsampling
        x = self.entry(x)
        
        # Encoder Blocks
        x = self.block1(x) # c_mid, H/2, W/2
        x = self.block2(x) # c_mid*2, H/4, W/4
        x = self.block3(x) # c_mid*4, H/8, W/8
        x = self.block4(x) # c_mid*4, H/8, W/8
        
        # Decoder Upsampling (Simple FCN-style upsampling)
        # Upsample by 2
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 
        x = self.up_c2(x) # c_mid*2, H/4, W/4
        
        # Upsample by 2
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 
        x = self.up_c1(x) # c_mid, H/2, W/2

        # Final upsample to original size (H,W)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 

        return self.head(x) # (B,1,H,W)

#### Functions for training and testing

In [10]:
#Focal Loss implementation for loss 
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):    
     
        bce_loss = self.bce_loss(inputs, targets)       
        p = torch.sigmoid(inputs)                
        p_t = targets * p + (1 - targets) * (1 - p)        
        modulating_factor = (1.0 - p_t).pow(self.gamma)              
        alpha_factor = targets * self.alpha + (1.0 - targets) * (1.0 - self.alpha)
        focal_loss = alpha_factor * modulating_factor * bce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss
        
def visualize_inference(x, y, probs,  max_ch=4):

    x_np = x[0].cpu()          # (C, H, W)
    y_np = y[0, 0].cpu()       # (H, W)
    out_np = probs[0, 0].cpu()

    n_in = min(max_ch, x_np.shape[0])
    
    out_np = probs[0, 0].cpu()
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    arr_out = y_np.numpy()

    # Flip vertically (top ⇄ bottom)
    arr_out = np.flipud(arr_out)
    size = arr_out.shape
    
    axes[0].imshow(arr_out, cmap="gray", vmin=0, vmax=1, extent=[0, size[1], 0, size[0]])
    axes[0].set_title("Ground truth mask")
    axes[0].axis("off")

    arr_out_m = out_np.numpy()

    # Flip vertically (top ⇄ bottom)
    arr_out_m = np.flipud(arr_out_m)

    #arr_out_m.resize(arr_out.shape, Image.BILINEAR)
    arr_out_m_resized = cv2.resize(arr_out_m, (arr_out.shape[1], arr_out.shape[0]), interpolation=cv2.INTER_LINEAR)
     

    im = axes[1].imshow(arr_out_m_resized, cmap="gray", vmin=0, vmax=1, extent=[0, size[1], 0, size[0]])
    axes[1].set_title("Model output")
    axes[1].axis("off")

    fig.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)

    fig.suptitle(f"GT vs Output", fontsize=14)
    plt.tight_layout()
    
    plt.show()
   



# Visualize a batch: inputs, 
# , output, and feature maps
def visualize_batch(model, x, y, probs, device, epoch, feature_maps, max_ch=4):    

    x_np = x[0].cpu()          # (C, H, W)
    y_np = y[0, 0].cpu()       # (H, W)
    out_np = probs[0, 0].cpu() # (H, W)

    # ------------- Plot Inputs -------------
    n_in = min(max_ch, x_np.shape[0])
    
    fig, axes = plt.subplots(1, n_in, figsize=(4*n_in, 4))
    if n_in == 1:
        axes = [axes]

    for i in range(n_in):
        arr_in = x_np[i].numpy()

        # Flip vertically (top ⇄ bottom)
        arr_in = np.flipud(arr_in)
        axes[i].imshow(arr_in, cmap="viridis")
   
        axes[i].set_title(f"Input ch {i}")
        axes[i].axis("off")
    fig.suptitle(f"Epoch {epoch} – Inputs", fontsize=14)
    plt.tight_layout()
    plt.savefig(cfg.results_dir / f"inputs_{epoch}.png", bbox_inches="tight")
    plt.close()
    #plt.show()    

    # ------------- GT vs Output -------------
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    arr_out = y_np.numpy()

    # Flip vertically (top ⇄ bottom)
    arr_out = np.flipud(arr_out)
    
    axes[0].imshow(arr_out, cmap="gray", vmin=0, vmax=0.5)
    axes[0].set_title("GT mask")
    axes[0].axis("off")

    arr_out_m = out_np.numpy()

    # Flip vertically (top ⇄ bottom)
    arr_out_m = np.flipud(arr_out_m)
    # Ojo
    arr_out_m_resized = cv2.resize(arr_out_m, (arr_out.shape[1], arr_out.shape[0]), interpolation=cv2.INTER_LINEAR)

    im = axes[1].imshow(arr_out_m_resized, cmap="gray", vmin=0, vmax=0.5)
    axes[1].set_title("Model output (sigmoid)")
    axes[1].axis("off")

    fig.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)

    fig.suptitle(f"Epoch {epoch} – GT vs Output", fontsize=14)
    plt.tight_layout()
    #plt.show()
    plt.savefig(cfg.results_dir / f"output_{epoch}.png", bbox_inches="tight")
    plt.close()

    # ------------- Intermediate feature maps -------------
    def show_feats(name, max_ch=4):
        if name not in feature_maps:
            print(f"[WARN] No feature map captured for {name}")
            return
        fm = feature_maps[name][0]  # (C, H, W) for first sample
        n = int(min(max_ch, fm.shape[0])/4)

        fig, axes = plt.subplots(4, n, figsize=(18, 12))
        if n == 1:
            axes = [axes]
        img = 0
        for j in range(4):
            for i in range(n):

                arr_fm = fm[img].numpy()
                img += 1

                # Flip vertically (top ⇄ bottom)
                arr_fm = np.flipud(arr_fm)

                axes[j][i].imshow(arr_fm, cmap="viridis")
                axes[j][i].set_title(f"{name} ch {img}")
                axes[j][i].axis("off")
            
        fig.suptitle(f"Epoch {epoch} – {name}", fontsize=14)
        plt.tight_layout()
        plt.savefig(cfg.results_dir / f"{name}_{epoch}.png", bbox_inches="tight")
        plt.close()
        #plt.show()

    # Show some layers
    show_feats("enc1", max_ch=max_ch)
    show_feats("bottleneck", max_ch=max_ch)
    show_feats("dec1", max_ch=max_ch)


def plot_image_np(img, title):
# --- load the PNG ---
 
    base_arr = img
   
    base_arr = np.flipud(base_arr)

    if isinstance(cfg.image_size, int):
        H = W = int(cfg.image_size)
    else:
        H, W = map(int, cfg.image_size)

    # --- plot ---
    plt.figure(figsize=(10,6))
    plt.imshow(base_arr, cmap="gray", vmin=0, vmax=1)
    plt.title(title)
    plt.show()




#### Classes to manage DataSet

In [None]:

def _open_gray(path: Path, size_hw: Tuple[int,int]) -> np.ndarray:
    H,W = size_hw
    if not path.exists():
        print(f"[WARN] missing image: {path}")
        return np.zeros((H,W), dtype=np.float32)
    im = Image.open(path).convert("L").resize((W,H), Image.BILINEAR)
    arr = np.asarray(im, dtype=np.float32) / 255.0
    return arr

# Generate DataLoader for the training or evaluation
def _mk_loader(ds, cfg, shuffle):
    
    is_windows = (platform.system() == "Windows")
    eff_workers = cfg.num_workers if (cfg.num_workers and not is_windows) else 0
    eff_persistent = (eff_workers > 0) and cfg.persistent_workers

    return torch.utils.data.DataLoader(
            ds,
            batch_size=cfg.batch_size, shuffle=shuffle,
            num_workers=eff_workers,
            pin_memory=(cfg.pin_memory and DEVICE.type=="cuda"),
            persistent_workers=eff_persistent,
            prefetch_factor=(cfg.prefetch_factor if eff_workers>0 else None),
            drop_last=False
        )

# Class to manage the datasets for training, validation, and testing
class GenesisDataset(torch.utils.data.Dataset):
    def __init__(self, df, cfg, events=None, split="train"):
        # Make a copy of the input DataFrame
        self.df = df.copy()
        
        # Image size configuration
        if isinstance(cfg.image_size, int):
            self.H = self.W = int(cfg.image_size)
        else:
            self.H, self.W = map(int, cfg.image_size)


        # Split the DataFrame based on the 'storm_id' column
        # Handle both empty string and NaN values for 'storm_id'
        df_has_genesis= self.df[self.df["Genesis"].notna() & (self.df['Genesis'] == 1)]
        df_no_genesis = self.df[self.df['Genesis'].isna() | (self.df['Genesis'] == 0)]
        print(f"Total samples: {len(self.df)} | With genesis: {len(df_has_genesis)} | Without genesis: {len(df_no_genesis)}")

        # Define split ratios
        if split == "train" or split == "val":
            train_ratio = 0.8
            val_ratio = 0.2          

            # Split the 'has_genesis' data
            train_has_genesis, val_has_genesis = train_test_split(
                df_has_genesis, train_size=train_ratio, shuffle=False
            )
           
            # Split the 'no_genesis' data
            train_no_genesis, val_no_genesis = train_test_split(
                df_no_genesis, train_size=train_ratio, shuffle=False
            )            

            # Concatenate the splits to form the final datasets
            train_df = pd.concat([train_has_genesis, train_no_genesis])
            val_df = pd.concat([val_has_genesis, val_no_genesis])        

            # Assign the appropriate DataFrame 
            if split == "train":
                self.df = train_df
                print(f"Training samples: {len(self.df)}")     
            else: # "test" or other
                self.df = val_df
                print(f"Validation samples: {len(self.df)}")
        
        
        self.idx = np.arange(len(self.df))

        # Other initializations
        self.root = Path(cfg.out_root)     
        self.img_cols = [c for c in self.df.columns if c.startswith("img_")]
        self.has_mask = "mask_path" in self.df.columns

      
        self.image_transform = getattr(cfg, "image_transform", "identity")

    def __len__(self):
        return len(self.idx)

    # Utilities in class
    def _apply_raster_transform(self, arr: np.ndarray) -> np.ndarray:
        t = self.image_transform
        if t == "identity": return arr
        if t == "flipud":   return np.flipud(arr)
        if t == "fliplr":   return np.fliplr(arr)
        if t == "rot180":   return np.rot90(arr, 2)
        return arr

    def _load_gray(self, rel_path: str) -> np.ndarray:
        H, W = self.H, self.W
        if not isinstance(rel_path, str) or not rel_path:
            return np.zeros((H, W), dtype=np.float32)
        p = Path(rel_path)
        if not p.exists():
            p = self.root / rel_path  # try relative to out_root
        if not p.exists():
            print(f"[WARN] missing image: {rel_path}")
            return np.zeros((H, W), dtype=np.float32)
        try:
            im = Image.open(p).convert("L").resize((W, H), Image.BILINEAR)
            arr = np.asarray(im, dtype=np.float32) / 255.0   
            #plot_image_np(arr, title=f"Image: {rel_path}")     
            return self._apply_raster_transform(arr)            
        except Exception as e:
            print(f"[ERR] load image failed {p}: {e}")
            return np.zeros((H, W), dtype=np.float32)
            

    def _load_mask(self, rel_path: str) -> np.ndarray:
        H, W = self.H, self.W
        if not isinstance(rel_path, str) or not rel_path:
            return np.zeros((H, W), dtype=np.float32)
        p = Path(rel_path)
        if not p.exists():
            p = self.root / rel_path
        if not p.exists():
            print(f"[WARN] missing mask: {rel_path}")
            return np.zeros((H, W), dtype=np.float32)
        try:
            m = np.load(p).astype(np.float32)
            # ensure 2D (H,W)
            if m.ndim == 3 and m.shape[0] in (1,):
                m = m[0]
            if m.shape != (H, W):
                m_img = Image.fromarray(m) if m.dtype != np.float32 else Image.fromarray(m.astype(np.float32))
                m_img = m_img.resize((W, H), Image.NEAREST)
                m = np.asarray(m_img, dtype=np.float32)
            m = np.clip(m, 0.0, 1.0)
            m = self._apply_raster_transform(m)  
            #plot_image_np(m, title=f"Mask: {rel_path}")  
            return m
        except Exception as e:
            print(f"[ERR] load mask failed {p}: {e}")
            return np.zeros((H, W), dtype=np.float32)

    # Get item by index, load images and masks
    def __getitem__(self, i):
        try:
            row = self.df.iloc[self.idx[i]]

            # stack channels          
            chans = [self._load_gray(row.get(col)) for col in self.img_cols]            
            x = np.stack(chans, axis=0).astype(np.float32)  # (C,H,W)

            # mask
            if self.has_mask:               
                mask = self._load_mask(row.get("mask_path"))
            else:
                mask = np.zeros((self.H, self.W), dtype=np.float32)

            # to tensors (faster than letting collate cast)
            x = torch.from_numpy(x)                             # (C,H,W), float32
            y = torch.from_numpy(mask).unsqueeze(0)            # (1,H,W)
            ts = str(row.get("Timestamp", ""))                 # keep as string

            return x, y, ts
        except Exception as e:
            print(f"[DatasetError] idx={i} global={self.idx[i]}: {e}")
            # re-raise so with num_workers=0 you get the traceback
            raise



# Generates folds for Leave-One-Storm-Out Cross-Validation (LOSOCV). Each fold uses one unique storm
def get_cv_folds(df: pd.DataFrame, n_splits: int = 9):   
    df_has_storm = df[df['Storm_id'].notna() & (df['Storm_id'] != '')]
    unique_storms = df_has_storm['Storm_id'].unique().tolist()
    
    # Determine the actual number of folds
    if len(unique_storms) < 1:
        print("[ERROR] Cannot perform CV: No unique storms found in the dataset.")
        return []
        
    num_storms = len(unique_storms)
    print(f"Found {num_storms} unique storms. Generating {num_storms} CV folds (LOSOCV).")
    
    all_folds = []
    
    for i, test_storm in enumerate(unique_storms):
        print(f"--- FOLD {i+1}/{num_storms}: Test Storm = {test_storm} ---")
        
        # Test Split (One storm)
        test_df = df[df['Storm_id'] == test_storm]
        
        # Training/Validation Split (All other data, including 'no-storm' samples)
        train_val_df = df[df['Storm_id'] != test_storm]
        
        # Further split train_val into strict train and validation sets 
        if len(train_val_df) > 0:
        
            train_df, val_df = train_test_split(train_val_df, test_size=0.10, random_state=42, shuffle=True)
        else:
            train_df = pd.DataFrame()
            val_df = pd.DataFrame()

        fold = {
            "fold_index": i + 1,
            "test_storm": test_storm,
            "train_df": train_df,
            "val_df": val_df,
            "test_df": test_df,
        }
        all_folds.append(fold)
        
    return all_folds

# Class to manage a pre-split dataset for a specific CV fold.
class GenesisDataset_CV(torch.utils.data.Dataset):
    
    def __init__(self, fold_df: pd.DataFrame, cfg: GenesisCfg, events=None, split="fold"):
        self.df = fold_df.copy().reset_index(drop=True)
        self.split = split
        
        if isinstance(cfg.image_size, int):
            self.H = self.W = int(cfg.image_size)
        else:
            self.H, self.W = map(int, cfg.image_size)

        self.idx = np.arange(len(self.df))
        print(f"Dataset '{split}' initialized with {len(self.df)} samples.")
        
        self.root = Path(cfg.out_root)
        # Using a fixed list of image columns based on cfg for mock consistency
        all_img_vars = cfg.prefer_vars_surface + cfg.prefer_vars_pressure
        self.img_cols = [f"img_{v}" for v in all_img_vars if f"img_{v}" in self.df.columns]
        self.has_mask = "mask_path" in self.df.columns
        self.image_transform = getattr(cfg, "image_transform", "identity")


    def __len__(self):
        return len(self.idx)

    # Utilities in class (remain mostly unchanged, adapted for type hints)
    def _apply_raster_transform(self, arr: np.ndarray) -> np.ndarray:
        t = self.image_transform
        if t == "identity": return arr
        if t == "flipud": return np.flipud(arr)
        if t == "fliplr": return np.fliplr(arr)
        if t == "rot180": return np.rot90(arr, 2)
        return arr

    def _load_gray(self, rel_path: str) -> np.ndarray:
        H, W = self.H, self.W
        if not isinstance(rel_path, str) or not rel_path:
            return np.zeros((H, W), dtype=np.float32)
        p = Path(rel_path)
        if not p.exists():
            p = self.root / rel_path  # try relative to out_root
        if not p.exists():
            print(f"[WARN] missing image: {rel_path}")
            return np.zeros((H, W), dtype=np.float32)
        try:
            im = Image.open(p).convert("L").resize((W, H), Image.BILINEAR)
            arr = np.asarray(im, dtype=np.float32) / 255.0        
            return self._apply_raster_transform(arr)            
        except Exception as e:
            print(f"[ERR] load image failed {p}: {e}")
            return np.zeros((H, W), dtype=np.float32)
            

    def _load_mask(self, rel_path: str) -> np.ndarray:
        H, W = self.H, self.W
        if not isinstance(rel_path, str) or not rel_path:
            return np.zeros((H, W), dtype=np.float32)
        p = Path(rel_path)
        if not p.exists():
            p = self.root / rel_path
        if not p.exists():
            print(f"[WARN] missing mask: {rel_path}")
            return np.zeros((H, W), dtype=np.float32)
        try:
            m = np.load(p).astype(np.float32)
            # ensure 2D (H,W)
            if m.ndim == 3 and m.shape[0] in (1,):
                m = m[0]
            if m.shape != (H, W):
                m_img = Image.fromarray(m) if m.dtype != np.float32 else Image.fromarray(m.astype(np.float32))
                m_img = m_img.resize((W, H), Image.NEAREST)
                m = np.asarray(m_img, dtype=np.float32)
            m = np.clip(m, 0.0, 1.0)
            m = self._apply_raster_transform(m)    
            return m
        except Exception as e:
            print(f"[ERR] load mask failed {p}: {e}")
            return np.zeros((H, W), dtype=np.float32)

    # Get item by index, load images and masks
    def __getitem__(self, i):
        try:
            row = self.df.iloc[self.idx[i]]

            # stack channels          
            chans = [self._load_gray(row.get(col)) for col in self.img_cols]            
            x = np.stack(chans, axis=0).astype(np.float32)  # (C,H,W)

            # mask
            if self.has_mask:               
                mask = self._load_mask(row.get("mask_path"))
            else:
                mask = np.zeros((self.H, self.W), dtype=np.float32)

            # to tensors (faster than letting collate cast)
            x = torch.from_numpy(x)                             # (C,H,W), float32
            y = torch.from_numpy(mask).unsqueeze(0)            # (1,H,W)
            ts = str(row.get("Timestamp", ""))                 # keep as string

            return x, y, ts
        except Exception as e:
            print(f"[DatasetError] idx={i} global={self.idx[i]}: {e}")
            # re-raise so with num_workers=0 you get the traceback
            raise

def get_model_class(model_name: str):
    if model_name == "UNet":
        return UNetSmall
    elif model_name == "ResSegNet":
        return ResSegNet
    elif model_name == "SwinUNet":
        return SwinUNet 
    elif model_name == "AttentionUNet":
        return AttentionUNet
    else:
        raise ValueError(f"Unknown model name: {model_name}")


# Evaluate PR-AUC and Brier score
@torch.no_grad()
def test_model(model, loader, device):
    threshold=0.2
    
    model.eval()
    y_true, y_prob = [], []
    all_iou_scores = []
    all_dice_scores = []
    total_intersection = 0.0
    total_union = 0.0
    total_pred_sum = 0.0
    total_target_sum = 0.0

    for xb, yb, _ in loader:
        xb = xb.to(device)
        yb = yb.to(device)
       
        logits = model(xb)             
        p = torch.sigmoid(logits)
        
        
        #visualize_inference(xb, yb, p, max_ch=4)

        preds = (p > threshold).float() # (B, 1, H, W)
        #print("Eval batch logits:", logits)
        y_true.append(yb.cpu().numpy().ravel())               
        y_prob.append(p.cpu().numpy().ravel())

        targets = yb.float()
                    
        intersection = (preds * targets).sum(dim=(2, 3))
        
        pred_sum = preds.sum(dim=(2, 3))
   
        target_sum = targets.sum(dim=(2, 3))

        union = pred_sum + target_sum - intersection

        iou = intersection / (union + 1e-9)
        
        dice = (2 * intersection) / (pred_sum + target_sum + 1e-6)

        all_iou_scores.extend(iou.view(-1).cpu().tolist())
        all_dice_scores.extend(dice.view(-1).cpu().tolist())

        total_intersection += intersection.sum().item()
        total_union += union.sum().item()
        total_pred_sum += pred_sum.sum().item()
        total_target_sum += target_sum.sum().item()
        
        
    y_true = np.concatenate(y_true)
    y_prob = np.concatenate(y_prob)
    y_pred_binary = (y_prob > threshold).astype(int)
    # Average the batch-level scores to get the final dataset score
    final_iou = np.mean(all_iou_scores)
    final_dice = np.mean(all_dice_scores)

    final_iou_global = total_intersection / (total_union + 1e-9)
    final_dice_global = (2.0 * total_intersection) / (total_pred_sum + total_target_sum + 1e-6)
  
    max_index = np.argmax(y_prob)
    min_index = np.argmin(y_prob)
        
    print(f"Max index: {y_prob[max_index]:.6f}")
    print(f"Min index: {y_prob[min_index]:.6f}")
    print("Ground true", y_true)
    print("Predicted probs", y_prob)  
    print(f"Mean predicted probability: {y_prob.mean():.6f}")  
    
    # Check for edge case: no positive samples for AUC-based metrics
    if np.sum(y_true) == 0 or np.sum(y_true) == len(y_true):
        print("AUC metrics cannot be calculated.")
        roc_auc = 0.0  
        pr_curve_data = {"precision": [1.0, 0.0], "recall": [0.0, 1.0], "thresholds": [0.1]}
        roc_curve_data = {"fpr": [0.0, 1.0], "tpr": [0.0, 1.0], "thresholds": [0.1]}
    else:
        # Calculate AUC-based metrics
        roc_auc = roc_auc_score(y_true, y_prob)
        # Precision-Recall Curve Data
        precision, recall, pr_thresholds = precision_recall_curve(y_true, y_prob)
        pr_curve_data = {
            "precision": precision.tolist(), 
            "recall": recall.tolist(), 
            "thresholds": pr_thresholds.tolist()
        }

        # ROC Curve Data (Optional, but useful for comprehensive plotting)
        fpr, tpr, roc_thresholds = roc_curve(y_true, y_prob)
        roc_curve_data = {
            "fpr": fpr.tolist(), 
            "tpr": tpr.tolist(), 
            "thresholds": roc_thresholds.tolist()
        }       
    
    # Accuracy Score
    accuracy = accuracy_score(y_true, y_pred_binary)

    # Precision, Recall, F-score, Support
    p, r, f1, s = precision_recall_fscore_support(y_true, y_pred_binary, average='binary', zero_division=0)
    
    # Brier Score Loss
    brier = brier_score_loss(y_true, y_prob)

    print("IOU", final_iou, "global",  final_iou_global)
    print("F1 Score", f1)
    
    metrics = {
        "ROC_AUC": float(roc_auc),
        "Accuracy": float(accuracy),
        "Precision": float(p),
        "Recall": float(r),
        "F1_Score": float(f1),
        "Brier_Loss": float(brier),
        "IoU": float(final_iou),
        "Dice": float(final_dice),  
        "IoU_Global": float(final_iou_global),
        "Dice_Global": float(final_dice_global),    
    }

    plot_data = {
        "PR_Curve": pr_curve_data,
        "ROC_Curve": roc_curve_data
    }  
    
    return metrics, plot_data

# Cross entropy loss with positive weight
def criterion(logits, y):
    #w = torch.tensor([cfg.pos_weight], device=logits.device)
    #return F.binary_cross_entropy_with_logits(logits, y, pos_weight=w)
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
    return criterion(logits, y)

# Save the model
def save_model(model: nn.Module, img_cols: List[str], cfg: GenesisCfg, model_name, epoch=0, vavg=0.0) -> Path:

    if epoch > 0:
        model_name = f"{model_name}_epoch{epoch}"   
    path = Path(cfg.out_root) / "Models" /f"genesis_{model_name}.pt"
    path.parent.mkdir(parents=True, exist_ok=True)    
    torch.save({"state_dict": model.state_dict(), "img_cols": img_cols, "cfg": vars(cfg), "model_name": model_name, "epoch": epoch, "vavg": vavg}, path)
    return path

# Load the model
def load_model(path: Path, device: torch.device) -> tuple[nn.Module, List[str], GenesisCfg]:
    ckpt = torch.load(path, map_location=device)
    C_in = len(ckpt["img_cols"])
    sd = ckpt["state_dict"]
    
    # Check for model name in cfg, default to UNetSmall if not present
    model_name = ckpt["model_name"] if "model_name" in ckpt else "AttentionUNet"
    ModelClass = get_model_class(model_name) 

    # Strip the '_orig_mod.' prefix if present
    clean_sd = {}
    for k, v in sd.items():
        new_k = k
        if k.startswith("_orig_mod."):
            new_k = k[len("_orig_mod."):]
        clean_sd[new_k] = v

    model = ModelClass(c_in=C_in, c_mid=32).to(device).to(memory_format=torch.channels_last)
    model.load_state_dict(clean_sd, strict=True)  
    cfg = GenesisCfg(**ckpt["cfg"])
    model.eval()
    return model, ckpt["img_cols"], cfg

# Hook to capture feature maps
def make_hook(name, feature_maps: dict):
    def hook(module, input, output):
        # detach & move to cpu to avoid keeping graph / GPU memory
        feature_maps[name] = output.detach().cpu()
    return hook


# Training function
def train_genesis(cfg: GenesisCfg, model_name: str = "UNet"):
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision("high")
    except: pass
    device = DEVICE

    # Get datasets for training, validation and test and convert to loaders
    df = build_input_table(cfg, inference=False, cross_validation=False)
    events = load_genesis_events(cfg) if cfg.genesis_csv else None
    if events is not None:        
        df = precompute_masks(df, cfg, events, Path(cfg.out_root)/"mask_cache")   

    ds_tr = GenesisDataset(df, cfg, None, split="train")
    ds_va = GenesisDataset(df, cfg, None, split="val")

    tr = _mk_loader(ds_tr, cfg, False)
    va = _mk_loader(ds_va, cfg, False)
       
    # Create model, optimizer, scaler
    ModelClass = get_model_class(model_name)
    C_in = len(ds_tr.img_cols)
    base_model = ModelClass(c_in=C_in, c_mid=32).to(device).to(memory_format=torch.channels_last)
    feature_maps = {}
    
    if model_name == "AttentionUNet":
        
        base_model.enc1.register_forward_hook(make_hook("enc1", feature_maps))
        base_model.bottleneck.register_forward_hook(make_hook("bottleneck", feature_maps))
        base_model.dec1.register_forward_hook(make_hook("dec1", feature_maps))

    
    if cfg.use_compile and hasattr(torch, "compile"):
        model = torch.compile(base_model)
    else:
        model = base_model

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    scaler = torch.amp.GradScaler('cuda', enabled=(device.type=="cuda" and cfg.use_amp))    

    amp_ctx = torch.amp.autocast('cuda', enabled=(device.type=="cuda" and cfg.use_amp))
    grad_accum = max(1, getattr(cfg, "grad_accum_steps", 1))
    best_val = float("inf")
    best_path = Path(cfg.out_root)/"genesis.pt"

    train_loss_hist = []  
    val_loss_hist   = []   
    epoch_hist      = []   
        
    # Training loop
    no_improve_epochs = 0
    patience = cfg.patience 
    min_delta = cfg.min_delta 
    
    for epoch in range(1, cfg.epochs+1):
        print("Epoch", epoch)
        x_first = next(iter(tr))[0]
        y_first = next(iter(tr))[1]
        
        model.train()
        opt.zero_grad(set_to_none=True)
        running, seen = 0.0, 0   
            
        for step, (xb, yb, _) in enumerate(tr, 1):
            xb = xb.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            yb = yb.to(device, non_blocking=True)
            with amp_ctx:                
                logits = model(xb)
                batch_loss = criterion(logits, yb)
                loss = batch_loss / grad_accum
            scaler.scale(loss).backward()
            if step % grad_accum == 0:
                scaler.step(opt); scaler.update()
                opt.zero_grad(set_to_none=True)

            running += batch_loss.item()
            seen += 1

        train_avg = running / max(1, seen)         
        train_loss_hist.append(train_avg)          
        # Validation
        model.eval()        

        vloss, vseen = 0.0, 0
        with torch.inference_mode(), amp_ctx:
            if model_name == "AttentionUNet" and (epoch == 1 or epoch % 10 == 0):     
                x = x_first[0:1].to(device) 
                y = y_first[0:1].to(device)

                # Clear old feature maps
                feature_maps.clear()

                # Forward pass (hooks will fill feature_maps)
                logits = model(x)
                probs  = torch.sigmoid(logits)  # (1, 1, H, W)                

                visualize_batch(model, x, y, probs, device, epoch, feature_maps, max_ch=16)
            
            for xb, yb, _ in va:
                xb = xb.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                yb = yb.to(device, non_blocking=True)
                logits = model(xb)  # <-- missing line
                vloss += criterion(logits, yb).item()
                vseen += 1
        vavg = vloss / max(1, vseen)
        val_loss_hist.append(vavg)               
        epoch_hist.append(epoch)  
        print(f"Epoch {epoch:02d} | train_loss={running/max(1,seen):.6f} | val_loss={vavg:.6f}")     
        
        if vavg < best_val - min_delta:
            best_val = vavg
            no_improve_epochs = 0
            save_model(model, ds_tr.img_cols, cfg, model_name, epoch, vavg)
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(f"Early stopping triggered after {epoch} epochs. Best val_loss={best_val:.6f}")
                break
            
    return model, ds_tr.img_cols, train_loss_hist, val_loss_hist, epoch_hist

# Train Leave-One-Storm-Out Cross-Validation
def train_genesis_cv(cfg: GenesisCfg, model_name: str = "UNet"): 
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision("high")
    except: pass
    device = DEVICE

    # Get datasets for training, validation and test and convert to loaders
    df = build_input_table(cfg, inference=False, cross_validation=True)
    events = load_genesis_events(cfg) if cfg.genesis_csv else None
    if events is not None:         
        # Precompute masks needed before splitting into folds
        df = precompute_masks(df, cfg, events, Path(cfg.out_root)/"mask_cache") 

    # 2. Generate CV Folds
    cv_folds = get_cv_folds(df)
    
    if not cv_folds:
        print("CV aborted due to lack of storm data.")
        return 

    all_fold_metrics = []

    no_improve_epochs = 0
    patience = cfg.patience 
    min_delta = cfg.min_delta 
    
    # 3. Iterate through all Folds (N times)
    for fold in cv_folds:
        fold_idx = fold["fold_index"]
        test_storm = fold["test_storm"] 
        print(f"Fold {fold_idx}: Testing on {test_storm}")
        print(f"Train/Val/Test samples: {len(fold['train_df'])}/{len(fold['val_df'])}/{len(fold['test_df'])}")

        ds_tr = GenesisDataset_CV(fold['train_df'], cfg, events, split=f"Fold{fold_idx}_Train")
        ds_va = GenesisDataset_CV(fold['val_df'], cfg, events, split=f"Fold{fold_idx}_Val")
        ds_te = GenesisDataset_CV(fold['test_df'], cfg, events, split=f"Fold{fold_idx}_Test")

        tr = _mk_loader(ds_tr, cfg, shuffle=False)
        va = _mk_loader(ds_va, cfg, shuffle=False)
        te = _mk_loader(ds_te, cfg, shuffle=False)

        ModelClass = get_model_class(model_name)
        C_in = len(ds_tr.img_cols)
        model = ModelClass(c_in=C_in, c_mid=32).to(device).to(memory_format=torch.channels_last)
        
        if cfg.use_compile and hasattr(torch, "compile"):
            model = torch.compile(model)

        opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
        scaler = torch.amp.GradScaler('cuda', enabled=(device.type=="cuda" and cfg.use_amp)) 
        amp_ctx = torch.amp.autocast('cuda', enabled=(device.type=="cuda" and cfg.use_amp))
        grad_accum = max(1, getattr(cfg, "grad_accum_steps", 1))
        best_val = float("inf")
        best_path = Path(cfg.out_root)/f"genesis_{model_name.lower()}_fold{fold_idx}.pt"

        train_loss_hist = [] 
        val_loss_hist   = [] 
        epoch_hist      = [] 
        
        # Training loop
        for epoch in range(1, cfg.epochs+1):
            print(f"Epoch {epoch:02d} running...")
            model.train()
            opt.zero_grad(set_to_none=True)
            running, seen = 0.0, 0     
            for step, (xb, yb, _) in enumerate(tr, 1):
                xb = xb.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                yb = yb.to(device, non_blocking=True)
                with amp_ctx:
                    logits = model(xb)
                    batch_loss = criterion(logits, yb)
                    loss = batch_loss / grad_accum
                
                # Scaler handles mixed precision backward pass
                scaler.scale(loss).backward() 
                
                if step % grad_accum == 0:
                    scaler.step(opt); scaler.update()
                    opt.zero_grad(set_to_none=True)

                running += batch_loss.item()
                seen += 1

            train_avg = running / max(1, seen)         
            train_loss_hist.append(train_avg)           
            
            # Validation
            model.eval()
            vloss, vseen = 0.0, 0
            with torch.inference_mode(), amp_ctx:
                for xb, yb, _ in va:
                    xb = xb.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                    yb = yb.to(device, non_blocking=True)
                    logits = model(xb)  
                    # Use full loss for consistent vloss calculation
                    vloss += criterion(logits, yb).item()
                    vseen += 1
            
            vavg = vloss / max(1, vseen)
            val_loss_hist.append(vavg)              
            epoch_hist.append(epoch) 
            print(f"\n Fold {fold_idx} | Epoch {epoch:02d} | train_loss={train_avg:.6f} | val_loss={vavg:.6f}")

            if vavg < best_val - min_delta:
                best_val = vavg
                no_improve_epochs = 0
                save_model(model, ds_tr.img_cols, cfg, model_name, epoch, vavg)
            else:
                no_improve_epochs += 1
                if no_improve_epochs >= patience:
                    print(f"Early stopping triggered after {epoch} epochs. Best val_loss={best_val:.6f}")
                    break                          
        
        # Test Evaluation on the Held-Out Storm 
        try:
            checkpoint = torch.load(best_path, map_location=device)
            model.load_state_dict(checkpoint['state_dict'])
            print(f"Loaded best checkpoint for Fold {fold_idx}.")
        except Exception as e:
            print(f"Could not load checkpoint for fold {fold_idx}: {e}. ")

        test_metrics, plot_data = test_model(model, te, device)

        fold_metrics = {
            "Fold": fold_idx,
            "Test_Storm_ID": test_storm,
            "Test_Samples": len(fold['test_df']),
            **test_metrics
        }
        all_fold_metrics.append(fold_metrics)
        
        print(f"\n--- FOLD {fold_idx} RESULTS (Test Storm: {test_storm}) ---")
        for k, v in test_metrics.items():
            print(f"{k}: {v:.4f}")

    # 4. Aggregate and Report Final Results
    results_df = pd.DataFrame(all_fold_metrics)
    mean_metrics = results_df.drop(columns=['Fold', 'Test_Storm_ID', 'Test_Samples']).mean().to_dict()
    std_metrics = results_df.drop(columns=['Fold', 'Test_Storm_ID', 'Test_Samples']).std().to_dict()
    

    print("CROSS-VALIDATION FINAL SUMMARY")
    print("--- Individual Fold Results ---")
    print(results_df.to_markdown(index=False))
    
    print("\n--- Average Metrics (Mean ± Std Dev) ---")
    for k in mean_metrics.keys():
        mean_val = mean_metrics.get(k, np.nan)
        std_val = std_metrics.get(k, 0.0)
        #print(f"{k:<12}: {mean_val:.4f} ± {std_val:.4f}")

    return model, ds_tr.img_cols, train_loss_hist, val_loss_hist, epoch_hist, results_df

# Plot ROC Curve 
def plot_roc_curve(plot_data, roc_auc, model_name="UNet"):
    roc_data = plot_data["ROC_Curve"]
    fpr = roc_data["fpr"]
    tpr = roc_data["tpr"]

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='#0F7CBB', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--', label='Random Guess (AUC = 0.50)')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate (FPR)')
    plt.ylabel('True Positive Rate (TPR)')
    plt.title('ROC Curve for ' f'{model_name} Model')
    plt.legend(loc="lower right")
    plt.grid(True, linestyle='--', alpha=0.6)
    #plt.show() 
    plt.savefig(cfg.results_dir / f"plot_roc_{model_name}.png", bbox_inches="tight")
    plt.close()   
    # 
# Plot Precision-Recall Curve
def plot_pr_curve(plot_data, model_name="AttentionUNet"):
 
    prc_data = plot_data["PR_Curve"]    

    precision = prc_data["precision"]
    recall = prc_data["recall"]
    thresholds = prc_data["thresholds"]
    average_precision = np.trapz(precision, recall)    

    plt.figure(figsize=(6, 5))
    plt.plot(recall, precision, label=f"{model_name} (AP={average_precision:.3f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision–Recall Curve")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    #plt.show()
    plt.savefig(cfg.results_dir / f"plot_pr_{model_name}.png", bbox_inches="tight")
    plt.close()   
    

# Plot BCE vs epochs
def plot_loss_epochs(train_loss_hist, val_loss_hist, epoch_hist, model_name="unet"):      
    plt.figure(figsize=(6,4))                      
    plt.plot(epoch_hist, train_loss_hist, label="Train BCE", marker="o")  
    plt.plot(epoch_hist, val_loss_hist,   label="Val BCE",   marker="s")  
    plt.ylabel("Binary cross-entropy loss")       
    plt.title("BCE loss vs. epochs - " + model_name)              
    plt.grid(True, alpha=0.3)                      
    plt.legend()                                   
    plt.tight_layout()                             
    #plt.show() 
    plt.savefig(cfg.results_dir / f"plot_loss_epoch_{model_name}.png", bbox_inches="tight")
    plt.close()  

#Compares the cross-validation mean and standard deviation for multiple models
def compare_models_cv_results(all_results: T.Dict[str, pd.DataFrame]):

    if not all_results:
        print("No results provided for comparison.")
        return

    print("######### FINAL CROSS-MODEL CV PERFORMANCE SUMMARY ##########")
    comparison_data = []
    
    # 1. Calculate Mean ± Std for each model
    for model_name, df in all_results.items():
        if df.empty:
            print(f"[Warning] Skipping {model_name}: Empty results DataFrame.")
            continue
            
        metrics = df.drop(columns=['Fold', 'Test_Storm_ID', 'Test_Samples'], errors='ignore')
        mean = metrics.mean()
        std = metrics.std()
        
        row = {"Model": model_name}
        for metric_name in mean.index:
            # Format as "Mean ± Std" string
            row[metric_name] = f"{mean[metric_name]:.4f} ± {std[metric_name]:.4f}"
        comparison_data.append(row)

    # 2. Create the final comparison DataFrame
    comparison_df = pd.DataFrame(comparison_data)

    if comparison_df.empty:
        print("No valid model results found for comparison.")
        return

    # 3. Print the comparison table using pure Python string formatting    
    headers = comparison_df.columns.tolist()
    
    # Calculate widths based on header length and max data cell length
    col_widths = {}
    for h in headers:
        max_len = max(comparison_df[h].astype(str).str.len(), default=0)
        col_widths[h] = max(len(h), max_len) + 1 # +1 for padding
    
    # Print header
    header_line = "| " + " | ".join(h.ljust(col_widths[h]) for h in headers) + " |"
    print(header_line)
    print("-" * len(header_line))

    # Print data rows
    for _, row in comparison_df.iterrows():
        row_str = "| "
        for h in headers:
            val = row[h]
            row_str += str(val).ljust(col_widths[h]) + " | "
        print(row_str.strip())
    print("-" * len(header_line))
    # Return the comparison DataFrame as well, in case it's needed later
    return comparison_df

#  Compares final test set metrics for multiple models and a clean table.
def compare_test_metrics(all_test_metrics: T.Dict[str, T.Dict[str, float]]):
    
    if not all_test_metrics:
        print("No test metrics provided for comparison.")
        return

    print("############ FINAL MODEL TEST METRICS SUMMARY #############")

    comparison_data = []
    
    # 1. Prepare data structure
    for model_name, metrics in all_test_metrics.items():
        row = {"Model": model_name}
        for k, v in metrics.items():
            row[k] = f"{v:.4f}"
        comparison_data.append(row)

    comparison_df = pd.DataFrame(comparison_data)
    
    if comparison_df.empty:
        print("No valid model results found for comparison.")
        return
    
    headers = comparison_df.columns.tolist()
  
    col_widths = {}
    for h in headers:
        max_len = max(comparison_df[h].astype(str).str.len(), default=0)
        col_widths[h] = max(len(h), max_len) + 1 # +1 for padding
    

    header_line = "| " + " | ".join(h.ljust(col_widths[h]) for h in headers) + " |"
    print(header_line)
    print("-" * len(header_line))

    # Print data rows
    for _, row in comparison_df.iterrows():
        row_str = "| "
        for h in headers:
            val = row[h]
            row_str += str(val).ljust(col_widths[h]) + " | "
        print(row_str.strip())
    print("-" * len(header_line))

 
    metric_keys = [k for k in headers if k != "Model"]
    if metric_keys:
     
        best_roc_model = max(all_test_metrics.items(), key=lambda item: item[1].get('ROC_AUC', -1))
        print(f"\nInsight: The **{best_roc_model[0]}** model achieved the highest ROC_AUC score of {best_roc_model[1]['ROC_AUC']:.4f}.")
    
    return comparison_df

# Testing function
def test_genesis(cfg: GenesisCfg, model_path: str, model):
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision("high")
    except: pass
    device = DEVICE

    # Load model
    if model_path:
        model, img_cols, cfg_loaded = load_model(Path(model_path), device)    

    # Get datasets for training, validation and test and convert to loaders
    df = build_input_table(cfg, inference=True, cross_validation=False)
    events = load_genesis_events(cfg) if cfg.genesis_csv else None
    if events is not None:        
        df = precompute_masks(df, cfg, events, Path(cfg.out_root)/"mask_cache") 

    ds_te = GenesisDataset(df, cfg, None, split="test")
    te = _mk_loader(ds_te, cfg, False)

    # Test evaluation using the comprehensive test_model function
    test_metrics, plot_data = test_model(model, te, device)  
    
    print("\nTEST METRICS")
    for key, value in test_metrics.items():
        print(f"{key:<12}: {value:.4f}")
    print(f"PR Curve data points: {len(plot_data['PR_Curve']['precision'])}")
    print(f"ROC Curve data points: {len(plot_data['ROC_Curve']['fpr'])}")
    print("--------------------------")

    return model, img_cols, test_metrics, plot_data

#### Train to get robust model

In [None]:
# Call train

model_Unet, img_cols_UNet, train_loss_hist_UNet, val_loss_hist_UNet, epoch_hist_UNet, results_cv_Unet = train_genesis_cv(cfg, model_name="UNet")
path_Unet = save_model(model_Unet, img_cols_UNet, cfg, model_name="UNet_CV")
print("UNet CV finished")

model_AttentionUNet, img_cols_AttentionUNet, train_loss_hist_AttentionUNet, val_loss_hist_AttentionUNet, epoch_hist_AttentionUNet, results_cv_AttentionUNet = train_genesis_cv(cfg, model_name="AttentionUNet")
path_AttentionUNet = save_model(model_AttentionUNet, img_cols_AttentionUNet, cfg, model_name="AttentionUNet_CV")
print("AttentionUNet CV finished")

model_ResSegNet, img_cols_ResSegNet, train_loss_hist_ResSegNet, val_loss_hist_ResSegNet, epoch_hist_ResSegNet, results_cv_ResSegNet = train_genesis_cv(cfg, model_name="ResSegNet")
path_ResSegNet = save_model(model_ResSegNet, img_cols_ResSegNet, cfg, model_name="ResSegNet_CV")
print("ResSegNet CV finished")

all_cv_results = {
    "UNet": results_cv_Unet,
    "AttentionUNet": results_cv_AttentionUNet,
    "ResSegNet": results_cv_ResSegNet,
}

final_comparison_table = compare_models_cv_results(all_cv_results)

#### Train to get best model

In [12]:
final_test_metrics = {}
cfg.epochs = 150

model_AttentionUNet, img_cols_AttentionUNet, train_loss_hist_AttentionUNet, val_loss_hist_AttentionUNet, epoch_hist_AttentionUNet= train_genesis(cfg, model_name="AttentionUNet")
path_AttentionUNet = ""
path_AttentionUNet = save_model(model_AttentionUNet, img_cols_AttentionUNet, cfg, model_name="AttentionUNet")
model_AttentionUNet, test_img_cols_AttentionUNet, test_metrics_AttentionUNet, plot_data_AttentionUNet = test_genesis(cfg, str(path_AttentionUNet), model_AttentionUNet)
final_test_metrics["AttentionUNet"] = test_metrics_AttentionUNet
print("AttentionUNet finished")


# Call train
model_Unet, img_cols_UNet, train_loss_hist_UNet, val_loss_hist_UNet, epoch_hist_UNet= train_genesis(cfg, model_name="UNet")
path_Unet = ""
path_Unet = save_model(model_Unet, img_cols_UNet, cfg, model_name="UNet")
model_Unet, test_img_cols_Unet, test_metrics_Unet, plot_data_Unet = test_genesis(cfg, str(path_Unet), model_Unet)
final_test_metrics["UNet"] = test_metrics_Unet
print("UNet finished")


model_ResSegNet, img_cols_ResSegNet, train_loss_hist_ResSegNet, val_loss_hist_ResSegNet, epoch_hist_ResSegNet= train_genesis(cfg, model_name="ResSegNet")
path_ResSegNet = ""
path_ResSegNet = save_model(model_ResSegNet, img_cols_ResSegNet, cfg, model_name="ResSegNet")
model_ResSegNet, test_img_cols_ResSegNet, test_metrics_ResSegNet, plot_data_ResSegNet = test_genesis(cfg, str(path_ResSegNet), model_ResSegNet)
final_test_metrics["ResSegNet"] = test_metrics_ResSegNet
print("ResSegNet finished")


# --- 2. Compare all final test metrics ---
comparison_table = compare_test_metrics(final_test_metrics)


Columnas sfc: Index(['Storm_id', 'Timestamp', 'Latitude', 'Longitud', 'Genesis', 'img_SLP',
       'img_U10M', 'img_V10M', 'img_T2M', 'img_FROCEAN', 'img_EFLUX',
       'img_HFLUX', 'img_TQL', 'img_TQI', 'img_TQV', 'img_U', 'img_V', 'img_T',
       'img_QV', 'img_OMEGA', 'img_CLOUD'],
      dtype='object')
            Storm_id                time   lat         lon
0    Hurricane_Delta 2020-10-04 18:00:00  17.0  -76.000000
1      Hurricane_Eta 2020-10-31 18:00:00  14.9  -67.099998
2     Hurricane_Goni 2020-10-26 12:00:00  14.1  141.700012
3  Hurricane_Haishen 2020-08-30 00:00:00  14.0  147.000000
4      Hurricane_Ida 2021-08-27 00:00:00  18.3  -80.199997
['C:\\Users\\adrid\\OneDrive\\Documents\\Master UMiami AI\\Topic DS - Computer vision\\Final Project\\Outputs\\mask_cache\\Hurricane_Delta\\mask_Hurricane_Delta_2020-10-04T00-00-00.npy', 'C:\\Users\\adrid\\OneDrive\\Documents\\Master UMiami AI\\Topic DS - Computer vision\\Final Project\\Outputs\\mask_cache\\Hurricane_Delta\\mask_Hurrica

  ckpt = torch.load(path, map_location=device)


Instantiated Attention UNet with C_in=16, Base Dim=32 (Channels fixed)
Columnas sfc: Index(['Storm_id', 'Timestamp', 'Latitude', 'Longitud', 'Genesis', 'img_SLP',
       'img_U10M', 'img_V10M', 'img_T2M', 'img_FROCEAN', 'img_EFLUX',
       'img_HFLUX', 'img_TQL', 'img_TQI', 'img_TQV', 'img_U', 'img_V', 'img_T',
       'img_QV', 'img_OMEGA', 'img_CLOUD'],
      dtype='object')
            Storm_id                time   lat         lon
0    Hurricane_Delta 2020-10-04 18:00:00  17.0  -76.000000
1      Hurricane_Eta 2020-10-31 18:00:00  14.9  -67.099998
2     Hurricane_Goni 2020-10-26 12:00:00  14.1  141.700012
3  Hurricane_Haishen 2020-08-30 00:00:00  14.0  147.000000
4      Hurricane_Ida 2021-08-27 00:00:00  18.3  -80.199997
['C:\\Users\\adrid\\OneDrive\\Documents\\Master UMiami AI\\Topic DS - Computer vision\\Final Project\\Outputs\\mask_cache\\Hurricane_Haishen\\mask_Hurricane_Haishen_2020-08-29T00-00-00.npy', 'C:\\Users\\adrid\\OneDrive\\Documents\\Master UMiami AI\\Topic DS - Compute

#### Plot ROC Curves


In [None]:
plot_roc_curve(plot_data_AttentionUNet, test_metrics_AttentionUNet["ROC_AUC"], model_name="AttentionUNet")

plot_pr_curve(plot_data_AttentionUNet, model_name="AttentionUNet genesis")

plot_roc_curve(plot_data_Unet, test_metrics_Unet["ROC_AUC"], model_name="UNet")

plot_roc_curve(plot_data_ResSegNet, test_metrics_ResSegNet["ROC_AUC"], model_name="ResSegNet")

### Function for inference and others plot result

In [None]:

# Visualize mask and map image from inference results
def show_inference_image(map_image, prob_image, timestamp, hurricane_events="", model_name="unet"):     
    plt.figure(figsize=(8,4))
    
    prob_im = cv2.resize(prob_image, (map_image.shape[1], map_image.shape[0]), interpolation=cv2.INTER_LINEAR)
    

    # Flip vertically (top ⇄ bottom)
    mask = np.flipud(map_image)

    # Flip vertically (top ⇄ bottom)
    arr_out = np.flipud(prob_im)    

    plt.imshow(mask, vmin=0, vmax=0.5, cmap="gray", alpha=0.4)
    plt.imshow(arr_out, vmin=0, vmax=0.5, cmap="turbo", alpha=0.9) 
    plt.colorbar(label="Genesis probability (0–24h)")
    plt.title(f"{model_name} Genesis probability @ {timestamp} - Ground truth case: {hurricane_events}")
    plt.tight_layout()
    #plt.show()
    plt.savefig(cfg.results_dir/f"inference_{str(timestamp.date())}{str(timestamp.time().hour)}.png", bbox_inches="tight")
    plt.close()
                            

# Inference function to generate genesis probability map for a given timestamp to load the images corresponding to that timestamp
@torch.no_grad()
def infer_genesis_map(cfg: GenesisCfg, model: nn.Module, img_cols: List[str], when: Optional[pd.Timestamp] = None, hurricane_events: str = None, model_name: str = "unet"):
    df = build_input_table(cfg, inference=True)   
    df_img_columns = df[[col for col in df.columns if 'img_' in col]]
    if when is None:
        row = df.iloc[-1]
    else:
        row = df.iloc[(df["Timestamp"]-when).abs().argsort().iloc[0]]

    H,W = cfg.image_size

    lat_img, lon_img = get_latlon_grid_cached(cfg)  # same grid for all rows (MERRA-2 static grid)

    root = Path(cfg.out_root)
    chans=[]  
    for col in df_img_columns:
        rel = row.get(col)      
        chans.append(_open_gray(root/rel, (H,W)) if isinstance(rel, str) else np.zeros((H,W), np.float32))
        if(cfg.visual_feature in col):           
            im = Image.open(root/rel).convert("L").resize((W,H), Image.BILINEAR)
            map_image = (np.asarray(im, dtype=np.float32) / 255.0)
            

    x = np.stack(chans, axis=0)[None,...]  
    
    xt = torch.from_numpy(x).to(DEVICE)
    with torch.amp.autocast("cuda", enabled=(DEVICE.type=="cuda" and cfg.use_amp)):
        logits = model(xt); 
        print("Logits shape:", logits.shape)
        prob = torch.sigmoid(logits)[0,0].cpu().numpy()
        print("Probability shape:", prob.shape)

    flat_index = np.argmax(prob)
    flat_minindex = np.argmin(prob)
    target_sum = np.sum(prob)

    r, c = np.unravel_index(flat_index, prob.shape)
    r_min, c_min = np.unravel_index(flat_minindex, prob.shape)

    genesis_latitude = lat_img[r, c]

    genesis_longitude = lon_img[r, c]

    # Find the maximum probability value
    max_prob = prob[r, c]

    print(f"Maximum Probability: {max_prob}")
    print(f"Coordinates (row, column): ({r}, {c})")
    print(f"Minimum Probability: {prob[r_min, c_min]}")
    print(f"Coordinates (row, column): ({r_min}, {c_min})")
    print("Target sum", target_sum)
    print(f"Latitude: {genesis_latitude}° (Range: -90 to 90)")
    print(f"Longitude: {genesis_longitude}° (Range: -180 to 180)")


    show_inference_image(map_image, prob, row["Timestamp"], hurricane_events, model_name)

    return prob, row["Timestamp"]



#### Plot loss vs epochs

In [None]:
plot_loss_epochs(train_loss_hist_UNet, val_loss_hist_UNet, epoch_hist_UNet, model_name="UNet")
plot_loss_epochs(train_loss_hist_AttentionUNet, val_loss_hist_AttentionUNet, epoch_hist_AttentionUNet, model_name="AttentionUNet")
plot_loss_epochs(train_loss_hist_ResSegNet, val_loss_hist_ResSegNet, epoch_hist_ResSegNet, model_name="ResSegNet")


### Call inferences

In [None]:
# Call inferences
hurricane_events = [   
    "Genesis_Ida",
    "Ida",
    "Ida",
    "Ida", 
    "Before Laura",      
    "Genesis_Laura",
    "Laura",
    "Laura",
    "Laura",
    "Laura",
    "Laura",
    "Before_Haishen",
    "Before_Haishen",
    "Genesis_Haishen",
    "Haishen",
    "Haishen",
    "Haishen",
    "No-Hurricane",
    "No-Hurricane",
    "No-Hurricane",
    "No-Hurricane",


]

times = [
    "2021-08-27 00:00:00", # Genesis Ida
    "2021-08-27 03:00:00", # Ida
    "2021-08-28 12:00:00", # Ida
    "2021-08-29 18:00:00", # Ida   
    "2020-08-20 00:00:00", # Before Laura
    "2020-08-20 03:00:00",  # Genesis Laura
    "2020-08-21 12:00:00",  # Laura
    "2020-08-22 18:00:00",  # Laura
    "2020-08-23 00:00:00",  # Laura
    "2020-08-23 18:00:00",  # Laura
    "2020-08-23 18:00:00",  # Laura    
    "2020-08-29 00:00:00",  # Before Heishen 
    "2020-08-29 21:00:00",  # Before Heishen 
    "2020-08-30 00:00:00",  # Genesis Heishen
    "2020-08-30 21:00:00",  # Heishen
    "2020-08-31 00:00:00",  # Heishen
    "2020-09-01 00:00:00",  # Heishen
    "2020-01-01 18:00:00", # No-Hurricane 
    "2020-02-01 18:00:00", # No-Hurricane 
    "2020-02-02 06:00:00", # No-Hurricane 
    "2020-02-03 09:00:00", # No-Hurricane

]

i = 0
for t in times:
    # Pass both keyword arguments (avoid positional after keyword)
    res_unet = infer_genesis_map(cfg, model_Unet, img_cols_UNet, when=pd.Timestamp(t), hurricane_events=hurricane_events[i], model_name="UNet")
    res_AttentionUNet = infer_genesis_map(cfg, model_AttentionUNet, img_cols_AttentionUNet, when=pd.Timestamp(t), hurricane_events=hurricane_events[i], model_name="AttentionUNet")
    res_ResSegNet = infer_genesis_map(cfg, model_ResSegNet, img_cols_ResSegNet, when=pd.Timestamp(t), hurricane_events=hurricane_events[i], model_name="ResSegNet")   
    i += 1


