# FWI prediction (LSTM)

This notebook is a **FWI-only** version: it trains an LSTM on historical FWI GeoTIFFs and forecasts the next timestep grid.


In [None]:
import torch

In [None]:
import os
import math
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.io import DatasetReader
import numpy as np
from datetime import datetime

In [None]:
OUT_DIR="Jin_fwi2"
os.makedirs(OUT_DIR, exist_ok=True)

# Choose a resampling for continuous meteorological fields:
RESAMPLE_METHOD = Resampling.bilinear   # use Resampling.nearest for categorical fields

# Weâ€™ll standardize nodata to float32 NaN in outputs
DST_NODATA = np.nan
DST_DTYPE = "float32"


## FWI raster utilities

In [None]:
import re
from typing import Dict, List, Optional, Tuple
import os


import numpy as np
import rasterio
import zipfile


def unzip_fwi_zip(zip_path: str, out_dir: str) -> str:
    """Unzip a FWI GeoTIFF archive into out_dir. Returns out_dir."""
    os.makedirs(out_dir, exist_ok=True)
    with zipfile.ZipFile(zip_path, "r") as z:
        z.extractall(out_dir)
    print("Unzipped to:", out_dir)
    return out_dir



def _parse_date_from_path(path: str) -> Optional[str]:
    """Try to extract YYYYMMDD or YYYYMMDDTHHMM from a filename."""
    m = re.search(r'(\d{8}T\d{4})', path)
    if m:
        return m.group(1)
    m = re.search(r'(\d{8})', path)
    if m:
        return m.group(1)
    return None


def load_single_raster(path: str, band: int = 1, to_float32: bool = True) -> Tuple[np.ndarray, dict]:
    """Load a GeoTIFF band to a numpy array, converting nodata to NaN."""
    with rasterio.open(path) as ds:
        arr = ds.read(band)
        profile = ds.profile.copy()
        nodata = ds.nodata

    if to_float32:
        arr = arr.astype(np.float32, copy=False)

    if nodata is not None and not (isinstance(nodata, float) and np.isnan(nodata)):
        arr = np.where(arr == np.float32(nodata), np.nan, arr)

    return arr, profile


def load_raster_stack(paths: List[str], band: int = 1) -> Tuple[np.ndarray, List[str], dict]:
    """
    Load rasters into a stack of shape [T, H, W].

    Parameters
    ----------
    paths : list[str]
        GeoTIFF paths.

    Returns
    -------
    stack : np.ndarray
        float32 with NaN nodata, shape [T, H, W]
    time_keys : list[str]
        Parsed date/time keys (best-effort) aligned with T.
    ref_profile : dict
        Profile from the first raster (useful for writing outputs).
    """
    if len(paths) == 0:
        raise ValueError('paths is empty')

    # Sort by extracted time key if possible
    keyed = []
    for p in paths:
        k = _parse_date_from_path(p) or p
        keyed.append((k, p))
    keyed.sort(key=lambda x: x[0])

    time_keys = [k for k, _ in keyed]
    sorted_paths = [p for _, p in keyed]

    first, ref_profile = load_single_raster(sorted_paths[0], band=band)
    T = len(sorted_paths)
    H, W = first.shape

    stack = np.empty((T, H, W), dtype=np.float32)
    stack[0] = first

    for i, p in enumerate(sorted_paths[1:], start=1):
        a, prof = load_single_raster(p, band=band)
        if a.shape != (H, W):
            raise ValueError(f'Shape mismatch for {p}: got {a.shape}, expected {(H, W)}')
        stack[i] = a

    return stack, time_keys, ref_profile


def fit_standard_scaler(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Fit a simple per-feature standard scaler on x with shape [N, L, F]."""
    flat = x.reshape(-1, x.shape[-1])
    mean = np.nanmean(flat, axis=0)
    std = np.nanstd(flat, axis=0)
    std = np.where(std < 1e-6, 1.0, std)
    return mean.astype(np.float32), std.astype(np.float32)


def apply_standard_scaler(x: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    return ((x - mean) / std).astype(np.float32, copy=False)


def sample_sequences(
    y_stack: np.ndarray,
    x_stack: Optional[np.ndarray] = None,
    lookback: int = 14,
    horizon: int = 1,
    n_samples: int = 200_000,
    seed: int = 42,
    require_finite: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Randomly sample training sequences from rasters.

    y_stack: [T,H,W] target FWI.
    x_stack: optional [T,H,W,F] features aligned to y_stack timesteps.

    Returns
    -------
    X: [N, lookback, n_features]
    y: [N, horizon]

    Notes
    -----
    - If x_stack is provided, per-timestep features become [FWI, x1..xF] by default.
    - NaNs are treated as missing; windows containing NaNs are skipped.
    """
    rng = np.random.default_rng(seed)
    T, H, W = y_stack.shape
    if T < lookback + horizon:
        raise ValueError('Not enough timesteps for given lookback+horizon')

    F = 0 if x_stack is None else int(x_stack.shape[-1])
    n_features = 1 + F

    X = np.empty((n_samples, lookback, n_features), dtype=np.float32)
    y = np.empty((n_samples, horizon), dtype=np.float32)

    max_t = T - horizon
    filled = 0
    max_tries = n_samples * 30

    for _ in range(max_tries):
        if filled >= n_samples:
            break

        t = int(rng.integers(lookback, max_t))
        r = int(rng.integers(0, H))
        c = int(rng.integers(0, W))

        y_hist = y_stack[t - lookback:t, r, c]
        y_fut = y_stack[t:t + horizon, r, c]

        if require_finite and (np.any(~np.isfinite(y_hist)) or np.any(~np.isfinite(y_fut))):
            continue

        if x_stack is not None:
            x_hist = x_stack[t - lookback:t, r, c, :]  # [L,F]
            if require_finite and np.any(~np.isfinite(x_hist)):
                continue
            feat = np.concatenate([y_hist[:, None], x_hist], axis=1)
        else:
            feat = y_hist[:, None]

        X[filled] = feat
        y[filled] = y_fut
        filled += 1

    if filled == 0:
        raise RuntimeError('Could not sample any valid sequences (too many NaNs / no overlap).')

    return X[:filled], y[:filled]

## LSTM model + training utilities

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, Optional

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class SequenceDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = torch.from_numpy(X).float()
        self.y = torch.from_numpy(y).float()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx: int):
        return self.X[idx], self.y[idx]


class LSTMForecaster(nn.Module):
    """Many-to-one LSTM that predicts `horizon` future values from the last hidden state."""
    def __init__(
        self,
        n_features: int,
        hidden_size: int = 64,
        num_layers: int = 2,
        dropout: float = 0.1,
        horizon: int = 1,
    ):
        super().__init__()
        self.horizon = horizon
        self.lstm = nn.LSTM(
            input_size=n_features,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        self.head = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, horizon),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, F]
        out, _ = self.lstm(x)
        last = out[:, -1, :]  # [B, H]
        return self.head(last)



class TrainConfig:
    epochs: int = 20
    batch_size: int = 4096
    lr: float = 1e-3
    weight_decay: float = 1e-5
    patience: int = 5
    num_workers: int = 0


def _rmse(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
    return torch.sqrt(torch.mean((pred - targ) ** 2))


def train_lstm(
    model: nn.Module,
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_val: np.ndarray,
    y_val: np.ndarray,
    cfg: TrainConfig = TrainConfig(),
    device: Optional[str] = None,
) -> Dict[str, Any]:
    """Train with early stopping on validation RMSE."""
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = model.to(device)

    train_loader = DataLoader(
        SequenceDataset(X_train, y_train),
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=(device == 'cuda'),
        drop_last=False,
    )
    val_loader = DataLoader(
        SequenceDataset(X_val, y_val),
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=(device == 'cuda'),
        drop_last=False,
    )

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    loss_fn = nn.MSELoss()

    best_val = float('inf')
    best_state = None
    bad = 0
    history = {'train_rmse': [], 'val_rmse': []}

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        train_losses = []
        for xb, yb in train_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(xb)
            loss = loss_fn(pred, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            train_losses.append(_rmse(pred.detach(), yb.detach()).item())

        model.eval()
        val_losses = []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)
                pred = model(xb)
                val_losses.append(_rmse(pred, yb).item())

        tr = float(np.mean(train_losses))
        va = float(np.mean(val_losses))
        history['train_rmse'].append(tr)
        history['val_rmse'].append(va)
        print(f"Epoch {epoch:02d}/{cfg.epochs} | train RMSE={tr:.4f} | val RMSE={va:.4f}")

        if va < best_val:
            best_val = va
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= cfg.patience:
                print(f"Early stopping (best val RMSE={best_val:.4f})")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    return {'best_val_rmse': best_val, 'history': history, 'device': device}


def predict_lstm(
    model: nn.Module,
    X: np.ndarray,
    batch_size: int = 8192,
    device: Optional[str] = None,
) -> np.ndarray:
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = model.to(device)
    model.eval()

    loader = DataLoader(
        SequenceDataset(X, np.zeros((X.shape[0], 1), dtype=np.float32)),
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=(device == 'cuda'),
        drop_last=False,
    )

    outs = []
    with torch.no_grad():
        for xb, _ in loader:
            xb = xb.to(device, non_blocking=True)
            outs.append(model(xb).detach().cpu().numpy())

    return np.concatenate(outs, axis=0)


## Grid inference + GeoTIFF writer

In [None]:
def predict_next_grid_lstm(
    model: LSTMForecaster,
    y_stack: np.ndarray,
    x_stack: Optional[np.ndarray],
    lookback: int,
    mean: np.ndarray,
    std: np.ndarray,
    batch_pixels: int = 65536,
    device: Optional[str] = None,
) -> np.ndarray:
    """
    Predict the next-horizon FWI grid from the last `lookback` timesteps.

    Returns
    -------
    pred : np.ndarray
        Shape [H, W] if horizon=1, else [horizon, H, W]
        (NaN where the input window contains missing values).
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = model.to(device)
    model.eval()

    T, H, W = y_stack.shape
    if T < lookback:
        raise ValueError('Not enough timesteps for inference lookback')

    y_win = y_stack[T - lookback:T]  # [L,H,W]

    F = 0 if x_stack is None else int(x_stack.shape[-1])
    horizon = int(model.horizon)

    # Prepare output
    if horizon == 1:
        out = np.full((H, W), np.nan, dtype=np.float32)
    else:
        out = np.full((horizon, H, W), np.nan, dtype=np.float32)

    # Flatten pixel dimension (do NOT materialize giant [P,L,F] arrays)
    P = H * W
    y_flat = y_win.reshape(lookback, P)  # view if contiguous

    if x_stack is not None:
        x_win = x_stack[T - lookback:T]  # [L,H,W,F]
        x_flat = x_win.reshape(lookback, P, F)

    for start in range(0, P, batch_pixels):
        end = min(P, start + batch_pixels)
        yb = y_flat[:, start:end].T  # [B, L]

        if x_stack is not None:
            xb = x_flat[:, start:end, :].transpose(1, 0, 2)  # [B, L, F]
            feat = np.concatenate([yb[:, :, None], xb], axis=2)
        else:
            feat = yb[:, :, None]

        # valid where no NaNs
        valid = np.all(np.isfinite(feat), axis=(1, 2))
        if not np.any(valid):
            continue

        feat_valid = feat[valid]
        feat_valid = apply_standard_scaler(feat_valid, mean, std)

        # predict
        pred = predict_lstm(model, feat_valid, batch_size=8192, device=device)  # [Nv, horizon]

        # scatter back
        idx = np.where(valid)[0] + start
        if horizon == 1:
            out.reshape(-1)[idx] = pred[:, 0].astype(np.float32)
        else:
            for h in range(horizon):
                out[h].reshape(-1)[idx] = pred[:, h].astype(np.float32)

    return out


def write_geotiff(path: str, arr: np.ndarray, ref_profile: dict):
    """Write [H,W] (or [1,H,W]) float32 with NaN nodata using ref_profile."""
    import rasterio

    profile = ref_profile.copy()
    profile.update(dtype='float32', nodata=np.nan, compress='deflate', predictor=3, BIGTIFF='IF_SAFER')

    if arr.ndim == 2:
        profile.update(count=1)
        data = arr[None, ...]
    elif arr.ndim == 3 and arr.shape[0] >= 1:
        profile.update(count=arr.shape[0])
        data = arr
    else:
        raise ValueError('arr must be [H,W] or [B,H,W]')

    with rasterio.open(path, 'w', **profile) as dst:
        dst.write(data.astype(np.float32), indexes=list(range(1, data.shape[0] + 1)))

    print('Wrote:', path)


## FWI-only training + forecast wrapper

In [None]:
from typing import Any, Dict, List, Optional


def run_lstm_fwi_forecast(
    fwi_paths: List[str],
    lookback: int = 14,
    horizon: int = 1,
    n_samples: int = 200000,
    val_frac: float = 0.1,
    seed: int = 42,
    model_kwargs: Optional[Dict[str, Any]] = None,
    train_cfg: TrainConfig = TrainConfig(),
    device: Optional[str] = None,
) -> Dict[str, Any]:
    """Train an LSTM forecaster on past FWI only and produce a next-horizon FWI grid forecast.

    Returns a dict with keys:
      - model, train_info
      - pred_grid (next horizon forecast)
      - ref_profile (for write_geotiff)
      - mean, std (scaler)
      - y_keys (time keys)

    Notes
    -----
    - Uses random sampling of (pixel, time) windows to keep training practical on large grids.
    - Windows containing NaNs are skipped.
    """
    # Load target
    y_stack, y_keys, ref_profile = load_raster_stack(fwi_paths)

    # Sample sequences
    X, y = sample_sequences(
        y_stack=y_stack,
        x_stack=None,
        lookback=lookback,
        horizon=horizon,
        n_samples=n_samples,
        seed=seed,
        require_finite=True,
    )

    # Shuffle + split
    rng = np.random.default_rng(seed)
    order = rng.permutation(X.shape[0])
    X = X[order]
    y = y[order]

    n_val = max(1, int(val_frac * X.shape[0]))
    X_val, y_val = X[:n_val], y[:n_val]
    X_train, y_train = X[n_val:], y[n_val:]

    # Scale
    mean, std = fit_standard_scaler(X_train)
    X_train_s = apply_standard_scaler(X_train, mean, std)
    X_val_s = apply_standard_scaler(X_val, mean, std)

    # Model
    model_kwargs = model_kwargs or {}
    model = LSTMForecaster(n_features=X.shape[-1], horizon=horizon, **model_kwargs)
    train_info = train_lstm(
        model=model,
        X_train=X_train_s,
        y_train=y_train,
        X_val=X_val_s,
        y_val=y_val,
        cfg=train_cfg,
        device=device,
    )

    # Forecast next grid
    pred_grid = predict_next_grid_lstm(
        model=model,
        y_stack=y_stack,
        x_stack=None,
        lookback=lookback,
        mean=mean,
        std=std,
        batch_pixels=65536,
        device=train_info.get('device', device),
    )

    return {
        'model': model,
        'train_info': train_info,
        'pred_grid': pred_grid,
        'ref_profile': ref_profile,
        'mean': mean,
        'std': std,
        'y_keys': y_keys,
    }


## End-to-end run

In [None]:
import glob

# --- Data location ---
# If you uploaded a ZIP of daily FWI GeoTIFFs, set ZIP_PATH and unzip to FWI_DIR.
# Otherwise, set FWI_DIR to the folder that already contains your .tif files.
ZIP_PATH = 'fwi_tifs_2025_09_12_2025_10_10.zip'  # adjust
FWI_DIR  = 'fwi'                                # adjust

if os.path.exists(ZIP_PATH):
    unzip_fwi_zip(ZIP_PATH, FWI_DIR)

# --- Collect FWI files ---
fwi_paths = sorted(glob.glob(os.path.join(FWI_DIR, '**', '*.tif'), recursive=True))
print('FWI tif count:', len(fwi_paths))
if len(fwi_paths) == 0:
    raise FileNotFoundError('No .tif files found under FWI_DIR. Check ZIP_PATH/FWI_DIR.')

# --- Load time series stack ---
y_stack, y_keys, ref_profile = load_raster_stack(fwi_paths)
print('Loaded y_stack:', y_stack.shape, 'timesteps:', len(y_keys))

# --- Train + forecast ---
result = run_lstm_fwi_forecast(
    fwi_paths=fwi_paths,
    lookback=14,
    horizon=1,
    n_samples=200_000,
    val_frac=0.1,
    seed=42,
    model_kwargs={'hidden_size': 128, 'num_layers': 2, 'dropout': 0.1},
    train_cfg=TrainConfig(epochs=10, batch_size=4096, lr=1e-3, patience=3),
)

pred = result['pred_grid']
print('pred_grid shape:', pred.shape)

# --- Write output ---
out_path = os.path.join(OUT_DIR, 'fwi_lstm_forecast.tif')
write_geotiff(out_path, pred if pred.ndim == 2 else pred[0], result['ref_profile'])
print('Wrote forecast:', out_path)
