In [None]:
"""
NOTE: 
Please change input directory before use. 
"""

import torch

In [2]:
import argparse
import os
import sys
import glob
import shutil
import zipfile
from pathlib import Path
from typing import Optional, List
import numpy as np
import rasterio
from rasterio.enums import Resampling
from tqdm import tqdm


In [3]:
# Path to your folder
folder = "fwi"

# Find all .tif files
tif_files = glob.glob(f"{folder}/*.tif")

# Read each file
for tif in tif_files:
    with rasterio.open(tif) as src:
        array = src.read().astype(float)  # shape: (bands, height, width)
        profile = src.profile

        # --- Robust NoData / overflow masking ---
        # Prefer the file's explicit nodata; otherwise fall back to float32 minimum
        nodata = src.nodata if src.nodata is not None else np.finfo(np.float32).min

        # Build a mask that catches:
        # 1) Non-finite values (nan, inf)
        # 2) Exact nodata (if provided)
        # 3) Any extreme sentinels near +/- 1e38 (covers Windows/Linux rounding)
        bad = (~np.isfinite(array)) | (array == nodata) | (np.abs(array) >= 1e38)

        # Replace bad with NaN so nan-aware math works
        array[bad] = np.nan

        # OPTIONAL: collapse tiny magnitudes to 0 after masking (tune eps to your units)
        eps = 1e-12
        array[np.abs(array) < eps] = 0.0

        # Example: safe stats (use np.nan* functions)
        band_means = np.nanmean(array, axis=(1, 2))
        band_stds  = np.nanstd(array, axis=(1, 2))

        # Print quick verification and metadata
        print(f"{tif}: shape={array.shape}, crs={profile['crs']}")
        print("Has extreme leftovers?",
              bool(np.any(np.isfinite(array) & (np.abs(array) >= 1e38))))
        print("Band means:", band_means)
        print("Band stds :", band_stds)
        print("-" * 60)

fwi\fwi_20250701.tif: shape=(1, 2281, 2709), crs=EPSG:3978
Has extreme leftovers? False
Band means: [8.01814767]
Band stds : [9.04589513]
------------------------------------------------------------
fwi\fwi_20250702.tif: shape=(1, 2281, 2709), crs=EPSG:3978
Has extreme leftovers? False
Band means: [8.30267814]
Band stds : [9.37343538]
------------------------------------------------------------
fwi\fwi_20250703.tif: shape=(1, 2281, 2709), crs=EPSG:3978
Has extreme leftovers? False
Band means: [7.21104751]
Band stds : [8.43844756]
------------------------------------------------------------
fwi\fwi_20250704.tif: shape=(1, 2281, 2709), crs=EPSG:3978
Has extreme leftovers? False
Band means: [7.80738059]
Band stds : [8.60878804]
------------------------------------------------------------
fwi\fwi_20250705.tif: shape=(1, 2281, 2709), crs=EPSG:3978
Has extreme leftovers? False
Band means: [7.9544012]
Band stds : [6.96468958]
------------------------------------------------------------
fwi\fw

In [4]:
import argparse, os, re, math, glob, random, warnings
from datetime import datetime
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [5]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def parse_date_from_filename(path):
    m = re.search(r'(20\d{6})', os.path.basename(path))
    return datetime.strptime(m.group(1), "%Y%m%d") if m else None

print("STEP 2")

STEP 2


In [19]:
def load_stack(data_dir):
    """
    Returns:
      frames: np.ndarray (T, H, W) float64, with NoData masked to NaN
      dates : list[datetime]
      files : list[str]
    """
    files = sorted(
        glob.glob(os.path.join(data_dir, "*.tif")),
        key=lambda p: parse_date_from_filename(p) or datetime.min
    )
    if not files:
        raise FileNotFoundError(f"No .tif files in {data_dir}")

    imgs, dates = [], []
    for p in files:
        d = parse_date_from_filename(p)
        if d is None:
            continue

        with rasterio.open(p) as src:
            a = src.read(1).astype("float64")  # (H, W) first band
            nodata = src.nodata if src.nodata is not None else np.finfo(np.float32).min
            # prefer explicit mask if present
            try:
                msk = src.read_masks(1)
            except Exception:
                msk = None

        # Build a robust "bad" mask
        bad = (~np.isfinite(a)) | (np.abs(a) >= 1e38) | ((nodata is not None) & (a == nodata))
        if msk is not None:
            bad = bad | (msk == 0)

        a[bad] = np.nan
        imgs.append(a)
        dates.append(d)

    frames = np.stack(imgs, axis=0)  # (T, H, W) float64 with NaNs for missing
    return frames, dates, files

def standardize(train_frames, all_frames, eps=1e-6):
    """
    EXACT semantics as your prof:
      - single global mean (scalar) over train_frames
      - single global std   (scalar) over train_frames
      - z = (all_frames - mu) / sd
    Added: robust to NaN/Inf/sentinels; epsilon guard for sd.
    """
    fr = train_frames.astype(np.float64, copy=False)
    fa = all_frames.astype(np.float64, copy=False)

    # guard (just in case anything slipped through)
    fr[(~np.isfinite(fr)) | (np.abs(fr) >= 1e38)] = np.nan
    fa[(~np.isfinite(fa)) | (np.abs(fa) >= 1e38)] = np.nan

    mu = float(np.nanmean(fr, dtype=np.float64))
    sd = float(np.nanstd (fr, dtype=np.float64))
    if (not np.isfinite(sd)) or (sd < eps):
        sd = eps

    z = (fa - mu) / sd
    z = np.nan_to_num(z, nan=0.0, posinf=0.0, neginf=0.0)
    return z, mu, sd

def _is_torch_tensor(x):
    return "torch" in type(x).__module__

def patchify(frames, patch: int):
    # frames: (D, H, W)
    D, H, W = frames.shape
    Hc = H - (H % patch)
    Wc = W - (W % patch)
    frames_c = frames[:, :Hc, :Wc]
    nph, npw = Hc // patch, Wc // patch

    if _is_torch_tensor(frames_c):
        # (D, nph, patch, npw, patch) -> (nph, npw, D, patch, patch) -> (nph*npw, D, patch*patch)
        x = (
            frames_c.reshape(D, nph, patch, npw, patch)
            .permute(1, 3, 0, 2, 4)     # PyTorch: permute for multi-axis reorder
            .reshape(nph * npw, D, patch * patch)
        )
    else:
        x = (
            frames_c.reshape(D, nph, patch, npw, patch)
            .transpose(1, 3, 0, 2, 4)   # NumPy: transpose accepts full axis order
            .reshape(nph * npw, D, patch * patch)
        )
    return x, (Hc, Wc), (nph, npw)

def depatchify(patches, grid, patch: int, hw):
    # patches: (Npatch, D, P*P)
    nph, npw = grid
    D = patches.shape[1]
    Hc, Wc = hw

    if _is_torch_tensor(patches):
        x = (
            patches.reshape(nph, npw, D, patch, patch)
            .permute(2, 0, 3, 1, 4)     # -> (D, nph, patch, npw, patch)
            .contiguous()
            .reshape(D, nph * patch, npw * patch)
        )
    else:
        x = (
            patches.reshape(nph, npw, D, patch, patch)
            .transpose(2, 0, 3, 1, 4)   # -> (D, nph, patch, npw, patch)
            .reshape(D, nph * patch, npw * patch)
        )

    return x[:, :Hc, :Wc]  # (D, Hc, Wc)

def build_windows(frames, in_days, out_days):
    T = frames.shape[0]
    windows = []
    for t in range(T - in_days - out_days + 1):
        windows.append((t, t+in_days, t+in_days+out_days))
    return windows  # list of (start, mid, end)

In [20]:
class FWIWeeklyDataset(Dataset):
    def __init__(self, frames, windows, in_days, out_days, patch, augment=False):
        self.frames = frames
        self.windows = windows
        self.in_days = in_days
        self.out_days = out_days
        self.patch = patch
        self.augment = augment
        # Prepare all patch sequences upfront to keep __getitem__ simple
        X_list, Y_list = [], []
        self.grid = None
        self.hw = None
        for (s, m, e) in windows:
            X = frames[s:m]           # (in_days,H,W)
            Y = frames[m:e]           # (out_days,H,W)
            Xp, hw, grid = patchify(X, patch)
            Yp, _, _   = patchify(Y, patch)
            if self.grid is None:
                self.hw, self.grid = hw, grid
            # Shape: (Npatch, D, P*P)
            X_list.append(Xp)
            Y_list.append(Yp)
        # Concatenate over time windows -> each row = one patch trajectory
        self.X = np.concatenate(X_list, axis=0)  # (N, in_days, P*P)
        self.Y = np.concatenate(Y_list, axis=0)  # (N, out_days, P*P)
        # Optional simple flip augmentation
        if augment and len(self.X) > 0:
            Xf = self.X.copy().reshape(-1, self.in_days, int(math.sqrt(self.X.shape[-1])), int(math.sqrt(self.X.shape[-1])))
            Yf = self.Y.copy().reshape(-1, self.out_days, int(math.sqrt(self.Y.shape[-1])), int(math.sqrt(self.Y.shape[-1])))
            Xf = Xf[..., ::-1, :].reshape(self.X.shape)  # horizontal flip
            Yf = Yf[..., ::-1, :].reshape(self.Y.shape)
            self.X = np.concatenate([self.X, Xf], axis=0)
            self.Y = np.concatenate([self.Y, Yf], axis=0)

    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])  # (in_days, P*P)
        y = torch.from_numpy(self.Y[idx])  # (out_days, P*P)
        return x, y

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(1))  # (max_len,1,d_model)
    def forward(self, x):  # x: (S,N,E)
        S = x.size(0)
        return x + self.pe[:S]

class PatchTemporalTransformer(nn.Module):
    def __init__(self, patch_dim, d_model=128, nhead=4, num_enc=2, num_dec=2, in_days=7, out_days=7, dropout=0.1):
        super().__init__()
        self.in_days = in_days
        self.out_days = out_days
        self.embed = nn.Linear(patch_dim, d_model)
        self.pos = PositionalEncoding(d_model, max_len=max(in_days, out_days)+32)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_enc, num_decoder_layers=num_dec, dim_feedforward=d_model*4, dropout=dropout, batch_first=False)
        self.query_embed = nn.Parameter(torch.randn(out_days, d_model))
        self.proj = nn.Linear(d_model, patch_dim)

    def forward(self, src_seq):  # src_seq: (B, in_days, P*P)
        B, S, Dp = src_seq.shape
        src = self.embed(src_seq)              # (B,S,E)
        src = src.transpose(0,1)               # (S,B,E)
        src = self.pos(src)
        tgt = self.query_embed.unsqueeze(1).repeat(1, B, 1)  # (T,B,E)
        tgt = self.pos(tgt)
        mem = self.transformer.encoder(src)
        out = self.transformer.decoder(tgt, mem)             # (T,B,E)
        out = out.transpose(0,1)                              # (B,T,E)
        rec = self.proj(out)                                  # (B,T,P*P)
        return rec

In [21]:
set_seed(42)
os.makedirs('Jin_fwi')
print("STEP 3")

STEP 3


In [22]:
frames, dates, files = load_stack('fwi')
print("STEP 4")

STEP 4


In [23]:
T,H,W = frames.shape

In [24]:
cutoff = int(T * 0.7)

# Simple health checks BEFORE standardize
ref = frames[:cutoff].astype('float64', copy=False)  # the stack used to compute mu/sd
all_ = frames.astype('float64', copy=False)          # the stack to standardize

print("REF shape:", ref.shape, " | ALL shape:", all_.shape)

print("REF nonfinite:", np.count_nonzero(~np.isfinite(ref)))
print("REF extremes (|x|>=1e38):", np.count_nonzero(np.abs(ref) >= 1e38))
print("REF finite min/max:", np.nanmin(ref), np.nanmax(ref))

print("ALL nonfinite:", np.count_nonzero(~np.isfinite(all_)))
print("ALL extremes (|x|>=1e38):", np.count_nonzero(np.abs(all_) >= 1e38))
print("ALL finite min/max:", np.nanmin(all_), np.nanmax(all_))

# ---- your existing line goes immediately after this ----
cutoff = int(T*0.7) # we need to tune it
frames_std, mu, sd = standardize(frames[:cutoff], frames)
frames_std = np.nan_to_num(frames_std, nan=0.0, posinf=0.0, neginf=0.0)
print("STEP 5")

REF shape: (21, 2281, 2709)  | ALL shape: (31, 2281, 2709)
REF nonfinite: 85973713
REF extremes (|x|>=1e38): 0
REF finite min/max: 3.9805106433199455e-10 134.33778381347656
ALL nonfinite: 126854904
ALL extremes (|x|>=1e38): 0
ALL finite min/max: 3.9805106433199455e-10 134.33778381347656
STEP 5


In [25]:
np.save(os.path.join('Jin_fwi', "norm_stats.npy"), np.array([mu, sd], dtype=np.float32))
print("STEP 6")

STEP 6


In [26]:
windows = build_windows(frames_std, 7, 7)
if not windows:
    raise RuntimeError("Not enough time windows to build a dataset. Add more daily rasters.")
print("STEP 7")

STEP 7


In [27]:
split_idx = int(len(windows)*(1-0.2))
train_wins, val_wins = windows[:split_idx], windows[split_idx:]
print("STEP 8")

STEP 8


In [28]:
train_ds = FWIWeeklyDataset(frames_std, train_wins, 7, 7, 16, augment=True)
val_ds   = FWIWeeklyDataset(frames_std, val_wins,   7, 7, 16, augment=False)
print("STEP 9")

STEP 9


In [29]:
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
val_dl   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=0, drop_last=False)
print("STEP 10")

STEP 10


In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
patch_dim = train_ds.X.shape[-1]
model = PatchTemporalTransformer(patch_dim, d_model=128, nhead=4, num_enc=2, num_dec=2, in_days=7, out_days=7).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=2e-4)
loss_fn = nn.MSELoss()
print("STEP 11")



STEP 11


In [32]:
print("ping", flush = True)
def run_epoch(loader, train=True):
    model.train(train)
    total = 0.0

    for xb, yb in loader:
        xb = xb.to(device).float()
        yb = yb.to(device).float()
        if torch.isnan(xb).any() or torch.isnan(yb).any():
            raise RuntimeError("Found NaNs in batch; check preprocessing/masks.")
        with torch.set_grad_enabled(train):
            pred = model(xb)
            loss = loss_fn(pred, yb)
            if train:
                opt.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)

ping


In [33]:
best_val = float("inf")
for epoch in range(1, 12):
    tr = run_epoch(train_dl, True)
    va = run_epoch(val_dl, False)
    print(f"Epoch {epoch}: train {tr:.4f}  val {va:.4f}")
    if va < best_val:
        best_val = va
        torch.save(model.state_dict(), os.path.join('Jin_fwi', "best_model.pt"))
print("STEP 12")

Epoch 1: train 0.2223  val 0.1998
Epoch 2: train 0.2014  val 0.2078
Epoch 3: train 0.1868  val 0.2170
Epoch 4: train 0.1747  val 0.2301
Epoch 5: train 0.1651  val 0.2234
Epoch 6: train 0.1574  val 0.2417
Epoch 7: train 0.1507  val 0.2305
Epoch 8: train 0.1450  val 0.2339
Epoch 9: train 0.1402  val 0.2423
Epoch 10: train 0.1355  val 0.2362
Epoch 11: train 0.1318  val 0.2397
STEP 12


In [34]:
# === minimal-diff, same logic ===
in_days = 7
patch   = 16

model.eval()  # harmless for inference

last_start = windows[-1][0]
X_last = frames_std[last_start:last_start+in_days]  # (in_days, H, W)

# Use train_ds utilities to patchify and depatchify
Xp, hw, grid = patchify(X_last, patch)
with torch.no_grad():
    x = torch.from_numpy(Xp).to(device, non_blocking=True).float()
    pred = model(x).detach().cpu().numpy()  # (N, out_days, P*P)

Y_pred = depatchify(torch.from_numpy(pred).float(), grid, patch, hw).numpy()

# --- exact same de-standardization formula, with tiny guard + silence invalid warnings ---
import numpy as np
eps = 1e-8
sd_safe = sd if (isinstance(sd, (float, np.floating)) and np.isfinite(sd) and sd >= eps) else eps
with np.errstate(invalid='ignore'):
    Y_pred_den = Y_pred * sd_safe + mu  # same logic as before

# Save each predicted day as TIFF preview (0â€“255)
import os
from PIL import Image

os.makedirs('Jin_fwi', exist_ok=True)
for i in range(Y_pred_den.shape[0]):
    arr = Y_pred_den[i]
    # simple min-max to 0-255 for visualization (nan-aware like before)
    vmin, vmax = np.nanpercentile(arr, 1), np.nanpercentile(arr, 99)
    if (not np.isfinite(vmin)) or (not np.isfinite(vmax)) or (vmax <= vmin):
        # fallback for all-NaN or flat images
        vmin, vmax = 0.0, 1.0
    img = np.clip((arr - vmin) / (vmax - vmin + 1e-6), 0, 1)
    im = Image.fromarray((img * 255).astype(np.uint8))
    im.save(os.path.join('Jin_fwi', f"forecast_day{i+1}.tiff"))

print("Saved predictions to Jin_fwi")

Saved predictions to Jin_fwi
