
# ðŸŒŠ Coastal-Aware SST Reconstruction + **Light DDPM** + Marine Heatwave Forecast (Bay of Bengal)

Colab-ready notebook that demonstrates:
- NOAA OISST download & subset
- Coast-distance feature
- **Coastal-Aware U-Net** for gap-filling + MC-Dropout uncertainty
- **Light DDPM (patch-based)** with U-Net-style denoiser (fast demo: 20â€“50 steps)
- Climatology (2010â€“2020) + **Hobday et al., 2016** MHW detection
- Simple **3D CNN** short-range forecast (14â†’K days)


## ðŸ“¦ Install & Imports

In [None]:

!pip -q install xarray netCDF4 pandas numpy scipy matplotlib tqdm einops torch torchvision scikit-image

import os, math, json, urllib.request, warnings, datetime as dt
import numpy as np, pandas as pd, xarray as xr
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from scipy.ndimage import distance_transform_edt

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings('ignore')


## ðŸ”§ User Controls

In [None]:

# Year & forecast horizon
YEAR_SELECT = 2019
FUTURE_DAYS = 3
assert FUTURE_DAYS in [1,3,5], "FUTURE_DAYS must be one of {1,3,5}"

# Region (Bay of Bengal)
REGION = dict(lon_min=80.0, lon_max=100.0, lat_min=5.0, lat_max=23.0)

# Data years used to build training set for U-Net / diffusion / forecaster
YEARS  = [2019]

# Present-day plots (optional; may skip if slow)
USE_PRESENT_DAY = True

# Reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_DIR = '/content/data'
FIG_DIR  = '/content/figs'
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(FIG_DIR, exist_ok=True)

# U-Net training (fast demo)
PATCH  = 64
STRIDE = 32
EPOCHS = 6
BATCH  = 32
MC_T   = 6

# Forecasting (fast demo): 14-day context -> 1-day model rolled K steps
FCTX_DAYS = 14
F_EPOCHS  = 3
F_BATCH   = 32

# Light DDPM (patch-based)
DDPM_STEPS     = 40          # 20â€“50 recommended for quick demo
DDPM_EPOCHS    = 3           # small epochs for speed
DDPM_BATCH     = 32
DDPM_BETA_SCHED= 'linear'    # 'linear' or 'cosine'
DDPM_MIN_BETA  = 1e-4
DDPM_MAX_BETA  = 0.02

print("DEVICE:", DEVICE)


## 1) Download NOAA OISST (daily) & subset Bay of Bengal

In [None]:

def fetch_oisst_year(year:int):
    url = f"https://downloads.psl.noaa.gov/Datasets/noaa.oisst.v2.highres/sst.day.mean.{year}.nc"
    fn = f"{DATA_DIR}/oisst_{year}.nc"
    if not os.path.exists(fn):
        print('Downloading', url)
        urllib.request.urlretrieve(url, fn)
    return fn

def open_subset_year(year:int):
    fn = fetch_oisst_year(year)
    ds = xr.open_dataset(fn, engine='netcdf4')
    if 'sst' in ds.data_vars:  # normalize var name
        ds = ds.rename({'sst':'SST'})
    ds = ds.sel(lat=slice(REGION['lat_min'], REGION['lat_max']),
                lon=slice(REGION['lon_min'], REGION['lon_max']))
    if not np.issubdtype(ds['time'].dtype, np.datetime64):
        ds = xr.decode_cf(ds)
    return ds

# Open years used for model training
ds_list = [open_subset_year(y) for y in YEARS]
ds = xr.concat(ds_list, dim='time').sortby('time')
print("Dataset shape (time, lat, lon):", ds['SST'].shape)


## 2) Coast Distance (0â€“100 km cap)

In [None]:

raw_sst = ds['SST'].astype('float32')
land_mask = np.isnan(raw_sst.values)   # (T,H,W)
land_any  = np.any(land_mask, axis=0)  # (H,W)
ocean = ~land_any
dist_pix = distance_transform_edt(ocean==1)

lats = ds['lat'].values; lons = ds['lon'].values
dlat_km = 111.0
dlon_km = 111.0*np.cos(np.deg2rad(lats))
dlon_km_grid = np.tile(dlon_km[:,None], (1, len(lons)))
avg_step_km = np.sqrt((dlat_km**2 + dlon_km_grid**2)/2.0)  # RMS-ish coarse step
dist_km = np.clip(dist_pix * avg_step_km, 0, 100)

coast_dist = xr.DataArray(dist_km, coords={'lat': ds['lat'], 'lon': ds['lon']},
                          dims=('lat','lon'), name='coast_km')


## 3) Normalize & Patchify (simulate gaps)

In [None]:

arr = raw_sst.copy().astype('float32').fillna(0.0).clip(min=-2.0, max=35.0)
mu = float(arr.mean()); sigma = float(arr.std())
arrn = (arr - mu) / (sigma + 1e-6)
coast01 = (coast_dist/100.0).clip(0,1)

def extract_patches_with_coast(da3d, coast2d, PATCH=64, STRIDE=32):
    A = da3d.values.astype('float32')   # (T,H,W)
    C = coast2d.values.astype('float32')# (H,W)
    T,H,W = A.shape
    X, coast, times = [], [], []
    for t in range(T):
        F = A[t]
        for i in range(0, H-PATCH+1, STRIDE):
            for j in range(0, W-PATCH+1, STRIDE):
                X.append(F[i:i+PATCH, j:j+PATCH])
                coast.append(C[i:i+PATCH, j:j+PATCH])
                times.append(t)
    X = np.stack(X)[:,None,...]        # (N,1,P,P)
    coast = np.stack(coast)[:,None,...]# (N,1,P,P)
    return X.astype('float32'), coast.astype('float32'), np.array(times)

X_all, C_all, times_idx = extract_patches_with_coast(arrn, coast01, PATCH=PATCH, STRIDE=STRIDE)

rng = np.random.default_rng(SEED)
# keep-prob mask (1 = keep/original, 0 = hole) â€“ we'll provide mask as conditioning
M_all = (rng.random(X_all.shape) < 0.8).astype('float32')
Y_all = X_all.copy()
X_in  = X_all * M_all

# train/val/test split
N = len(X_in); perm = rng.permutation(N)
ntr, nva = int(0.7*N), int(0.85*N)
tr_idx, va_idx, te_idx = perm[:ntr], perm[ntr:nva], perm[nva:]

Xtr, Mtr, Ctr, Ytr = X_in[tr_idx], M_all[tr_idx], C_all[tr_idx], Y_all[tr_idx]
Xva, Mva, Cva, Yva = X_in[va_idx], M_all[va_idx], C_all[va_idx], Y_all[va_idx]
Xte, Mte, Cte, Yte = X_in[te_idx], M_all[te_idx], C_all[te_idx], Y_all[te_idx]

print("Train/Val/Test:", Xtr.shape, Xva.shape, Xte.shape)


## 4) Coastal-Aware U-Net (reconstruction)

In [None]:

class ConvBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c_in, c_out, 3, padding=1), nn.ReLU(),
            nn.Conv2d(c_out, c_out, 3, padding=1), nn.ReLU(),
            nn.Dropout2d(0.2)
        )
    def forward(self,x): return self.conv(x)

class UNetCoastal(nn.Module):
    def __init__(self):
        super().__init__()
        # 3 in-channels: masked SST, mask, coast01
        self.enc1 = ConvBlock(3, 32)
        self.pool = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(32, 64)
        self.enc3 = ConvBlock(64, 128)
        self.dec3 = ConvBlock(128+64, 64)
        self.dec2 = ConvBlock(64+32, 32)
        self.outc = nn.Conv2d(32, 1, 1)
    def forward(self, x, m, c):
        z = torch.cat([x, m, c], dim=1)
        e1 = self.enc1(z)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        d3 = F.interpolate(e3, scale_factor=2, mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e2], dim=1))
        d2 = F.interpolate(d3, scale_factor=2, mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e1], dim=1))
        y  = self.outc(d2)
        return y

model = UNetCoastal().to(DEVICE)

class SSTDset(Dataset):
    def __init__(self, X,M,C,Y): self.X,self.M,self.C,self.Y = X,M,C,Y
    def __len__(self): return len(self.X)
    def __getitem__(self,i):
        return (torch.from_numpy(self.X[i]),
                torch.from_numpy(self.M[i]),
                torch.from_numpy(self.C[i]),
                torch.from_numpy(self.Y[i]))

train_dl = DataLoader(SSTDset(Xtr,Mtr,Ctr,Ytr), batch_size=BATCH, shuffle=True, drop_last=False)
val_dl   = DataLoader(SSTDset(Xva,Mva,Cva,Yva), batch_size=BATCH, shuffle=False, drop_last=False)

opt  = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

def run_epoch_unet(dl, train=True):
    model.train(train)
    tot, n = 0.0, 0
    for xb,mb,cb,yb in dl:
        xb,mb,cb,yb = xb.to(DEVICE), mb.to(DEVICE), cb.to(DEVICE), yb.to(DEVICE)
        pred = model(xb,mb,cb)
        loss = F.l1_loss(pred,yb) + 0.2*F.mse_loss(pred,yb)
        if train:
            opt.zero_grad(); loss.backward(); opt.step()
        tot += loss.item()*len(xb); n += len(xb)
    return tot/max(n,1)

print("Training U-Net (quick)...")
best=float('inf'); best_w=None
for ep in range(1, EPOCHS+1):
    tr = run_epoch_unet(train_dl, True)
    va = run_epoch_unet(val_dl,   False)
    print(f"U-Net ep {ep:02d} | train {tr:.4f} | val {va:.4f}")
    if va<best: best=va; best_w={k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
if best_w is not None: model.load_state_dict(best_w); print("U-Net best val:", best)


## 5) Evaluate & MC-Dropout Uncertainty (U-Net)

In [None]:

test_dl = DataLoader(SSTDset(Xte,Mte,Cte,Yte), batch_size=BATCH, shuffle=False)

def grad_mag(z):
    gx = z[...,1:] - z[...,:-1]
    gy = z[:,:,1:,:] - z[:,:,:-1,:]
    gx = F.pad(gx,(0,1,0,0)); gy = F.pad(gy,(0,0,0,1))
    return torch.sqrt(gx**2 + gy**2 + 1e-8)

def eval_unet(dl, T_mc=MC_T):
    model.train(True)  # keep dropout ON
    maes,rmses,gdiffs=[],[],[]
    with torch.no_grad():
        for xb,mb,cb,yb in dl:
            xb,mb,cb,yb = xb.to(DEVICE), mb.to(DEVICE), cb.to(DEVICE), yb.to(DEVICE)
            preds=[model(xb,mb,cb) for _ in range(T_mc)]
            pred=torch.stack(preds,0).mean(0)
            mae = F.l1_loss(pred,yb).item()
            rmse= torch.sqrt(F.mse_loss(pred,yb)).item()
            gdf = F.l1_loss(grad_mag(pred),grad_mag(yb)).item()
            maes.append(mae); rmses.append(rmse); gdiffs.append(gdf)
    return float(np.mean(maes)), float(np.mean(rmses)), float(np.mean(gdiffs))

mae_u, rmse_u, g_u = eval_unet(test_dl)
print({"U-Net MAE":mae_u, "U-Net RMSE":rmse_u, "U-Net GradDiff":g_u})

def denorm(z): return z*(sigma+1e-6)+mu

# Visualize one sample
i = np.random.randint(0, len(Xte))
xb = torch.from_numpy(Xte[i:i+1]).to(DEVICE)
mb = torch.from_numpy(Mte[i:i+1]).to(DEVICE)
cb = torch.from_numpy(Cte[i:i+1]).to(DEVICE)
yb = torch.from_numpy(Yte[i:i+1]).to(DEVICE)

model.train(True)
preds=[model(xb,mb,cb) for _ in range(MC_T)]
pm = torch.stack(preds,0).mean(0)[0,0].detach().cpu().numpy()
ps = torch.stack(preds,0).std(0)[0,0].detach().cpu().numpy()
gt = yb[0,0].detach().cpu().numpy()
ms = mb[0,0].detach().cpu().numpy()

fig,axs=plt.subplots(1,4,figsize=(14,3.2))
axs[0].imshow(denorm(gt),origin='lower'); axs[0].set_title('GT Â°C')
axs[1].imshow(ms,origin='lower'); axs[1].set_title('Mask')
axs[2].imshow(denorm(pm),origin='lower'); axs[2].set_title('U-Net Pred Â°C')
im=axs[3].imshow(ps*(sigma+1e-6),origin='lower'); axs[3].set_title('Uncertainty Â°C')
plt.colorbar(im,ax=axs[3]); plt.tight_layout(); plt.savefig(f"{FIG_DIR}/unet_recon_uncertainty.png",dpi=200); plt.show()


## 6) Light DDPM (Patch-based) â€” U-Net-like denoiser

In [None]:

def make_beta_schedule(T:int, schedule='linear', beta_start=1e-4, beta_end=0.02):
    if schedule=='linear':
        return torch.linspace(beta_start, beta_end, T)
    elif schedule=='cosine':
        s = 0.008
        t = torch.linspace(0, 1, T+1)
        alphas_cumprod = (torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2)
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 1e-4, 0.02)
    else:
        raise ValueError("Unknown schedule")

BETAS = make_beta_schedule(DDPM_STEPS, DDPM_BETA_SCHED, DDPM_MIN_BETA, DDPM_MAX_BETA).to(DEVICE)
ALPHAS = (1.0 - BETAS)
ALPHAS_BAR = torch.cumprod(ALPHAS, dim=0)  # (T,)

def extract(a, t, x_shape):
    out = a.gather(0, t).float()
    while len(out.shape) < len(x_shape):
        out = out.unsqueeze(-1)
    return out

class DDPMDenoiser(nn.Module):
    def __init__(self):
        super().__init__()
        # 4 channels: x_t, mask, coast, t_map
        self.enc1 = ConvBlock(4, 32)
        self.pool = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(32, 64)
        self.enc3 = ConvBlock(64, 128)
        self.dec3 = ConvBlock(128+64, 64)
        self.dec2 = ConvBlock(64+32, 32)
        self.outc = nn.Conv2d(32, 1, 1)
    def forward(self, x_t, m, c, t):
        B, _, H, W = x_t.shape
        t_norm = (t.float()/ (DDPM_STEPS-1)).view(B,1,1,1).expand(B,1,H,W)
        z = torch.cat([x_t, m, c, t_norm], dim=1)
        e1 = self.enc1(z)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        d3 = F.interpolate(e3, scale_factor=2, mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e2], dim=1))
        d2 = F.interpolate(d3, scale_factor=2, mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e1], dim=1))
        eps = self.outc(d2)
        return eps

ddpm = DDPMDenoiser().to(DEVICE)
opt_d = torch.optim.AdamW(ddpm.parameters(), lr=1e-3, weight_decay=1e-5)

class PatchSet(Dataset):
    def __init__(self, X, M, C, Y):
        self.X, self.M, self.C, self.Y = X, M, C, Y
    def __len__(self): return len(self.X)
    def __getitem__(self,i):
        return (torch.from_numpy(self.X[i]),
                torch.from_numpy(self.M[i]),
                torch.from_numpy(self.C[i]),
                torch.from_numpy(self.Y[i]))

tr_dl_ddpm = DataLoader(PatchSet(Xtr,Mtr,Ctr,Ytr), batch_size=DDPM_BATCH, shuffle=True)
va_dl_ddpm = DataLoader(PatchSet(Xva,Mva,Cva,Yva), batch_size=DDPM_BATCH, shuffle=False)

def q_sample(x0, t, noise):
    sqrt_ab = torch.sqrt(extract(ALPHAS_BAR, t, x0.shape))
    sqrt_one_minus_ab = torch.sqrt(1.0 - extract(ALPHAS_BAR, t, x0.shape))
    return sqrt_ab * x0 + sqrt_one_minus_ab * noise

def train_ddpm_epoch(dl):
    ddpm.train(True)
    total, n = 0.0, 0
    for xb, mb, cb, yb in dl:
        xb, mb, cb, yb = xb.to(DEVICE), mb.to(DEVICE), cb.to(DEVICE), yb.to(DEVICE)
        B = yb.size(0)
        t = torch.randint(0, DDPM_STEPS, (B,), device=DEVICE).long()
        noise = torch.randn_like(yb)
        x_t = q_sample(yb, t, noise)
        eps_hat = ddpm(x_t, mb, cb, t)
        loss = F.mse_loss(eps_hat, noise)
        opt_d.zero_grad(); loss.backward(); opt_d.step()
        total += loss.item()*B; n += B
    return total/max(n,1)

def eval_ddpm_epoch(dl):
    ddpm.train(False)
    total, n = 0.0, 0
    with torch.no_grad():
        for xb, mb, cb, yb in dl:
            xb, mb, cb, yb = xb.to(DEVICE), mb.to(DEVICE), cb.to(DEVICE), yb.to(DEVICE)
            B = yb.size(0)
            t = torch.randint(0, DDPM_STEPS, (B,), device=DEVICE).long()
            noise = torch.randn_like(yb)
            x_t = q_sample(yb, t, noise)
            eps_hat = ddpm(x_t, mb, cb, t)
            loss = F.mse_loss(eps_hat, noise)
            total += loss.item()*B; n += B
    return total/max(n,1)

print("Training Light DDPM (patch-based, quick)...")
best_d=float('inf'); best_dw=None
for ep in range(1, DDPM_EPOCHS+1):
    tr = train_ddpm_epoch(tr_dl_ddpm)
    va = eval_ddpm_epoch(va_dl_ddpm)
    print(f"DDPM ep {ep:02d} | train {tr:.5f} | val {va:.5f}")
    if va < best_d: best_d = va; best_dw = {k:v.detach().cpu().clone() for k,v in ddpm.state_dict().items()}
if best_dw is not None:
    ddpm.load_state_dict(best_dw)
    print("DDPM best val:", best_d)

@torch.no_grad()
def p_sample_step(x_t, m, c, t):
    beta_t     = extract(BETAS, t, x_t.shape)
    alpha_t    = extract(ALPHAS, t, x_t.shape)
    alpha_bar_t= extract(ALPHAS_BAR, t, x_t.shape)
    eps_hat = ddpm(x_t, m, c, t)
    mean = (1.0/torch.sqrt(alpha_t))*(x_t - ((1.0 - alpha_t)/torch.sqrt(1.0 - alpha_bar_t))*eps_hat)
    if t.min() > 0:
        z = torch.randn_like(x_t)
        sigma_t = torch.sqrt(beta_t)
        x_prev = mean + sigma_t * z
    else:
        x_prev = mean
    return x_prev

@torch.no_grad()
def ddpm_sample(m, c, shape):
    B = m.size(0)
    x_t = torch.randn(shape, device=DEVICE)
    for step in reversed(range(DDPM_STEPS)):
        t = torch.full((B,), step, device=DEVICE, dtype=torch.long)
        x_t = p_sample_step(x_t, m, c, t)
    return x_t  # normalized units


## 7) Compare Reconstructions: U-Net vs DDPM (patch)

In [None]:

# take a random test patch
i = np.random.randint(0, len(Xte))
xb = torch.from_numpy(Xte[i:i+1]).to(DEVICE)   # masked input (not directly used by DDPM trainer)
mb = torch.from_numpy(Mte[i:i+1]).to(DEVICE)   # mask channel
cb = torch.from_numpy(Cte[i:i+1]).to(DEVICE)   # coast channel
yb = torch.from_numpy(Yte[i:i+1]).to(DEVICE)   # GT (normalized)

# U-Net MC average
model.train(True)  # keep dropout for MC
preds=[model(xb,mb,cb) for _ in range(MC_T)]
pm_unet = torch.stack(preds,0).mean(0)[0,0].detach().cpu().numpy()
mask_np = mb[0,0].detach().cpu().numpy()
gt_np   = yb[0,0].detach().cpu().numpy()

# DDPM conditional sample
ddpm.eval()
x_gen = ddpm_sample(mb, cb, shape=yb.shape)    # normalized
pm_ddpm = x_gen[0,0].detach().cpu().numpy()

# plot
def denorm(z): return z*(sigma+1e-6)+mu
fig,axs=plt.subplots(1,4,figsize=(15,3.4))
axs[0].imshow(denorm(gt_np), origin='lower'); axs[0].set_title('GT Â°C')
axs[1].imshow(denorm((xb[0,0].cpu().numpy())), origin='lower'); axs[1].set_title('Masked In Â°C')
axs[2].imshow(denorm(pm_unet), origin='lower'); axs[2].set_title('U-Net Recon Â°C')
axs[3].imshow(denorm(pm_ddpm), origin='lower'); axs[3].set_title('DDPM Sample Â°C')
for ax in axs: ax.axis('off')
plt.tight_layout(); plt.savefig(f"{FIG_DIR}/compare_unet_ddpm.png", dpi=200); plt.show()


## 8) Climatology (2010â€“2020) & MHW Detector (Hobday et al., 2016)

In [None]:

base_years = list(range(2010, 2021))
base_list = []
print("Building climatology (may download multiple yearly files) ...")
for y in base_years:
    ds_y = open_subset_year(y)
    base_list.append(ds_y)
clim_ds = xr.concat(base_list, dim='time').sortby('time')

clim = clim_ds['SST'].groupby('time.dayofyear').mean('time', skipna=True)
clim = clim.rolling(dayofyear=11, center=True, min_periods=1).mean()

thresh = clim_ds['SST'].groupby('time.dayofyear').quantile(0.90, dim='time', skipna=True)
thresh = thresh.rolling(dayofyear=11, center=True, min_periods=1).mean()

def detect_mhw(ds_year, clim, thresh):
    da = ds_year['SST'] if 'SST' in ds_year else ds_year
    if not np.issubdtype(da['time'].dtype, np.datetime64):
        da = xr.decode_cf(ds_year)['SST']
    doy = da['time'].dt.dayofyear
    clim_match   = clim.sel(dayofyear=doy)
    thresh_match = thresh.sel(dayofyear=doy)
    anomaly = da - clim_match
    mhw = anomaly.where(da > thresh_match)
    return mhw, anomaly, thresh_match

def longest_consecutive_run(binary_series):
    best = cur = 0
    for v in binary_series:
        if v: cur += 1; best = max(best, cur)
        else: cur = 0
    return best

print(f"Analyzing MHW for {YEAR_SELECT} ...")
ds_y = open_subset_year(YEAR_SELECT)
mhw, anomaly, thresh_match = detect_mhw(ds_y, clim, thresh)

area_frac_per_day = mhw.notnull().mean(dim=("lat","lon")).fillna(0.0)
total_days = float((area_frac_per_day > 0).sum().item())
longest = int(longest_consecutive_run((area_frac_per_day.values > 0.2).astype(np.int32)))
mean_intensity = float(mhw.mean(skipna=True).values)
print(f"{YEAR_SELECT} summary: total_days={total_days:.0f}, longest_run_days={longest}, mean_intensity={mean_intensity:.2f}Â°C")

has_mhw = (area_frac_per_day.values > 0).any()
if has_mhw:
    first_idx = int(np.argmax(area_frac_per_day.values > 0))
    first_day = pd.to_datetime(ds_y.time.values[first_idx]).strftime("%Y-%m-%d")
    plt.figure(figsize=(6,4))
    mhw.isel(time=first_idx).plot(cmap="hot", vmin=0)
    plt.title(f"MHW anomaly map â€” {YEAR_SELECT} â€” {first_day}")
    plt.tight_layout(); plt.savefig(f"{FIG_DIR}/MHW_map_{YEAR_SELECT}.png", dpi=200); plt.show()
else:
    print(f"No MHW days found in {YEAR_SELECT} (90th percentile threshold).")


## 9) Present-day anomaly & MHW map (optional)

In [None]:

if USE_PRESENT_DAY:
    try:
        cur_year = dt.date.today().year
        ds_today = open_subset_year(cur_year)
        latest_time = str(ds_today.time.values[-1])[:10]
        mhw_today, anomaly_today, thresh_today = detect_mhw(ds_today, clim, thresh)

        plt.figure(figsize=(10,4))
        anomaly_today.isel(time=-1).plot(cmap="RdBu_r", vmin=-3, vmax=3)
        plt.title(f"SST Anomaly â€” {latest_time}")
        plt.tight_layout(); plt.savefig(f"{FIG_DIR}/SST_anomaly_present.png", dpi=200); plt.show()

        plt.figure(figsize=(6,4))
        mhw_today.isel(time=-1).plot(cmap="hot", vmin=0)
        plt.title(f"MHW Mask â€” {latest_time}")
        plt.tight_layout(); plt.savefig(f"{FIG_DIR}/MHW_present.png", dpi=200); plt.show()
        print('Present-day processed:', latest_time)
    except Exception as e:
        print('âš  Present-day step skipped:', e)


## 10) Short-range Forecast (3D CNN): past 14 â†’ next day, rolled K steps

In [None]:

A = arrn.values.astype('float32')  # normalized (T,H,W) for YEARS
T,H,W = A.shape
T_in = FCTX_DAYS

def build_sequences(A, T_in=14):
    Xseq, Yt = [], []
    for t in range(T_in, len(A)-1):
        Xseq.append(A[t-T_in:t])
        Yt.append(A[t+1])
    if len(Xseq)==0:
        raise RuntimeError("Not enough days to build sequences. Increase YEARS or reduce FCTX_DAYS.")
    Xseq = np.stack(Xseq)[:,None,...]  # (N,1,T_in,H,W)
    Yt   = np.stack(Yt)[:,None,...]    # (N,1,H,W)
    return Xseq.astype('float32'), Yt.astype('float32')

Xseq, Yseq = build_sequences(A, T_in=T_in)
N = len(Xseq); idx = np.arange(N)
rng = np.random.default_rng(SEED); rng.shuffle(idx)
ntr = int(0.8*N)
tr_idx, va_idx = idx[:ntr], idx[ntr:] if N-ntr>0 else (idx, idx)
Xtr_s, Ytr_s = Xseq[tr_idx], Yseq[tr_idx]
Xva_s, Yva_s = Xseq[va_idx], Yseq[va_idx]

class CNN3DForecaster(nn.Module):
    def __init__(self, T_in=14):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(1, 8, kernel_size=(3,3,3), padding=(1,1,1)), nn.ReLU(),
            nn.Conv3d(8,16, kernel_size=(3,3,3), padding=(1,1,1)), nn.ReLU(),
            nn.Conv3d(16,16, kernel_size=(3,3,3), padding=(1,1,1)), nn.ReLU()
        )
        self.head = nn.Sequential(
            nn.Conv3d(16, 8, kernel_size=(T_in,1,1)), nn.ReLU(),
            nn.Conv3d(8, 1, kernel_size=1)
        )
    def forward(self, x):  # (B,1,T,H,W)
        z = self.net(x)
        y = self.head(z)   # (B,1,1,H,W)
        return y.squeeze(2)

f_model = CNN3DForecaster(T_in=T_in).to(DEVICE)
optf = torch.optim.AdamW(f_model.parameters(), lr=1e-3, weight_decay=1e-5)

class SeqSet(Dataset):
    def __init__(self, X,Y): self.X,self.Y = X,Y
    def __len__(self): return len(self.X)
    def __getitem__(self,i): return torch.from_numpy(self.X[i]), torch.from_numpy(self.Y[i])

tr_dl_s = DataLoader(SeqSet(Xtr_s, Ytr_s), batch_size=F_BATCH, shuffle=True)
va_dl_s = DataLoader(SeqSet(Xva_s, Yva_s), batch_size=F_BATCH, shuffle=False)

def run_epoch_fore(dl, train=True):
    f_model.train(train)
    tot, n = 0.0, 0
    for xb,yb in dl:
        xb,yb = xb.to(DEVICE), yb.to(DEVICE)
        pred = f_model(xb)
        loss = F.l1_loss(pred,yb) + 0.2*F.mse_loss(pred,yb)
        if train: optf.zero_grad(); loss.backward(); optf.step()
        tot += loss.item()*len(xb); n += len(xb)
    return tot/max(n,1)

print("Training forecaster (quick)...")
bestf=float('inf'); bestfw=None
for ep in range(1, F_EPOCHS+1):
    tr = run_epoch_fore(tr_dl_s, True)
    va = run_epoch_fore(va_dl_s, False)
    print(f"[Forecast] ep {ep:02d} | train {tr:.4f} | val {va:.4f}")
    if va<bestf: bestf=va; bestfw={k:v.detach().cpu().clone() for k,v in f_model.state_dict().items()}
if bestfw is not None:
    f_model.load_state_dict(bestfw)
    print("Forecast best val:", bestf)

@torch.no_grad()
def forecast_multi(model, A_last, steps=3, T_in=FCTX_DAYS):
    preds = []
    seq = A_last.copy()
    for _ in range(steps):
        x = torch.from_numpy(seq[None,None,...]).to(DEVICE)
        yhat = model(x).cpu().numpy()[0,0]
        preds.append(yhat)
        seq = np.concatenate([seq[1:], yhat[None,...]], axis=0)
    return np.array(preds)

# build last-seq from selected year
ds_y = open_subset_year(YEAR_SELECT)
sst_y = ds_y['SST'].astype('float32').fillna(0.0).clip(min=-2.0, max=35.0)
A_year = ((sst_y - mu) / (sigma + 1e-6)).values.astype('float32')
if A_year.shape[0] < FCTX_DAYS+1:
    raise RuntimeError("Chosen year lacks days; reduce FCTX_DAYS or choose another year.")

last_seq = A_year[-FCTX_DAYS:]  # (14,H,W)
preds_n = forecast_multi(f_model, last_seq, steps=int(FUTURE_DAYS), T_in=FCTX_DAYS)
preds_c = preds_n*(sigma+1e-6)+mu

last_date = pd.to_datetime(ds_y.time.values[-1])
future_times = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=int(FUTURE_DAYS), freq='D')

pred_da = xr.DataArray(preds_c, coords={'time': future_times, 'lat': ds_y['lat'], 'lon': ds_y['lon']},
                       dims=('time','lat','lon'), name='SST_pred')

# detect MHW on forecast fields
mhw_list = []; anom_list=[]
for tt in range(int(FUTURE_DAYS)):
    date = future_times[tt]; doy  = int(date.dayofyear)
    clim_d = clim.sel(dayofyear=doy, method='nearest')
    thr_d  = thresh.sel(dayofyear=doy, method='nearest')
    sst_d  = pred_da.isel(time=tt)
    anom_d = sst_d - clim_d
    mhw_d  = anom_d.where(sst_d > thr_d)
    anom_list.append(anom_d); mhw_list.append(mhw_d)

anom_future = xr.concat(anom_list, dim='time'); anom_future['time'] = future_times
mhw_future  = xr.concat(mhw_list,  dim='time'); mhw_future['time']  = future_times

for i, t in enumerate(future_times):
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1); pred_da.isel(time=i).plot(cmap='jet'); plt.title(f'Forecast SST (Â°C)\n{t.date()}')
    plt.subplot(1,3,2); anom_future.isel(time=i).plot(cmap='RdBu_r', vmin=-3, vmax=3); plt.title('Forecast Anomaly (Â°C)')
    plt.subplot(1,3,3); mhw_future.isel(time=i).plot(cmap='hot', vmin=0); plt.title('Forecast MHW mask')
    plt.tight_layout(); plt.savefig(f"{FIG_DIR}/forecast_{YEAR_SELECT}_day{i+1}.png", dpi=200); plt.show()

area_frac = mhw_future.notnull().mean(dim=('lat','lon')).fillna(0.0).to_pandas()
summary = pd.DataFrame({'date': future_times, 'mhw_area_fraction': area_frac.values})
summary.to_csv(f"{FIG_DIR}/forecast_MHW_summary_{YEAR_SELECT}.csv", index=False)
print("Saved:")
print(f" - Comparison plot: {FIG_DIR}/compare_unet_ddpm.png")
print(f" - MHW map (example year): {FIG_DIR}/MHW_map_{YEAR_SELECT}.png (if any)")
print(f" - Forecast per-day: {FIG_DIR}/forecast_{YEAR_SELECT}_day*.png")
print(f" - Forecast summary CSV: {FIG_DIR}/forecast_MHW_summary_{YEAR_SELECT}.csv")
