In [70]:
# fUSI → train/val/test windows (80/10/10), streaming, grouped splits to prevent leakage
# - Groups files by session (subject+run) or by subject before splitting
# - Streams windows into small .pt chunks: out_dir/{train,val,test}/*
# - Executes when you run the cell

from pathlib import Path
from typing import List, Tuple, Optional, Dict
import json, gc, re
import numpy as np
import torch

# --------- deps ---------
try:
    import h5py
except ImportError:
    h5py = None
try:
    from scipy.io import loadmat
except ImportError:
    loadmat = None

def _require_iolibs():
    if h5py is None and loadmat is None:
        raise RuntimeError(
            "Need at least one MAT loader. Install one of:\n"
            "  pip install h5py   # for v7.3 MAT\n"
            "  pip install scipy  # for legacy MAT"
        )

# ----------------------- I/O helpers -----------------------
def _find_candidate_key(h5) -> Optional[str]:
    if h5py is None:
        return None
    hits = []
    def visit(name, obj):
        if isinstance(obj, h5py.Dataset):
            lname = name.lower()
            if any(k in lname for k in ["idop", "dop", "doppler"]):
                hits.append(name)
    h5.visititems(visit)
    for c in hits:
        if c.split("/")[-1].lower() == "idop":
            return c
    return hits[0] if hits else None

def _load_mat_v73(path: Path) -> Optional[np.ndarray]:
    if h5py is None:
        return None
    try:
        with h5py.File(path, "r") as f:
            k = _find_candidate_key(f)
            if k is None:
                return None
            return np.array(f[k])
    except Exception:
        return None

def _load_mat_legacy(path: Path) -> Optional[np.ndarray]:
    if loadmat is None:
        return None
    try:
        md = loadmat(path, squeeze_me=False, struct_as_record=False)
        for k in ["iDop", "IDop", "doppler", "Dop", "dop"]:
            if k in md:
                return np.array(md[k])
        for _, v in md.items():
            if isinstance(v, np.ndarray) and v.ndim in (3, 4) and np.issubdtype(v.dtype, np.number):
                return np.array(v)
    except Exception:
        return None
    return None

def load_idop_any(path: Path) -> np.ndarray:
    arr = _load_mat_v73(path)
    if arr is None:
        arr = _load_mat_legacy(path)
    if arr is None:
        raise ValueError(f"Could not find Doppler array in {path}")
    return arr

def canonicalize_idop(arr: np.ndarray) -> np.ndarray:
    """
    Return [N, T, H, W] (no channel).
    3D: two largest -> spatial; remainder -> T.
    4D: spatial = pair with minimal |a-b|; of remaining dims, smaller->T, larger->N.
    """
    if arr.ndim == 3:
        s = list(arr.shape)
        idx = sorted(range(3), key=lambda i: s[i], reverse=True)
        H, W, T = idx[0], idx[1], idx[2]
        out = np.transpose(arr, (T, H, W))   # [T,H,W]
        return out[None, ...].astype(np.float32, copy=False)  # [1,T,H,W]
    if arr.ndim == 4:
        s = list(arr.shape)
        best_pair, best_diff = None, None
        for i in range(4):
            for j in range(i+1, 4):
                d = abs(s[i] - s[j])
                if best_diff is None or d < best_diff:
                    best_diff, best_pair = d, (i, j)
        H_idx, W_idx = best_pair
        rem = [i for i in range(4) if i not in (H_idx, W_idx)]
        T_idx, N_idx = (rem[0], rem[1]) if s[rem[0]] <= s[rem[1]] else (rem[1], rem[0])
        out = np.transpose(arr, (N_idx, T_idx, H_idx, W_idx))  # [N,T,H,W]
        return out.astype(np.float32, copy=False)
    raise ValueError(f"Unsupported ndim={arr.ndim}")

def log1p_safe(x: np.ndarray) -> np.ndarray:
    if x.min() < 0:
        x = x - x.min()
    return np.log1p(x)

def zscore_over_time_per_trial(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """x: [N,T,H,W] — z-score across time per pixel, per trial."""
    m = x.mean(axis=1, keepdims=True)
    s = x.std(axis=1, keepdims=True)
    return (x - m) / (s + eps)

def make_windows(x_nthw: np.ndarray, T_win: int, stride: int, pad_short: bool = False) -> np.ndarray:
    """x: [N,T,H,W] → windows [M,T_win,1,H,W] (float32)."""
    N, T, H, W = x_nthw.shape
    out = []
    for n in range(N):
        if T < T_win:
            if not pad_short:
                continue
            pad = T_win - T
            pre, post = pad // 2, pad - pad // 2
            x_t = np.pad(x_nthw[n], ((pre, post), (0, 0), (0, 0)), mode="edge")
            out.append(x_t[None, :, None, :, :])
            continue
        for t0 in range(0, T - T_win + 1, stride):
            x_t = x_nthw[n, t0:t0 + T_win]
            out.append(x_t[None, :, None, :, :])
    if not out:
        return np.empty((0, T_win, 1, H, W), dtype=np.float32)
    return np.concatenate(out, axis=0).astype(np.float32, copy=False)

# ----------------------- GROUPED splitting (3-way) -----------------------
def _session_key(p: Path) -> str:
    """
    Group files by (subject, run) if present; else by subject; else by stem.
    Handles names like:
      doppler_S129_R1+normcorre.mat
      doppler_S129_R1_allTrials+normcorre.mat
      dopplerContinuous_S129_R1+normcorre.mat
    """
    name = p.name
    m = re.search(r'_S(\d+)_R(\d+)', name, re.IGNORECASE)
    if m:
        return f"S{m.group(1)}_R{m.group(2)}"
    m = re.search(r'_S(\d+)', name, re.IGNORECASE)
    if m:
        return f"S{m.group(1)}"
    return p.stem

def _subject_key(p: Path) -> str:
    name = p.name
    m = re.search(r'_S(\d+)', name, re.IGNORECASE)
    return f"S{m.group(1)}" if m else p.stem

def split_files_3way_grouped(
    files: List[Path],
    ratios=(0.8, 0.1, 0.1),
    seed: int = 0,
    group_by: str = "session"  # "session" (subject+run) or "subject"
) -> Tuple[List[Path], List[Path], List[Path]]:
    assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0"
    key_fn = _session_key if group_by == "session" else _subject_key

    # build groups
    groups: Dict[str, List[Path]] = {}
    for f in files:
        k = key_fn(f)
        groups.setdefault(k, []).append(f)

    # shuffle group ids
    rng = np.random.RandomState(seed)
    gids = list(groups.keys())
    rng.shuffle(gids)

    n = len(gids)
    n_tr = int(np.floor(ratios[0]*n))
    n_va = int(np.floor(ratios[1]*n))
    n_te = n - n_tr - n_va
    if n >= 2 and n_te == 0:
        if n_va > 0: n_va -= 1; n_te = 1
        elif n_tr > 1: n_tr -= 1; n_te = 1

    tr_ids = set(gids[:n_tr])
    va_ids = set(gids[n_tr:n_tr+n_va])
    te_ids = set(gids[n_tr+n_va:])

    train = [f for gid in tr_ids for f in groups[gid]]
    val   = [f for gid in va_ids for f in groups[gid]]
    test  = [f for gid in te_ids for f in groups[gid]]
    return train, val, test

# --------------- STREAMING builder (writes small chunks) ---------------
def build_and_cache_streaming_3way(
    root: Path,
    out_dir: Path,
    pattern: str = "*normcorre.mat",
    ratios=(0.8, 0.1, 0.1),
    split_by: str = "file",           # "file" or "trial"
    group_by: str = "session",        # only used when split_by="file"
    T_win: int = 8,
    stride: int = 4,
    apply_log1p: bool = True,
    apply_zscore: bool = True,
    pad_short: bool = False,
    seed: int = 0,
    chunk_size: int = 32,             # small to avoid OOM in notebooks
) -> Dict[str, List[Path]]:
    """
    Streams windows into out_dir/{train,val,test}/ as small .pt files.
    Splits by grouped files to prevent leakage across variants.
    Returns dict listing the chunk file paths in each split.
    """
    _require_iolibs()
    out_dir = Path(out_dir)
    files = sorted(Path(root).glob(pattern))
    if not files:
        raise FileNotFoundError(f"No files in {root} matching {pattern}")

    for tag in ("train", "val", "test"):
        (out_dir / tag).mkdir(parents=True, exist_ok=True)

    saved = {"train": [], "val": [], "test": []}
    totals = {"train": 0, "val": 0, "test": 0}

    def _save_chunks(tag: str, wins: np.ndarray, stem: str, counter: int) -> int:
        if wins.size == 0:
            return counter
        n = wins.shape[0]
        for i in range(0, n, chunk_size):
            part = torch.from_numpy(wins[i:i + chunk_size].copy())
            path = out_dir / tag / f"{stem}_part{counter:05d}.pt"
            torch.save(part, path)
            saved[tag].append(path)
            counter += 1
        totals[tag] += n
        return counter

    if split_by == "file" and len(files) > 1:
        train_files, val_files, test_files = split_files_3way_grouped(
            files, ratios=ratios, seed=seed, group_by=group_by
        )
        sets = [("train", train_files), ("val", val_files), ("test", test_files)]
        counters = {"train": 0, "val": 0, "test": 0}
        for tag, fset in sets:
            for f in fset:
                try:
                    arr = load_idop_any(f)
                    x = canonicalize_idop(arr)  # [N,T,H,W]
                    if apply_log1p: x = log1p_safe(x)
                    if apply_zscore: x = zscore_over_time_per_trial(x)
                    wins = make_windows(x, T_win=T_win, stride=stride, pad_short=pad_short)
                    counters[tag] = _save_chunks(tag, wins, f.stem, counters[tag])
                    print(f"[{tag}] {f.name}: trials={x.shape[0]} T={x.shape[1]} → saved {wins.shape[0]} windows")
                except Exception as e:
                    print(f"[{tag}] SKIP {f.name}: {e}")
                finally:
                    del arr, x
                    if 'wins' in locals(): del wins
                    gc.collect()
    else:
        # Single-file or explicit trial split (3-way inside each file)
        counters = {"train": 0, "val": 0, "test": 0}
        for f in files:
            try:
                arr = load_idop_any(f)
                x = canonicalize_idop(arr)
                if apply_log1p: x = log1p_safe(x)
                if apply_zscore: x = zscore_over_time_per_trial(x)
                # 3-way split of trials
                rng = np.random.RandomState(seed)
                idx = np.arange(x.shape[0]); rng.shuffle(idx)
                n = x.shape[0]
                n_tr = int(np.floor(ratios[0]*n))
                n_va = int(np.floor(ratios[1]*n))
                tr_idx = np.sort(idx[:n_tr])
                va_idx = np.sort(idx[n_tr:n_tr+n_va])
                te_idx = np.sort(idx[n_tr+n_va:])
                w_tr = make_windows(x[tr_idx], T_win=T_win, stride=stride, pad_short=pad_short)
                w_va = make_windows(x[va_idx], T_win=T_win, stride=stride, pad_short=pad_short)
                w_te = make_windows(x[te_idx], T_win=T_win, stride=stride, pad_short=pad_short)
                counters["train"] = _save_chunks("train", w_tr, f.stem, counters["train"])
                counters["val"]   = _save_chunks("val",   w_va, f.stem, counters["val"])
                counters["test"]  = _save_chunks("test",  w_te, f.stem, counters["test"])
                print(f"[trial-3way] {f.name}: train {w_tr.shape[0]} | val {w_va.shape[0]} | test {w_te.shape[0]} windows")
            except Exception as e:
                print(f"[trial-3way] SKIP {f.name}: {e}")
            finally:
                del arr, x
                if 'w_tr' in locals(): del w_tr
                if 'w_va' in locals(): del w_va
                if 'w_te' in locals(): del w_te
                gc.collect()

    # Write manifest
    manifest = {
        "root": str(root),
        "pattern": pattern,
        "ratios": ratios,
        "split_by": split_by,
        "group_by": group_by,
        "T_win": T_win,
        "stride": stride,
        "apply_log1p": apply_log1p,
        "apply_zscore": apply_zscore,
        "pad_short": pad_short,
        "seed": seed,
        "counts": totals,
        "train_files": [str(p) for p in saved["train"]],
        "val_files":   [str(p) for p in saved["val"]],
        "test_files":  [str(p) for p in saved["test"]],
    }
    with open(out_dir / "manifest.json", "w") as f:
        json.dump(manifest, f, indent=2)

    print(f"\nSaved chunks → train: {totals['train']} windows in {len(saved['train'])} files, "
          f"val: {totals['val']} in {len(saved['val'])}, test: {totals['test']} in {len(saved['test'])}.")
    print(f"Manifest: {out_dir/'manifest.json'}")
    return saved

# ---------------- Lazy dataset for the saved chunks ----------------
class FUSIChunkedDataset(torch.utils.data.Dataset):
    """Loads window chunks lazily from out_dir/train, val, or test."""
    def __init__(self, chunk_paths: List[Path]):
        self.files = [Path(p) for p in chunk_paths]
        self._sizes = []
        for p in self.files:
            t = torch.load(p, map_location="cpu")
            self._sizes.append(int(t.shape[0]))
        self._cum = np.cumsum([0] + self._sizes)
        self._cache_idx = None
        self._cache_tensor = None

    def __len__(self) -> int:
        return int(self._cum[-1])

    def _locate(self, idx: int) -> Tuple[int, int]:
        c = int(np.searchsorted(self._cum, idx, side="right") - 1)
        off = idx - self._cum[c]
        return c, off

    def __getitem__(self, idx: int) -> torch.Tensor:
        c, off = self._locate(idx)
        if self._cache_idx != c:
            self._cache_tensor = torch.load(self.files[c], map_location="cpu")
            self._cache_idx = c
        return self._cache_tensor[off]

def load_manifest(manifest_path: Path) -> Dict:
    with open(manifest_path, "r") as f:
        return json.load(f)

# ------------------------- RUN (executes now) -------------------------
ROOT = Path("fusiData/doppler")
OUT  = Path("fusi_splits_stream_80_10_10")

PATTERN = "**/*normcorre.mat"  # recursive; use "*normcorre.mat" if flat

_ = build_and_cache_streaming_3way(
    root=ROOT,
    out_dir=OUT,
    pattern=PATTERN,
    ratios=(0.8, 0.1, 0.1),
    split_by="file",          # prefer session-level split across files
    group_by="session",       # keep all variants of the same Sxxx_Ryy together
    T_win=8,
    stride=4,
    apply_log1p=True,
    apply_zscore=True,
    pad_short=False,
    seed=0,
    chunk_size=32,
)


FileNotFoundError: No files in fusiData/doppler matching **/*normcorre.mat

In [55]:
# model_unet_temporal.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---- UNet building blocks ----
class DoubleConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
        )
    def forward(self, x):  # [B,C,H,W]
        return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x):
        return self.conv(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, 2, stride=2)
        self.conv = DoubleConv(in_ch // 2 + out_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        dh, dw = skip.size(2) - x.size(2), skip.size(3) - x.size(3)
        if dh or dw:
            if dh < 0: x = x[:, :, :skip.size(2), :]
            if dw < 0: x = x[:, :, :, :skip.size(3)]
            dh, dw = skip.size(2) - x.size(2), skip.size(3) - x.size(3)
            if dh > 0 or dw > 0:
                x = F.pad(x, (0, max(dw,0), 0, max(dh,0)))
        return self.conv(torch.cat([x, skip], dim=1))

# ---- temporal transformer over T only ----
class SinusoidalPosEnc(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:, 0::2], pe[:, 1::2] = torch.sin(pos*div), torch.cos(pos*div)
        self.register_buffer("pe", pe, persistent=False)
    def forward(self, x):  # [B,T,d]
        return x + self.pe[:x.size(1)].unsqueeze(0).to(x.dtype)

class TemporalTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=4, nlayers=3, dropout=0.0):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                         dim_feedforward=4*d_model, dropout=dropout,
                                         batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=nlayers)
        self.pos = SinusoidalPosEnc(d_model)
    def forward(self, seq):  # [B,T,d]
        return self.encoder(self.pos(seq))

# ---- FiLM conditioning (temporal -> spatial decoder) ----
class FiLMHead(nn.Module):
    def __init__(self, d_model: int, channels: int):
        super().__init__()
        hidden = max(128, d_model)
        self.mlp = nn.Sequential(nn.Linear(d_model, hidden), nn.SiLU(),
                                 nn.Linear(hidden, 2*channels))
        self.C = channels
    def forward(self, z):  # [B,T,d]
        gb = self.mlp(z)   # [B,T,2C]
        return gb[..., :self.C], gb[..., self.C:]  # gamma, beta

# ---- full model ----
class UNetTemporalDenoiser(nn.Module):
    """
    Input/Output: x ∈ R[B, T, 1, H, W]  →  y ∈ R[B, T, 1, H, W]
    Spatial UNet runs per-frame; temporal transformer runs along T.
    Transformer outputs modulate decoder via FiLM per time step.
    """
    def __init__(self, base=32, d_model=128, nhead=4, nlayers=3,
                 dropout=0.0, predict_residual=True):
        super().__init__()
        c1, c2, c3, c4 = base, 2*base, 4*base, 8*base
        # encoder
        self.inc   = DoubleConv(1, c1)
        self.down1 = Down(c1, c2)
        self.down2 = Down(c2, c3)
        self.down3 = Down(c3, c4)
        # tokens from multi-scale pooled features
        self.to_token = nn.Linear(c2 + c3 + c4, d_model)
        self.temporal = TemporalTransformer(d_model=d_model, nhead=nhead, nlayers=nlayers, dropout=dropout)
        # decoder + FiLM
        self.up3  = Up(c4, c3); self.film3 = FiLMHead(d_model, c3)
        self.up2  = Up(c3, c2); self.film2 = FiLMHead(d_model, c2)
        self.up1  = Up(c2, c1); self.film1 = FiLMHead(d_model, c1)
        self.outc = nn.Conv2d(c1, 1, 1)
        self.predict_residual = predict_residual
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if getattr(m, "bias", None) is not None:
                    nn.init.zeros_(m.bias)

    @staticmethod
    def _film(x_bt, g_bt, b_bt):
        return x_bt * (1.0 + g_bt).unsqueeze(-1).unsqueeze(-1) + b_bt.unsqueeze(-1).unsqueeze(-1)

    def forward(self, x):  # x: [B,T,1,H,W]
        B, T, _, H, W = x.shape
        x_bt = x.view(B*T, 1, H, W)

        # encoder per frame
        e1 = self.inc(x_bt)        # [B*T,c1,H,W]
        e2 = self.down1(e1)        # [B*T,c2,H/2,W/2]
        e3 = self.down2(e2)        # [B*T,c3,H/4,W/4]
        b  = self.down3(e3)        # [B*T,c4,H/8,W/8]

        # pooled tokens -> transformer over T
        p2 = F.adaptive_avg_pool2d(e2, 1).flatten(1)
        p3 = F.adaptive_avg_pool2d(e3, 1).flatten(1)
        p4 = F.adaptive_avg_pool2d(b,  1).flatten(1)
        tok = torch.cat([p2, p3, p4], dim=1)          # [B*T, c2+c3+c4]
        z = self.temporal(self.to_token(tok).view(B, T, -1))  # [B,T,d_model]

        # FiLM params per time → flatten to [B*T,C]
        def bt(tC): return tC.reshape(B*T, -1)
        g3, b3 = bt(self.film3(z)[0]), bt(self.film3(z)[1])
        g2, b2 = bt(self.film2(z)[0]), bt(self.film2(z)[1])
        g1, b1 = bt(self.film1(z)[0]), bt(self.film1(z)[1])

        # decoder per frame + FiLM
        d3 = self._film(self.up3(b, e3), g3, b3)
        d2 = self._film(self.up2(d3, e2), g2, b2)
        d1 = self._film(self.up1(d2, e1), g1, b1)
        y_bt = self.outc(d1)                       # [B*T,1,H,W]
        y = y_bt.view(B, T, 1, H, W)
        return x + y if self.predict_residual else y

In [71]:
# train_fusi_denoiser.py
# -- MPS-safe, OOM-resistant training with realistic fUSI-like noise --
import os, math, bisect, json, time, random, datetime as dt
from pathlib import Path
from typing import List, Tuple, Optional, Type

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Optional: helps MPS avoid watermark OOMs
os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")

# ---------------- Device & AMP ----------------
def pick_device():
    if torch.cuda.is_available():
        return "cuda"
    elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

DEVICE = pick_device()
USE_AMP = (DEVICE == "cuda")
AMP_DTYPE = torch.float16 if DEVICE == "cuda" else None
if DEVICE == "cuda":
    torch.set_float32_matmul_precision("medium")

# ---------------- Model (inline) ----------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x):
        return self.conv(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, 2, stride=2)
        self.conv = DoubleConv(in_ch // 2 + out_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        dh, dw = skip.size(2) - x.size(2), skip.size(3) - x.size(3)
        if dh or dw:
            if dh < 0: x = x[:, :, :skip.size(2), :]
            if dw < 0: x = x[:, :, :, :skip.size(3)]
            dh, dw = skip.size(2) - x.size(2), skip.size(3) - x.size(3)
            if dh > 0 or dw > 0:
                x = F.pad(x, (0, max(dw,0), 0, max(dh,0)))
        return self.conv(torch.cat([x, skip], dim=1))

class SinusoidalPosEnc(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:, 0::2], pe[:, 1::2] = torch.sin(pos*div), torch.cos(pos*div)
        self.register_buffer("pe", pe, persistent=False)
    def forward(self, x):  # [B,T,d]
        return x + self.pe[:x.size(1)].unsqueeze(0).to(x.dtype)

class TemporalTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=4, nlayers=3, dropout=0.0):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                         dim_feedforward=4*d_model, dropout=dropout,
                                         batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=nlayers)
        self.pos = SinusoidalPosEnc(d_model)
    def forward(self, seq):  # [B,T,d]
        return self.encoder(self.pos(seq))

class FiLMHead(nn.Module):
    def __init__(self, d_model: int, channels: int):
        super().__init__()
        hidden = max(128, d_model)
        self.mlp = nn.Sequential(nn.Linear(d_model, hidden), nn.SiLU(),
                                 nn.Linear(hidden, 2*channels))
        self.C = channels
    def forward(self, z):  # [B,T,d]
        gb = self.mlp(z)   # [B,T,2C]
        return gb[..., :self.C], gb[..., self.C:]  # gamma, beta

class UNetTemporalDenoiser(nn.Module):
    """
    Input/Output: x ∈ R[B, T, 1, H, W]  →  y ∈ R[B, T, 1, H, W]
    Spatial UNet runs per-frame; temporal transformer runs along T.
    """
    def __init__(self, base=32, d_model=128, nhead=4, nlayers=3,
                 dropout=0.0, predict_residual=True):
        super().__init__()
        c1, c2, c3, c4 = base, 2*base, 4*base, 8*base
        # encoder
        self.inc   = DoubleConv(1, c1)
        self.down1 = Down(c1, c2)
        self.down2 = Down(c2, c3)
        self.down3 = Down(c3, c4)
        # tokens from multi-scale pooled features
        self.to_token = nn.Linear(c2 + c3 + c4, d_model)
        self.temporal = TemporalTransformer(d_model=d_model, nhead=nhead, nlayers=nlayers, dropout=dropout)
        # decoder + FiLM
        self.up3  = Up(c4, c3); self.film3 = FiLMHead(d_model, c3)
        self.up2  = Up(c3, c2); self.film2 = FiLMHead(d_model, c2)
        self.up1  = Up(c2, c1); self.film1 = FiLMHead(d_model, c1)
        self.outc = nn.Conv2d(c1, 1, 1)
        self.predict_residual = predict_residual
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if getattr(m, "bias", None) is not None:
                    nn.init.zeros_(m.bias)

    @staticmethod
    def _film(x_bt, g_bt, b_bt):
        return x_bt * (1.0 + g_bt).unsqueeze(-1).unsqueeze(-1) + b_bt.unsqueeze(-1).unsqueeze(-1)

    def forward(self, x):  # x: [B,T,1,H,W]
        B, T, _, H, W = x.shape
        x_bt = x.view(B*T, 1, H, W)

        # encoder per frame
        e1 = self.inc(x_bt)
        e2 = self.down1(e1)
        e3 = self.down2(e2)
        b  = self.down3(e3)

        # pooled tokens -> transformer over T
        p2 = F.adaptive_avg_pool2d(e2, 1).flatten(1)
        p3 = F.adaptive_avg_pool2d(e3, 1).flatten(1)
        p4 = F.adaptive_avg_pool2d(b,  1).flatten(1)
        tok = torch.cat([p2, p3, p4], dim=1)          # [B*T, c2+c3+c4]
        z = self.temporal(self.to_token(tok).view(B, T, -1))  # [B,T,d_model]

        # FiLM params per time → flatten to [B*T,C]
        def bt(tC): return tC.reshape(B*T, -1)
        g3, b3 = bt(self.film3(z)[0]), bt(self.film3(z)[1])
        g2, b2 = bt(self.film2(z)[0]), bt(self.film2(z)[1])
        g1, b1 = bt(self.film1(z)[0]), bt(self.film1(z)[1])

        # decoder per frame + FiLM
        d3 = self._film(self.up3(b, e3), g3, b3)
        d2 = self._film(self.up2(d3, e2), g2, b2)
        d1 = self._film(self.up1(d2, e1), g1, b1)
        y_bt = self.outc(d1)
        y = y_bt.view(B, T, 1, H, W)
        return x + y if self.predict_residual else y

# ---------------- Dataset over saved shards ----------------
class FUSIChunkedDataset(Dataset):
    """Each item is a clip [T,1,H,W] loaded on demand from shard .pt files."""
    def __init__(self, shard_paths: List[Path]):
        self.files = [Path(p) for p in shard_paths]
        self._sizes = []
        for p in self.files:
            t = torch.load(p, map_location="cpu")  # [M,T,1,H,W]
            self._sizes.append(int(t.shape[0]))
        self._cum: List[int] = [0]
        for n in self._sizes:
            self._cum.append(self._cum[-1] + n)
        self._cache_idx = None
        self._cache = None

    def __len__(self) -> int:
        return self._cum[-1]

    def _locate(self, idx: int) -> Tuple[int, int]:
        i = bisect.bisect_right(self._cum, int(idx)) - 1
        i = max(0, min(i, len(self.files) - 1))
        off = int(idx) - self._cum[i]
        return i, off

    def __getitem__(self, idx: int) -> torch.Tensor:
        if idx < 0 or idx >= len(self):
            raise IndexError(idx)
        shard, off = self._locate(idx)
        if self._cache_idx != shard:
            self._cache = torch.load(self.files[shard], map_location="cpu")  # [M,T,1,H,W]
            self._cache_idx = shard
        return self._cache[off].float()  # [T,1,H,W]

# ---------------- Collate: pad/crop ----------------
def _center_crop_or_pad(x: torch.Tensor, Ht: int, Wt: int) -> torch.Tensor:
    _, _, H, W = x.shape
    if H > Ht:
        top = (H - Ht) // 2
        x = x[:, :, top:top + Ht, :]
    if W > Wt:
        left = (W - Wt) // 2
        x = x[:, :, :, left:left + Wt]
    pad_h, pad_w = Ht - x.shape[2], Wt - x.shape[3]
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, pad_w, 0, pad_h), mode="replicate")
    return x

def collate_and_pad(batch, max_hw: Optional[Tuple[int, int]] = None):
    Tset = {b.shape[0] for b in batch}
    if len(Tset) != 1:
        raise ValueError(f"Mixed T in batch: {Tset}")
    if max_hw is None:
        Ht = max(b.shape[2] for b in batch)
        Wt = max(b.shape[3] for b in batch)
    else:
        Ht, Wt = max_hw
    out = [_center_crop_or_pad(x, Ht, Wt) for x in batch]
    return torch.stack(out, dim=0)  # [B,T,1,Ht,Wt]

# ---------------- Manifest helpers (auto-build if missing) ----------------
def _build_manifest_from_shards(split_dir: Path) -> dict:
    shards = sorted(str(p) for p in split_dir.glob("*.pt"))
    if not shards:
        shards = sorted(str(p) for p in split_dir.rglob("*.pt"))
    if not shards:
        raise FileNotFoundError(
            f"No manifest.json and no .pt shards found under {split_dir}.\n"
            f"Set SPLIT_DIR to your folder with saved shards."
        )
    n = len(shards)
    n_train = max(1, int(round(0.8 * n)))
    n_val   = max(1, int(round(0.1 * n)))
    n_test  = max(1, n - n_train - n_val)
    if n_train + n_val + n_test > n:
        n_test = n - n_train - n_val
        if n_test < 1:
            n_test = 1
            n_val = max(1, n - n_train - n_test)
            if n_train + n_val + n_test > n:
                n_train = n - n_val - n_test

    man = {
        "train_files": shards[:n_train],
        "val_files":   shards[n_train:n_train+n_val] if n_val > 0 else shards[:1],
        "test_files":  shards[n_train+n_val:] if n_test > 0 else shards[-1:],
        "note": "Auto-generated manifest (80/10/10) because manifest.json was missing."
    }
    with open(split_dir / "manifest.json", "w") as f:
        json.dump(man, f, indent=2)
    print(f"[autosplit] Built manifest.json with {n_train}/{n_val}/{n_test} shards in {split_dir}")
    return man

def _load_manifest(split_dir: Path) -> dict:
    mpath = split_dir / "manifest.json"
    if mpath.exists():
        with open(mpath, "r") as f:
            man = json.load(f)
        for k in ("train_files", "val_files", "test_files"):
            if k not in man or not isinstance(man[k], list) or len(man[k]) == 0:
                raise ValueError(f"manifest.json missing or empty '{k}'.")
        return man
    else:
        return _build_manifest_from_shards(split_dir)

def make_loaders(split_dir: Path, batch_size=1, num_workers=0, max_hw: Optional[Tuple[int, int]] = None):
    split_dir = Path(os.path.expanduser(str(split_dir))).resolve()
    if not split_dir.exists():
        raise FileNotFoundError(f"SPLIT_DIR does not exist: {split_dir}")
    man = _load_manifest(split_dir)

    train_ds = FUSIChunkedDataset([Path(p) for p in man["train_files"]])
    val_ds   = FUSIChunkedDataset([Path(p) for p in man["val_files"]])
    test_ds  = FUSIChunkedDataset([Path(p) for p in man["test_files"]])

    cf = (lambda b: collate_and_pad(b, max_hw=max_hw))
    pin = (DEVICE == "cuda")
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=pin, persistent_workers=False,
                              collate_fn=cf)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin, persistent_workers=False,
                              collate_fn=cf)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin, persistent_workers=False,
                              collate_fn=cf)
    return train_loader, val_loader, test_loader

# ---------------- Metrics & noise ----------------
def temporal_tv(x):
    return (x[:, 1:] - x[:, :-1]).abs().mean()

def psnr(x, y, eps=1e-8):
    mse = F.mse_loss(x, y)
    return 10.0 * torch.log10(1.0 / (mse + eps))

def add_noise(x, sigma=0.05, relative=True):
    if relative:
        s = x.std(dim=(2, 3, 4), keepdim=True).clamp_min(1e-6)
        noise = torch.randn_like(x) * (sigma * s)
    else:
        noise = torch.randn_like(x) * sigma
    return x + noise

@torch.no_grad()
def _integer_jitter(x: torch.Tensor, jitter_px: int) -> torch.Tensor:
    """Per-frame replicate-pad jitter. x: [B,T,1,H,W]"""
    if jitter_px <= 0:
        return x
    B, T, C, H, W = x.shape
    out = torch.empty_like(x)
    for b in range(B):
        dx = torch.randint(-jitter_px, jitter_px + 1, (T,), device=x.device)
        dy = torch.randint(-jitter_px, jitter_px + 1, (T,), device=x.device)
        for t in range(T):
            frame = x[b, t].unsqueeze(0)  # [1,1,H,W]
            pad = F.pad(frame, (jitter_px, jitter_px, jitter_px, jitter_px), mode="replicate")
            x0 = jitter_px + int(dx[t].item())
            y0 = jitter_px + int(dy[t].item())
            out[b, t] = pad[:, :, y0:y0+H, x0:x0+W]
    return out

def add_phys_like_noise(
    x: torch.Tensor,               # [B,T,1,H,W]
    fps: float = 10.0,
    sigma_white: float = 0.02,
    resp_band: Tuple[float,float] = (0.2, 0.4),
    card_band: Tuple[float,float] = (0.8, 1.2),
    lowrank_scale: float = 0.2,
    jitter_px: int = 1,
) -> torch.Tensor:
    B,T,_,H,W = x.shape
    device = x.device
    t = torch.arange(T, device=device).float() / max(fps, 1e-6)
    fr = torch.empty(B, device=device).uniform_(*resp_band)
    fc = torch.empty(B, device=device).uniform_(*card_band)
    phir = torch.empty(B, device=device).uniform_(0, 2*math.pi)
    phic = torch.empty(B, device=device).uniform_(0, 2*math.pi)
    resp = torch.sin(2*math.pi*fr[:,None]*t + phir[:,None])
    card = torch.sin(2*math.pi*fc[:,None]*t + phic[:,None])
    phys = (resp + 0.5*card).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)  # [B,T,1,1,1]

    # smooth spatial map -> broadcast over T
    Hs, Ws = max(1, H//8), max(1, W//8)
    spatial_lr = torch.randn(B, 1, Hs, Ws, device=device)
    spatial = F.interpolate(spatial_lr, size=(H, W), mode="bilinear", align_corners=False)  # (B,1,H,W)
    spatial = spatial / (spatial.flatten(1).std(dim=1, keepdim=True).clamp_min(1e-6)).view(B,1,1,1)
    spatial_bt = spatial.unsqueeze(1).expand(B, T, 1, H, W)

    lowrank = lowrank_scale * phys * spatial_bt
    x_jit = _integer_jitter(x, jitter_px=jitter_px)
    s = x.std(dim=(2,3,4), keepdim=True).clamp_min(1e-6)
    white = torch.randn_like(x) * (sigma_white * s)
    return x_jit + lowrank + white

# ---------------- Safe weights-only save helpers ----------------
def _safe_save(obj, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    tmp = path.with_suffix(path.suffix + ".tmp")
    torch.save(obj, tmp)
    tmp.replace(path)

def _snapshot_weights(model: nn.Module, epoch: int, step: int, val_loss: Optional[float], out_dir: Path, tag: str):
    payload = {
        "model": model.state_dict(),        # weights only
        "epoch": epoch,
        "step": step,
        "val_loss": float(val_loss) if val_loss is not None else None,
        "timestamp": time.time(),
    }
    _safe_save(payload, out_dir / f"{tag}.pt")

# ---------------- Training loop (hourly + step saves, best/last/final) ----------------
def train_fusi(
    split_dir: Path,
    epochs=1,
    batch_size=1,
    lr=2e-4,
    weight_decay=1e-5,
    # noise
    use_phys_noise: bool = True,
    fps: float = 10.0,
    noise_sigma: float = 0.05,
    sigma_white: float = 0.02,
    resp_band: Tuple[float,float] = (0.2, 0.4),
    card_band: Tuple[float,float] = (0.8, 1.2),
    lowrank_scale: float = 0.2,
    jitter_px: int = 1,
    # reg
    lambda_tv: float = 0.05,
    # model size
    base_channels=12,
    d_model=64,
    nlayers=1,
    # memory controls
    max_hw: Optional[Tuple[int, int]] = (128, 160),
    train_patch_hw: Optional[Tuple[int, int]] = (112, 144),
    # logging & save cadence
    log_every: int = 50,
    empty_cache_every: int = 10,
    save_every_seconds: int = 3600,      # hourly saves
    save_every_steps: int = 0,           # e.g., 200 to also save by step
    max_steps_per_epoch: Optional[int] = None,  # limit steps for quick runs
    # optional model class
    model_cls: Optional[Type[nn.Module]] = None,
):
    print(f"Device: {DEVICE} | AMP: {USE_AMP} ({AMP_DTYPE})  |  phys-noise: {use_phys_noise}")
    torch.backends.cudnn.benchmark = (DEVICE == "cuda")
    torch.manual_seed(0)

    # Loaders
    train_loader, val_loader, _ = make_loaders(split_dir, batch_size=batch_size, num_workers=0, max_hw=max_hw)

    # Model
    if model_cls is None:
        model_cls = UNetTemporalDenoiser
    model = model_cls(base=base_channels, d_model=d_model, nhead=4, nlayers=nlayers, predict_residual=True).to(DEVICE)

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.amp.GradScaler(device="cuda") if DEVICE == "cuda" else None

    weights_dir = Path(split_dir) / "weights"   # weights-only folder
    weights_dir.mkdir(parents=True, exist_ok=True)
    best_val = float("inf")

    # initial save
    _snapshot_weights(model, epoch=0, step=0, val_loss=None, out_dir=weights_dir, tag="init")

    ctx_device = "cuda" if DEVICE == "cuda" else ("mps" if DEVICE == "mps" else "cpu")
    def make_noisy(xb: torch.Tensor, train: bool) -> torch.Tensor:
        if use_phys_noise:
            return add_phys_like_noise(
                xb, fps=fps, sigma_white=sigma_white,
                resp_band=resp_band, card_band=card_band,
                lowrank_scale=lowrank_scale, jitter_px=jitter_px
            )
        else:
            return add_noise(xb, sigma=noise_sigma, relative=True)

    global_step = 0
    last_save_time = time.time()

    try:
        for epoch in range(1, epochs + 1):
            t0 = time.time()
            # ----- train -----
            model.train()
            tr_loss = tr_psnr = 0.0
            n_seen = 0
            step = 0
            for xb in train_loader:
                step += 1
                global_step += 1
                xb = xb.float()

                # optional random crop on CPU for memory
                if train_patch_hw is not None:
                    B, T, C, H, W = xb.shape
                    Hc, Wc = min(train_patch_hw[0], H), min(train_patch_hw[1], W)
                    if (Hc < H) or (Wc < W):
                        top = random.randint(0, H - Hc)
                        left = random.randint(0, W - Wc)
                        xb = xb[:, :, :, top:top + Hc, left:left + Wc]

                xb = xb.to(DEVICE, non_blocking=(DEVICE == "cuda"))
                with torch.no_grad():
                    x_noisy = make_noisy(xb, train=True)

                if USE_AMP:
                    with torch.amp.autocast(device_type=ctx_device, dtype=AMP_DTYPE):
                        y = model(x_noisy)
                        loss = F.mse_loss(y, xb) + lambda_tv * temporal_tv(y)
                else:
                    y = model(x_noisy)
                    loss = F.mse_loss(y, xb) + lambda_tv * temporal_tv(y)

                opt.zero_grad(set_to_none=True)
                if DEVICE == "cuda":
                    scaler.scale(loss).backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(opt); scaler.update()
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    opt.step()

                bs = xb.size(0)
                n_seen += bs
                tr_loss += loss.item() * bs
                with torch.no_grad():
                    tr_psnr += psnr(y.clamp_min(0), xb.clamp_min(0)).item() * bs

                if (step % log_every) == 0:
                    print(f"  step {step}: train loss {loss.item():.5f}")

                # Step-based save (optional)
                if save_every_steps and (global_step % save_every_steps == 0):
                    _snapshot_weights(model, epoch=epoch, step=global_step, val_loss=None, out_dir=weights_dir, tag=f"step-{global_step}")
                    _snapshot_weights(model, epoch=epoch, step=global_step, val_loss=None, out_dir=weights_dir, tag="last")

                # Hourly weights save
                now = time.time()
                if save_every_seconds and (now - last_save_time) >= save_every_seconds:
                    stamp = dt.datetime.now().strftime("%Y%m%d-%H%M%S")
                    _snapshot_weights(model, epoch=epoch, step=global_step, val_loss=None, out_dir=weights_dir, tag=f"hourly-{stamp}")
                    _snapshot_weights(model, epoch=epoch, step=global_step, val_loss=None, out_dir=weights_dir, tag="last")
                    last_save_time = now

                if DEVICE == "mps" and (step % empty_cache_every) == 0:
                    del x_noisy, y, loss
                    torch.mps.empty_cache()

                # Optional cap for quick runs
                if (max_steps_per_epoch is not None) and (step >= max_steps_per_epoch):
                    break

            tr_loss /= max(1, n_seen)
            tr_psnr /= max(1, n_seen)

            # ----- val -----
            model.eval()
            va_loss = va_psnr = 0.0
            n_seen = 0
            with torch.no_grad():
                for xb in val_loader:
                    xb = xb.float().to(DEVICE)
                    x_noisy = make_noisy(xb, train=False)
                    if USE_AMP:
                        with torch.amp.autocast(device_type=ctx_device, dtype=AMP_DTYPE):
                            y = model(x_noisy)
                            loss = F.mse_loss(y, xb) + lambda_tv * temporal_tv(y)
                    else:
                        y = model(x_noisy)
                        loss = F.mse_loss(y, xb) + lambda_tv * temporal_tv(y)
                    bs = xb.size(0)
                    n_seen += bs
                    va_loss += loss.item() * bs
                    va_psnr += psnr(y.clamp_min(0), xb.clamp_min(0)).item() * bs
                    if DEVICE == "mps":
                        del x_noisy, y, loss
                        torch.mps.empty_cache()

            va_loss = (va_loss / n_seen) if n_seen > 0 else float("inf")
            va_psnr = (va_psnr / n_seen) if n_seen > 0 else float("nan")

            if DEVICE == "mps":
                torch.mps.empty_cache()

            dt_epoch = time.time() - t0
            print(f"Epoch {epoch:02d} | {dt_epoch:5.1f}s "
                  f"train: loss {tr_loss:.5f}, PSNR {tr_psnr:.2f} | "
                  f"val: loss {va_loss:.5f}, PSNR {va_psnr:.2f}")

            # Save last & best (weights only)
            _snapshot_weights(model, epoch=epoch, step=global_step, val_loss=va_loss, out_dir=weights_dir, tag="last")
            if math.isfinite(va_loss) and (va_loss < best_val):
                best_val = va_loss
                _snapshot_weights(model, epoch=epoch, step=global_step, val_loss=va_loss, out_dir=weights_dir, tag="best")

    except KeyboardInterrupt:
        print("\nKeyboardInterrupt caught — saving interrupt weights...")
        _snapshot_weights(model, epoch=locals().get('epoch', 0), step=locals().get('global_step', 0),
                          val_loss=None, out_dir=weights_dir, tag="interrupt")
        raise
    finally:
        # Always write a final weights file
        _snapshot_weights(model, epoch=epochs, step=global_step,
                          val_loss=best_val if math.isfinite(best_val) else None,
                          out_dir=weights_dir, tag="final")
        print(f"Done. Best val loss: {best_val:.6f}  |  wrote weights in: {weights_dir}")
    return model

# ---- runner ----
if __name__ == "__main__":
    # set this to your real split path
    SPLIT_DIR = Path("fusi_splits_stream_80_10_10")
    _ = train_fusi(
        split_dir=SPLIT_DIR,
        epochs=5,            # quick pass
        batch_size=1,
        lr=2e-4,
        use_phys_noise=True,
        fps=10.0,
        sigma_white=0.02,
        resp_band=(0.2, 0.4),    # human; for rodent try (1.0, 3.0)
        card_band=(0.8, 1.2),    # human; for rodent try (6.0, 10.0)
        lowrank_scale=0.2,
        jitter_px=1,
        lambda_tv=0.05,
        base_channels=12,
        d_model=64,
        nlayers=1,
        max_hw=(128, 160),
        train_patch_hw=(112, 144),
        save_every_seconds=3600,  # hourly weights
        save_every_steps=0,       # set e.g. 200 to also save by step
        max_steps_per_epoch=50,   # cap for speed; raise/remove for full training
        log_every=10,
    )
    

Device: mps | AMP: False (None)  |  phys-noise: True
  step 10: train loss 2.53553
  step 20: train loss 1.02374
  step 30: train loss 1.12471
  step 40: train loss 0.85075
  step 50: train loss 0.82841
Epoch 01 | 130.8s train: loss 2.75089, PSNR 3.80 | val: loss 0.93624, PSNR 5.47
  step 10: train loss 1.15655
  step 20: train loss 0.90731
  step 30: train loss 0.14757
  step 40: train loss 0.95641
  step 50: train loss 0.91374
Epoch 02 | 132.6s train: loss 0.80090, PSNR 6.88 | val: loss 0.77935, PSNR 5.79
  step 10: train loss 0.79201
  step 20: train loss 1.00765
  step 30: train loss 0.81595
  step 40: train loss 0.82681
  step 50: train loss 0.96873
Epoch 03 | 862.7s train: loss 0.70940, PSNR 7.30 | val: loss 0.71869, PSNR 6.02
  step 10: train loss 0.21448
  step 20: train loss 0.65280
  step 30: train loss 0.83225
  step 40: train loss 0.73815
  step 50: train loss 0.79972
Epoch 04 | 2715.5s train: loss 0.62269, PSNR 7.90 | val: loss 0.67268, PSNR 6.23
  step 10: train loss 0.57

In [73]:
# runner_resume_best.py (PyTorch, Python 3.9-safe)
from pathlib import Path
from typing import Union, Tuple
import re
import torch

def infer_hparams_from_weights(weights_path: Union[str, Path]) -> Tuple[int, int, int, int]:
    pkg = torch.load(weights_path, map_location="cpu")
    state = pkg.get("model", pkg)

    base = state["inc.net.0.weight"].shape[0]            # first conv out-ch
    d_model = state["to_token.weight"].shape[0]          # rows of to_token
    # count transformer layers
    layer_idxs = []
    pat = re.compile(r"^temporal\.encoder\.layers\.(\d+)\.")
    for k in state.keys():
        m = pat.match(k)
        if m:
            layer_idxs.append(int(m.group(1)))
    nlayers = (max(layer_idxs) + 1) if layer_idxs else 1
    nhead = 4  # can't infer from weights; matches your training

    return base, d_model, nlayers, nhead

if __name__ == "__main__":

    WEIGHTS = Path("fusi_splits_stream_80_10_10/weights/best.pt")
    SPLIT_DIR = Path("fusi_splits_stream_80_10_10")

    base, d_model, nlayers, nhead = infer_hparams_from_weights(WEIGHTS)
    print(f"[resume] inferred → base={base}, d_model={d_model}, nlayers={nlayers}, nhead={nhead}")

    _ = train_fusi(
        split_dir=SPLIT_DIR,
        # resume from best.pt (make sure train_fusi supports init_weights)
        init_weights=WEIGHTS,
        epochs=20,
        batch_size=1,
        lr=1e-4,                   # lower LR for resume
        weight_decay=1e-5,

        use_phys_noise=True,
        fps=10.0,
        sigma_white=0.02,
        resp_band=(0.2, 0.4),
        card_band=(0.8, 1.2),
        lowrank_scale=0.2,
        jitter_px=1,
        lambda_tv=0.05,

        # use inferred architecture
        base_channels=base,
        d_model=d_model,
        nlayers=nlayers,

        max_hw=(128, 160),
        train_patch_hw=None,       # switch to (112,144) if you hit OOM

        save_every_seconds=3600,
        log_every=50,
        empty_cache_every=10,
    )


[resume] inferred → base=12, d_model=64, nlayers=1, nhead=4


TypeError: train_fusi() got an unexpected keyword argument 'init_weights'

In [None]:
#load model

from pathlib import Path
from typing import Union, Optional, Dict, Any
import torch

def pick_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

def load_model_from_weights(
    weights_path: Union[str, Path],
    *,
    base: int = 12,
    d_model: int = 64,
    nhead: int = 4,
    nlayers: int = 1,
    predict_residual: bool = True,
    device: Optional[str] = None,
) -> tuple[UNetTemporalDenoiser, str, Dict[str, Any]]:
    """
    Load a model from a weights-only .pt file produced by your training script.
    Pass the same hyperparams used during training (base, d_model, nhead, nlayers, predict_residual).
    Returns (model, device, meta_dict).
    """
    device = device or pick_device()
    weights_path = Path(weights_path)

    if not weights_path.exists():
        raise FileNotFoundError(f"Can't find weights file: {weights_path}")

    pkg = torch.load(weights_path, map_location=device)

    # Accept either {"model": state_dict, ...} or plain state_dict
    if isinstance(pkg, dict) and "model" in pkg:
        state = pkg["model"]
        meta = {k: pkg.get(k) for k in ("epoch", "step", "val_loss", "timestamp")}
        # If a config was saved, prefer it unless caller overrides
        cfg = pkg.get("config") or {}
        base = cfg.get("base", base)
        d_model = cfg.get("d_model", d_model)
        nhead = cfg.get("nhead", nhead)
        nlayers = cfg.get("nlayers", nlayers)
        if "predict_residual" in cfg:
            predict_residual = cfg["predict_residual"]
    else:
        state = pkg
        meta = {}

    model = UNetTemporalDenoiser(
        base=base, d_model=d_model, nhead=nhead, nlayers=nlayers,
        predict_residual=predict_residual
    ).to(device)

    try:
        model.load_state_dict(state, strict=True)
    except RuntimeError as e:
        print("Strict load failed; trying non-strict.\nDetails:\n", e)
        missing, unexpected = model.load_state_dict(state, strict=False)
        print("Loaded with strict=False.\nMissing keys:", missing, "\nUnexpected keys:", unexpected)

    model.eval()
    return model, device, meta


# ---- example usage ----
if __name__ == "__main__":
    weights_file = "fusi_splits_stream_80_10_10/weights/final.pt"
    model, device, meta = load_model_from_weights(
        weights_file,
        base=12, d_model=64, nhead=4, nlayers=1, predict_residual=True
    )
    print("Loaded on:", device, "| meta:", meta)

    # Optional quick test
    # x = torch.randn(1, 16, 1, 112, 144, device=device)
    # with torch.no_grad():
    #     y = model(x)
    # print("Output shape:", y.shape)


Loaded on: mps | meta: {'epoch': 1, 'step': 621, 'val_loss': None, 'timestamp': 1754973217.4593768}
