In [None]:
# Cell 0 ────────────────────────────────────────────────────────────────
# Runtime flags for the notebook
#
# • RUN_TRAIN – set True to launch training
# • RUN_VALID – set True to compute fold-0 validation MAE
# • RUN_TEST  – set True to run ensemble inference on the test set
#
# Toggling these at the very top keeps downstream cells stateless.

RUN_TRAIN = False   # bfloat16 or float32 recommended
RUN_VALID = True
RUN_TEST  = True


In [None]:
# Cell 1 ────────────────────────────────────────────────────────────────
# Environment checks & dependency bootstrap
#
# • Verifies at least 2 CUDA-visible GPUs.
# • Installs MONAI (lightweight, --no-deps) if missing.
# • Emits a concise capability summary.

import subprocess, sys, importlib.util, textwrap, torch, os

def _install_monai():
    print("Installing MONAI …")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "-q", "monai"])

if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
    raise RuntimeError("Requires ≥ 2 GPUs with CUDA enabled.")

if importlib.util.find_spec("monai") is None:
    _install_monai()

import monai  # noqa: E402  (import after potential installation)

print(textwrap.dedent(f"""
    ╔══════════════════════════════════════════════╗
    ║ CUDA devices      : {torch.cuda.device_count()}                       ║
    ║ cuDNN available   : {torch.backends.cudnn.is_available()}             ║
    ║ bfloat16 support  : {torch.cuda.is_bf16_supported()}                  ║
    ║ MONAI version     : {monai.__version__}                               ║
    ╚══════════════════════════════════════════════╝
"""))


In [None]:
%%writefile _cfg.py

# Cell 2 ────────────────────────────────────────────────────────────────
# Global config  ➜ _cfg.py
#
# This file stores all hyper-parameters and misc settings in a
# SimpleNamespace for easy dot-access.  No heavyweight imports here.


from types import SimpleNamespace
import torch, random, numpy as np, os

def set_global_seed(seed: int = 123):
    """Ensure deterministic-ish behaviour across libraries."""
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

cfg = SimpleNamespace(
    # ------------------------------------------------------------------
    # Reproducibility
    seed                     = 123,
    device                   = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    local_rank               = 0,     # set by DDP launcher
    world_size               = 1,     # set by DDP launcher
    # ------------------------------------------------------------------
    # Data
    subsample                = None,  # e.g. 250 to debug faster
    # ------------------------------------------------------------------
    # Model
    backbone                 = "convnext_small.fb_in22k_ft_in1k",
    ema                      = True,
    ema_decay                = 0.99,
    # ------------------------------------------------------------------
    # Optimisation
    epochs                   = 1,
    batch_size               = 16,
    batch_size_val           = 16,
    use_amp                  = True,
    precision                = "bfloat16",    # "float32" | "bfloat16" | "float16"
    gradient_accumulation_steps = 4,
    clip_grad_norm           = 1.0,
    warmup_epochs            = 2,
    save_top_k               = 3,
    # ------------------------------------------------------------------
    # Early stopping
    early_stopping           = {"patience": 3, "streak": 0},
    # ------------------------------------------------------------------
    # Logging
    logging_steps            = 100,
)

# Make seeding immediately effective when _cfg is imported
set_global_seed(cfg.seed)


In [None]:
%%writefile _dataset.py

# Cell 3 ────────────────────────────────────────────────────────────────
# Dataset definition  ➜ _dataset.py
#
#  • Loads metadata CSV for fold splits.
#  • Supports optional row-subsampling for quick experiments.
#  • Performs light on-the-fly augmentation when mode == "train".
#  • Uses numpy.memmap to keep RAM usage low.


from __future__ import annotations
import os, glob, numpy as np, pandas as pd
from tqdm import tqdm
import torch
from typing import List, Tuple

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, cfg, mode: str = "train"):
        """
        Args:
            cfg  – global SimpleNamespace from _cfg.py
            mode – "train", "valid", or "test"
        """
        self.cfg   = cfg
        self.mode  = mode
        self.data, self.labels, self.records = self._load_metadata()

    # ------------------------------------------------------------------
    def _load_metadata(self) -> Tuple[List[np.ndarray], List[np.ndarray], List[str]]:
        df = pd.read_csv("/kaggle/input/openfwi-preprocessed-72x72/folds.csv")

        if self.cfg.subsample is not None:
            df = df.groupby(["dataset", "fold"]).head(self.cfg.subsample)

        df = df[df["fold"] != 0] if self.mode == "train" else df[df["fold"] == 0]

        data, labels, records = [], [], []
        mmap_mode = "r"

        for _, row in tqdm(df.iterrows(), total=len(df), disable=self.cfg.local_rank != 0):
            row = row.to_dict()

            # Resolve file-paths across both openfwi_float16 datasets
            parts = row["data_fpath"].split("/")
            patterns = [
                os.path.join("/kaggle/input/open-wfi-1/openfwi_float16_1/", row["data_fpath"]),
                os.path.join("/kaggle/input/open-wfi-1/openfwi_float16_1/", parts[0], "*", parts[-1]),
                os.path.join("/kaggle/input/open-wfi-2/openfwi_float16_2/", row["data_fpath"]),
                os.path.join("/kaggle/input/open-wfi-2/openfwi_float16_2/", parts[0], "*", parts[-1]),
            ]
            farr = sum((glob.glob(p) for p in patterns), [])
            if len(farr) != 1:
                raise FileNotFoundError(f"Expected 1 match, got {len(farr)} for {row['data_fpath']}")
            fdata = farr[0]
            flabel = fdata.replace('seis', 'vel').replace('data', 'model')

            data.append(np.load(fdata, mmap_mode=mmap_mode))
            labels.append(np.load(flabel, mmap_mode=mmap_mode))
            records.append(row["dataset"])

        return data, labels, records

    # ------------------------------------------------------------------
    def __getitem__(self, idx: int):
        row_idx, col_idx = divmod(idx, 500)

        x = self.data[row_idx][col_idx].copy()
        y = self.labels[row_idx][col_idx].copy()

        if self.mode == "train":
            if np.random.rand() < 0.5:        # temporal flip
                x = x[::-1, :, ::-1]; y = y[..., ::-1]
            if np.random.rand() < 0.3:        # gaussian noise
                x += np.random.normal(0, 0.03, x.shape).astype(x.dtype)
            if np.random.rand() < 0.3:        # intensity scale
                x *= np.random.uniform(0.95, 1.05)
            if np.random.rand() < 0.2:        # brightness jitter
                x *= np.random.uniform(0.9, 1.1)
            if np.random.rand() < 0.2:        # simple elastic warp
                offset = np.random.randint(-1, 2)
                if offset:
                    x = np.roll(x, offset, axis=-1)

        return x, y

    def __len__(self) -> int:
        return len(self.records) * 500


In [None]:
%%writefile _model.py


# Cell 4 ────────────────────────────────────────────────────────────────
# Model definition  ➜ _model.py
#
# • ConvNeXt encoder with custom asymmetric stem.
# • UNet-style decoder (InstanceNorm + SCSE attention).
# • EMA wrapper for test-time stability and Ensemble helper.
# • Utilities to swap activations / norms and patch ConvNeXt blocks.

from __future__ import annotations
from copy import deepcopy
from types import MethodType

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.models.convnext import ConvNeXtBlock
from monai.networks.blocks import UpSample, SubpixelUpsample

# ────────────────────────── EMA / Ensemble ────────────────────────────
class ModelEMA(nn.Module):
    """Exponential-moving-average wrapper."""
    def __init__(self, model: nn.Module, decay: float = 0.99, device: torch.device | None = None):
        super().__init__()
        self.module = deepcopy(model).eval()
        self.decay  = decay
        self.device = device
        if self.device:
            self.module.to(device=self.device)

    @torch.no_grad()
    def _update(self, model, fn):
        for ema_w, w in zip(self.module.state_dict().values(), model.state_dict().values()):
            ema_w.copy_(fn(ema_w, w.to(self.device) if self.device else w))

    def update(self, model): self._update(model, lambda e, m: self.decay * e + (1.0 - self.decay) * m)
    def set   (self, model): self._update(model, lambda _e, m: m)

class EnsembleModel(nn.Module):
    """Simple arithmetic-mean ensemble."""
    def __init__(self, models): super().__init__(); self.models = nn.ModuleList(models).eval()
    def forward(self, x): return sum(m(x) for m in self.models) / len(self.models)

# ─────────────────────────── Decoder blocks ───────────────────────────
class ConvBnAct2d(nn.Module):
    def __init__(self, ic, oc, k, p=0, s=1, norm=nn.Identity, act=nn.GELU):
        super().__init__()
        self.conv = nn.Conv2d(ic, oc, k, stride=s, padding=p, bias=False)
        self.norm = norm(oc) if norm is not nn.Identity else nn.Identity()
        self.act  = act(inplace=True)
    def forward(self, x): return self.act(self.norm(self.conv(x)))

class SCSEModule2d(nn.Module):
    """Spatial-and-channel squeeze-and-excitation."""
    def __init__(self, c, r=16):
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(c, c // r, 1), nn.GELU(),
            nn.Conv2d(c // r, c, 1), nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(nn.Conv2d(c, 1, 1), nn.Sigmoid())
    def forward(self, x): return x * self.cSE(x) + x * self.sSE(x)

class Attention2d(nn.Module):
    def __init__(self, name=None, **kw):
        super().__init__()
        self.attention = nn.Identity() if name is None else {"scse": SCSEModule2d}[name](**kw)
    def forward(self, x): return self.attention(x)

class DecoderBlock2d(nn.Module):
    def __init__(
        self,
        ic: int, skip_c: int, oc: int,
        norm: nn.Module = nn.Identity,
        attention: str | None = None,
        intermediate: bool = False,
        up: str = "deconv",
        scale: int = 2,
    ):
        super().__init__()
        # upsample
        self.up = (SubpixelUpsample if up == "pixelshuffle" else UpSample)(
            spatial_dims=2, in_channels=ic, out_channels=ic, scale_factor=scale, mode=up
        )
        # optional skip refinement
        self.intermediate_conv = (
            nn.Sequential(
                ConvBnAct2d(skip_c or ic, skip_c or ic, 3, 1, norm),
                ConvBnAct2d(skip_c or ic, skip_c or ic, 3, 1, norm),
            ) if intermediate else None
        )
        self.att1 = Attention2d(attention, in_channels=ic + skip_c)
        self.conv1 = ConvBnAct2d(ic + skip_c, oc, 3, 1, norm)
        self.conv2 = ConvBnAct2d(oc, oc, 3, 1, norm)
        self.att2 = Attention2d(attention, in_channels=oc)

    def forward(self, x, skip=None):
        x = self.up(x)
        if self.intermediate_conv is not None:
            skip = self.intermediate_conv(skip) if skip is not None else self.intermediate_conv(x)
        if skip is not None:
            x = self.att1(torch.cat([x, skip], dim=1))
        x = self.conv2(self.conv1(x))
        return self.att2(x)

class UnetDecoder2d(nn.Module):
    def __init__(
        self,
        enc_chs: tuple[int],
        dec_chs: tuple[int] = (256, 128, 64, 32),
        norm: nn.Module = nn.Identity,
        attention: str | None = None,
        intermediate: bool = False,
        up: str = "deconv",
        scale_factors: tuple[int] = (2, 2, 2, 2),
    ):
        super().__init__()
        if len(enc_chs) == 4:  # convnext_small has 4 stages
            dec_chs = dec_chs[1:]
        self.decoder_channels = dec_chs

        skip_chs = list(enc_chs[1:]) + [0]
        in_chs = [enc_chs[0]] + list(dec_chs[:-1])

        self.blocks = nn.ModuleList([
            DecoderBlock2d(ic, sc, dc, norm, attention, intermediate, up, sf)
            for ic, sc, dc, sf in zip(in_chs, skip_chs, dec_chs, scale_factors)
        ])

    def forward(self, feats):
        x = feats[0]
        outs = [x]
        skips = feats[1:]
        for i, blk in enumerate(self.blocks):
            x = blk(outs[-1], skip=skips[i] if i < len(skips) else None)
            outs.append(x)
        return outs

class SegmentationHead2d(nn.Module):
    def __init__(self, ic, oc, scale=1, k=3, mode="nontrainable"):
        super().__init__()
        self.conv = nn.Conv2d(ic, oc, k, padding=k // 2)
        self.up   = UpSample(spatial_dims=2, in_channels=oc, out_channels=oc, scale_factor=scale, mode=mode)
    def forward(self, x): return self.up(self.conv(x))

# ──────────────────── ConvNeXt patch helper ---------------------------
def _convnext_block_forward(self, x):
    shortcut = x
    x = self.conv_dw(x)
    x = self.norm(x)
    if self.use_conv_mlp:
        x = self.mlp(x)
    else:
        x = self.mlp(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()
    if self.gamma is not None:
        x = x * self.gamma.reshape(1, -1, 1, 1)
    return self.drop_path(x) + self.shortcut(shortcut)

# ───────────────────────────── Net class ──────────────────────────────
class Net(nn.Module):
    def __init__(self, backbone: str, pretrained: bool = True):
        super().__init__()
        # Encoder
        self.backbone = timm.create_model(
            backbone,
            in_chans=5,
            pretrained=pretrained,
            features_only=True,
            drop_path_rate=0.0,
        )
        enc_chs = [_["num_chs"] for _ in self.backbone.feature_info][::-1]

        # Decoder & head
        self.decoder  = UnetDecoder2d(enc_chs)
        self.seg_head = SegmentationHead2d(self.decoder.decoder_channels[-1], 1, scale=1)

        # Customisations
        self._update_stem(backbone)
        self._swap_activations(self.backbone)
        self._swap_norms(self.backbone)
        self._patch_convnext_blocks(self.backbone)

    # ---------------- stem tweak for non-square input ------------------
    def _update_stem(self, backbone):
        if not backbone.startswith("convnext"):
            return
        stem = self.backbone.stem_0               # Conv2d(5→128, k7, s2)
        stem.stride, stem.padding = (4, 1), (0, 2)
        with torch.no_grad():
            w = stem.weight
            new = nn.Conv2d(w.shape[0], w.shape[0], kernel_size=(4, 4), stride=(4, 1), padding=(0, 1))
            new.weight.copy_(w.repeat(1, (128 // w.shape[1]) + 1, 1, 1)[:, : new.weight.shape[1]])
            new.bias.copy_(stem.bias)
        self.backbone.stem_0 = nn.Sequential(
            nn.ReflectionPad2d((1, 1, 80, 80)), stem, new
        )

    # --------------------- replacement utilities ----------------------
    def _swap_activations(self, module):
        for name, child in module.named_children():
            if isinstance(child, (
                nn.ReLU, nn.LeakyReLU, nn.Mish, nn.Sigmoid, nn.Tanh,
                nn.Softmax, nn.Hardtanh, nn.ELU, nn.SELU, nn.PReLU,
                nn.CELU, nn.GELU, nn.SiLU,
            )):
                setattr(module, name, nn.GELU())
            else:
                self._swap_activations(child)

    def _swap_norms(self, module):
        for name, child in module.named_children():
            n_feats = None
            if isinstance(child, (nn.BatchNorm2d, nn.InstanceNorm2d)):
                n_feats = child.num_features
            elif isinstance(child, nn.GroupNorm):
                n_feats = child.num_channels
            elif isinstance(child, nn.LayerNorm):
                n_feats = child.normalized_shape[0]

            if n_feats is not None:
                setattr(module, name, nn.InstanceNorm2d(n_feats, affine=True))
            else:
                self._swap_norms(child)

    def _patch_convnext_blocks(self, module):
        for child in module.children():
            if isinstance(child, ConvNeXtBlock):
                child.forward = MethodType(_convnext_block_forward, child)
            else:
                self._patch_convnext_blocks(child)

    # -------------------------- forward pass --------------------------
    def _forward_core(self, x_in):
        feats = self.backbone(x_in)[::-1]      # deepest → shallowest
        dec   = self.decoder(feats)
        seg   = self.seg_head(dec[-1])[..., 1:-1, 1:-1]   # crop artefacts
        return seg * 1500 + 3000

    def proc_flip(self, x):                    # Test-time augmentation
        return torch.flip(self._forward_core(torch.flip(x, dims=[-3, -1])), dims=[-1])

    def forward(self, x):
        pred = self._forward_core(x)
        if self.training:
            return pred
        # two-view TTA: original + (time, space) flip
        return torch.mean(torch.stack([pred, self.proc_flip(x)]), dim=0)


In [None]:
%%writefile _model_init.py


# Cell 5 ────────────────────────────────────────────────────────────────
# Weight-initialisation & extra attention layers  ➜ _model_init.py
#
# • `initialize_model()` – Kaiming init for decoder & seg-head only.
# • Optional ChannelAttention block you can swap into the decoder.

from __future__ import annotations
import torch.nn as nn, torch

# --------------------------------------------------------------------- #
# 1) Improved decoder initialisation                                    #
# --------------------------------------------------------------------- #
def initialize_model(model: nn.Module):
    """
    Re-initialise decoder & seg-head layers with Kaiming-Normal weights
    (backbone already carries pretrained weights).
    """
    def _init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
            if m.weight is not None: nn.init.ones_(m.weight)
            if m.bias   is not None: nn.init.zeros_(m.bias)

    model.decoder.apply(_init)
    model.seg_head.apply(_init)
    return model

# --------------------------------------------------------------------- #
# 2) Channel-wise Squeeze-and-Excitation (optional)                     #
# --------------------------------------------------------------------- #
class ChannelAttention(nn.Module):
    """
    Lightweight SE block: global-pool → bottleneck FC → sigmoid gate.
    Use it as a drop-in replacement for SCSE if desired.
    """
    def __init__(self, in_channels: int, r: int = 16):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // r, 1, bias=False),
            nn.GELU(),
            nn.Conv2d(in_channels // r, in_channels, 1, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return x * self.fc(self.avgpool(x))


In [None]:
%%writefile _train_utils.py


# Cell 6 ────────────────────────────────────────────────────────────────
# Training utilities  ➜ _train_utils.py
#
# • Mixed-precision dtype selector.
# • Warm-up + cosine LR scheduler.
# • Checkpoint manager retaining top-K best models.

from __future__ import annotations
import os, glob, math, torch, numpy as np
from torch.optim import Optimizer
from types import SimpleNamespace

# --------------------------------------------------------------------- #
# 1) Autocast dtype helper                                              #
# --------------------------------------------------------------------- #
def get_autocast_dtype(precision: str):
    if precision == "bfloat16" and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
        return torch.bfloat16
    if precision == "float16":
        return torch.float16
    return torch.float32

# --------------------------------------------------------------------- #
# 2) Scheduler: linear warm-up → cosine decay                           #
# --------------------------------------------------------------------- #
def get_lr_scheduler(optimizer: Optimizer, cfg: SimpleNamespace, steps_per_epoch: int):
    warm = cfg.warmup_epochs * steps_per_epoch
    total = cfg.epochs * steps_per_epoch

    def lr_lambda(step: int):
        if step < warm:                             # linear warm-up
            return step / max(1, warm)
        progress = (step - warm) / max(1, total - warm)
        return 0.5 * (1.0 + math.cos(math.pi * progress))  # cosine

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# --------------------------------------------------------------------- #
# 3) Checkpoint saver                                                   #
# --------------------------------------------------------------------- #
def save_checkpoint(model, ema_model, optimizer, scheduler, epoch,
                    val_loss: float, cfg: SimpleNamespace, is_best: bool):
    if cfg.local_rank != 0:           # only rank-0 touches disk
        return

    os.makedirs("checkpoints", exist_ok=True)
    sd = (ema_model.module if (ema_model and is_best) else
          model.module if hasattr(model, "module") else model).state_dict()

    ckpt = {
        "epoch":      epoch,
        "state_dict": sd,
        "val_loss":   val_loss,
        "optimizer":  optimizer.state_dict(),
        "scheduler":  scheduler.state_dict() if scheduler else None,
    }
    torch.save(ckpt, f"checkpoints/latest_{cfg.seed}.pt")
    if is_best:
        torch.save(ckpt, f"checkpoints/best_{cfg.seed}.pt")

    # Keep only top-K best losses
    if is_best and cfg.save_top_k > 1:
        torch.save(ckpt, f"checkpoints/ep{epoch}_loss{val_loss:.4f}.pt")
        ckpts = sorted(
            glob.glob(f"checkpoints/ep*_loss*.pt"),
            key=lambda p: float(os.path.basename(p).split("loss")[1][:-3]),
        )
        for p in ckpts[cfg.save_top_k:]:
            os.remove(p)


In [None]:
%%writefile _train.py


# Cell 7 ────────────────────────────────────────────────────────────────
# Core trainer script  ➜ _train.py
#
# • DDP-aware training loop with AMP + grad-accum.
# • Early stopping via patience counter.
# • Minimal console noise on non-rank-0 workers.

from __future__ import annotations
import os, time, random, numpy as np, torch, torch.distributed as dist
from tqdm import tqdm
from torch.utils.data import DistributedSampler, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler

from _cfg          import cfg
from _dataset      import CustomDataset
from _model        import Net, ModelEMA
from _train_utils  import get_autocast_dtype, get_lr_scheduler, save_checkpoint

# --------------------------------------------------------------------- #
# 0) Reproducible seed per rank                                         #
# --------------------------------------------------------------------- #
def _set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# --------------------------------------------------------------------- #
# 1) DDP setup / teardown                                               #
# --------------------------------------------------------------------- #
def _ddp_setup(rank: int, world: int):
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl", rank=rank, world_size=world)

def _ddp_cleanup():
    dist.barrier(); dist.destroy_process_group()

# --------------------------------------------------------------------- #
# 2) Main training routine                                              #
# --------------------------------------------------------------------- #
def main(cfg):
    # ── Data ──────────────────────────────────────────────────────────
    if cfg.local_rank == 0: print("⌛ Loading datasets …", flush=True)
    train_ds = CustomDataset(cfg, mode="train")
    valid_ds = CustomDataset(cfg, mode="valid")

    train_loader = DataLoader(
        train_ds,
        sampler=DistributedSampler(train_ds, num_replicas=cfg.world_size, rank=cfg.local_rank),
        batch_size=cfg.batch_size,
        num_workers=4,
        pin_memory=True,
    )
    valid_loader = DataLoader(
        valid_ds,
        sampler=DistributedSampler(valid_ds, num_replicas=cfg.world_size, rank=cfg.local_rank),
        batch_size=cfg.batch_size_val,
        num_workers=4,
        pin_memory=True,
    )

    # ── Model / Optim ────────────────────────────────────────────────
    model = Net(cfg.backbone).to(cfg.local_rank)
    ema_model = ModelEMA(model, cfg.ema_decay, device=cfg.local_rank) if cfg.ema else None
    model = DDP(model, device_ids=[cfg.local_rank])

    criterion  = torch.nn.L1Loss()
    optimizer  = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler     = GradScaler(enabled=cfg.use_amp)
    scheduler  = get_lr_scheduler(optimizer, cfg, steps_per_epoch=len(train_loader)//cfg.gradient_accumulation_steps)

    best_loss  = float("inf"); val_loss = float("inf")
    dtype_autocast = get_autocast_dtype(cfg.precision)

    if cfg.local_rank == 0:
        print(f"🚀 Training for {cfg.epochs} epoch(s) on {cfg.world_size} GPU(s)")

    # ── Epoch loop ───────────────────────────────────────────────────
    for epoch in range(1, cfg.epochs + 1):
        train_loader.sampler.set_epoch(epoch)
        model.train(); epoch_loss = []

        t0 = time.time(); optimizer.zero_grad()

        for step, (x, y) in enumerate(train_loader, 1):
            x, y = x.to(cfg.local_rank), y.to(cfg.local_rank)
            with autocast(device_type=cfg.device.type, dtype=dtype_autocast, enabled=cfg.use_amp):
                preds = model(x)
                loss  = criterion(preds, y) / cfg.gradient_accumulation_steps

            scaler.scale(loss).backward()

            if step % cfg.gradient_accumulation_steps == 0 or step == len(train_loader):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
                scaler.step(optimizer); scaler.update(); optimizer.zero_grad()
                if ema_model: ema_model.update(model)

            epoch_loss.append(loss.item() * cfg.gradient_accumulation_steps)

            if cfg.local_rank == 0 and step % cfg.logging_steps == 0:
                print(f"Epoch {epoch} Step {step}/{len(train_loader)} "
                      f"Train MAE {np.mean(epoch_loss[-cfg.logging_steps:]):.4f}", flush=True)

        # ── Validation ───────────────────────────────────────────────
        val_loss = _validate(model, ema_model, valid_loader, criterion, cfg, dtype_autocast)
        is_best  = val_loss < best_loss
        best_loss = min(best_loss, val_loss)

        save_checkpoint(model, ema_model, optimizer, scheduler, epoch, val_loss, cfg, is_best)
        scheduler.step()

        # ── Early stopping ───────────────────────────────────────────
        es = cfg.early_stopping
        es["streak"] = 0 if is_best else es["streak"] + 1
        if es["streak"] > es["patience"]:
            if cfg.local_rank == 0: print("⏹ Early stopping triggered."); break

        if cfg.local_rank == 0:
            dt = time.time() - t0
            print(f"✓ Epoch {epoch} finished | Val MAE {val_loss:.4f} | "
                  f"Best {best_loss:.4f} | Time {dt/60:.1f} min", flush=True)

# --------------------------------------------------------------------- #
# 3) Validation helper                                                  #
# --------------------------------------------------------------------- #
@torch.no_grad()
def _validate(model, ema_model, loader, criterion, cfg, dtype_autocast):
    model.eval(); eval_model = ema_model.module if ema_model else model
    losses = []
    for x, y in loader:
        x, y = x.to(cfg.local_rank), y.to(cfg.local_rank)
        with autocast(device_type=cfg.device.type, dtype=dtype_autocast, enabled=cfg.use_amp):
            out = eval_model(x)
            losses.append(criterion(out, y).item())
    loss = np.mean(losses)
    # average across GPUs
    v = torch.tensor([loss], device=cfg.local_rank); dist.all_reduce(v, op=dist.ReduceOp.SUM)
    return v.item() / cfg.world_size

# --------------------------------------------------------------------- #
# 4) Entrypoint for torchrun                                            #
# --------------------------------------------------------------------- #
if __name__ == "__main__":
    rank       = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    _ddp_setup(rank, world_size)
    _set_seed(cfg.seed + rank)

    cfg.local_rank = rank; cfg.world_size = world_size
    main(cfg)

    _ddp_cleanup()


In [None]:
%%writefile _train_patching.py


# Cell 8 ────────────────────────────────────────────────────────────────
# Optional enhanced loop patcher  ➜ _train_patching.py
#
# • Run `python _train_patching.py --use_improved` before torchrun to
#   monkey-patch _train.main with an advanced loop (gradient-accum, etc.)

import argparse, importlib

parser = argparse.ArgumentParser()
parser.add_argument("--use_improved", action="store_true", help="Patch _train.main with enhanced loop")
args = parser.parse_args()

if not args.use_improved:
    print("Nothing to patch – run with --use_improved flag.")
    raise SystemExit

print("🔧  Applying enhanced training loop patch …")

# 1) import the existing modules
_train           = importlib.import_module("_train")
_enhanced        = importlib.import_module("_enhanced_train_function")
_model_init      = importlib.import_module("_model_init")
_dataset         = importlib.import_module("_dataset")
_model           = importlib.import_module("_model")
_train_utils     = importlib.import_module("_train_utils")

from _cfg import cfg
from torch.utils.data import DistributedSampler, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import torch

# 2) define patched main
def patched_main(cfg):
    # dataset
    train_ds = _dataset.CustomDataset(cfg, "train")
    valid_ds = _dataset.CustomDataset(cfg, "valid")
    train_dl = DataLoader(
        train_ds,
        sampler=DistributedSampler(train_ds, num_replicas=cfg.world_size, rank=cfg.local_rank),
        batch_size=cfg.batch_size,
        num_workers=4,
        pin_memory=True,
    )
    valid_dl = DataLoader(
        valid_ds,
        sampler=DistributedSampler(valid_ds, num_replicas=cfg.world_size, rank=cfg.local_rank),
        batch_size=cfg.batch_size_val,
        num_workers=4,
        pin_memory=True,
    )

    # model / ema
    model = _model.Net(cfg.backbone)
    _model_init.initialize_model(model)      # custom init
    model = model.to(cfg.local_rank)
    ema   = _model.ModelEMA(model, cfg.ema_decay, device=cfg.local_rank) if cfg.ema else None
    model = DDP(model, device_ids=[cfg.local_rank])

    # optimiser & schedulers
    criterion = torch.nn.L1Loss()
    optim     = torch.optim.Adam(model.parameters(), lr=1e-3)
    sched     = _train_utils.get_lr_scheduler(optim, cfg, len(train_dl)//cfg.gradient_accumulation_steps)

    best = _enhanced.enhanced_training_loop(
        model, train_dl, valid_dl, criterion, optim, sched, ema, cfg
    )
    return best

# 3) Monkey-patch
_train.main = patched_main
print("✅  Patch successful – launch training with torchrun as usual.")


In [None]:
%%writefile _enhanced_train_function.py


# Cell 9 ────────────────────────────────────────────────────────────────
# Enhanced training loop  ➜ _enhanced_train_function.py
#
# • Gradient accumulation + AMP + gradient-clipping.
# • Early stopping and top-K checkpointing via _train_utils helpers.

from __future__ import annotations
import time, numpy as np, torch
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

from _train_utils import (
    get_autocast_dtype,
    save_checkpoint,
)

# --------------------------------------------------------------------- #
# Core loop                                                             #
# --------------------------------------------------------------------- #
def enhanced_training_loop(
    model,
    train_dl,
    valid_dl,
    criterion,
    optimizer,
    scheduler,
    ema_model,
    cfg,
):
    best_loss = float("inf")
    val_loss  = float("inf")
    dtype_cast = get_autocast_dtype(cfg.precision)
    scaler     = GradScaler(enabled=cfg.use_amp)

    steps_per_epoch = len(train_dl) // cfg.gradient_accumulation_steps
    if scheduler: scheduler.optimizer.steps_per_epoch = steps_per_epoch

    for epoch in range(1, cfg.epochs + 1):
        t0 = time.time()
        train_dl.sampler.set_epoch(epoch)
        model.train(); epoch_losses = []

        optimizer.zero_grad()

        # ── Train ─────────────────────────────────────────────────────
        for step, (x, y) in enumerate(train_dl, 1):
            x, y = x.to(cfg.local_rank), y.to(cfg.local_rank)

            with autocast(device_type=cfg.device.type, dtype=dtype_cast, enabled=cfg.use_amp):
                preds = model(x)
                loss  = criterion(preds, y) / cfg.gradient_accumulation_steps

            scaler.scale(loss).backward()

            if step % cfg.gradient_accumulation_steps == 0 or step == len(train_dl):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
                scaler.step(optimizer); scaler.update(); optimizer.zero_grad()
                if ema_model: ema_model.update(model)

            epoch_losses.append(loss.item() * cfg.gradient_accumulation_steps)

            if cfg.local_rank == 0 and step % cfg.logging_steps == 0:
                rolling = np.mean(epoch_losses[-cfg.logging_steps:])
                print(f"[Epoch {epoch}] Step {step}/{len(train_dl)}  "
                      f"Train MAE {rolling:.4f}  Val MAE {val_loss:.4f}", flush=True)

        # ── Validation ───────────────────────────────────────────────
        val_loss = evaluate_model(model, valid_dl, criterion, ema_model, cfg, dtype_cast)

        # ── Scheduler / Checkpoint / ES ──────────────────────────────
        if scheduler: scheduler.step()
        is_best = val_loss < best_loss
        best_loss = min(best_loss, val_loss)

        save_checkpoint(model, ema_model, optimizer, scheduler,
                        epoch, val_loss, cfg, is_best=is_best)

        es = cfg.early_stopping
        es["streak"] = 0 if is_best else es["streak"] + 1
        if es["streak"] > es["patience"]:
            if cfg.local_rank == 0: print("⏹  Early stopping triggered."); break

        if cfg.local_rank == 0:
            dt = (time.time() - t0) / 60
            print(f"✓ Epoch {epoch} done | Val {val_loss:.4f} | Best {best_loss:.4f} | {dt:.1f} min")

    return best_loss

# --------------------------------------------------------------------- #
# Validation helper                                                     #
# --------------------------------------------------------------------- #
@torch.no_grad()
def evaluate_model(model, loader, criterion, ema_model, cfg, dtype_cast):
    model.eval()
    eval_model = ema_model.module if ema_model else model
    losses = []

    for x, y in tqdm(loader, disable=cfg.local_rank != 0):
        x, y = x.to(cfg.local_rank), y.to(cfg.local_rank)
        with autocast(device_type=cfg.device.type, dtype=dtype_cast, enabled=cfg.use_amp):
            out = eval_model(x)
            losses.append(criterion(out, y).item())

    loss = np.mean(losses)
    v = torch.tensor([loss], device=cfg.local_rank); torch.distributed.all_reduce(v)
    return v.item() / cfg.world_size


In [None]:
%%writefile _optimized_test.py


# Cell 10 ───────────────────────────────────────────────────────────────
# Memory-efficient ensemble inference  ➜ _optimized_test.py
#
# • Loads checkpoints, builds EnsembleModel, runs TTA inference.
# • Sub-batches to avoid OOM and flushes GPU cache periodically.

from __future__ import annotations
import csv, time, glob, numpy as np, torch
from tqdm import tqdm
from torch.utils.data import DataLoader, SequentialSampler

from _cfg   import cfg
from _model import Net, EnsembleModel
from _train_utils import get_autocast_dtype

# --------------------------------------------------------------------- #
# Dataset                                                               #
# --------------------------------------------------------------------- #
class _TestDS(torch.utils.data.Dataset):
    def __init__(self, files): self.files = files
    def __len__(self): return len(self.files)
    def __getitem__(self, i):
        f = self.files[i]; stem = f.split("/")[-1].split(".")[0]
        return np.load(f, mmap_mode="r"), stem

# --------------------------------------------------------------------- #
# Batched inference helper                                              #
# --------------------------------------------------------------------- #
def _infer_batched(model, x, sub_bs, dtype_cast):
    if sub_bs is None or sub_bs >= x.size(0):
        with torch.cuda.amp.autocast(device_type=cfg.device.type, dtype=dtype_cast, enabled=cfg.use_amp):
            return model(x)
    outs = []
    for i in range(0, x.size(0), sub_bs):
        with torch.cuda.amp.autocast(device_type=cfg.device.type, dtype=dtype_cast, enabled=cfg.use_amp):
            outs.append(model(x[i : i + sub_bs]))
        if i % (sub_bs * 4) == 0:
            torch.cuda.empty_cache()
    return torch.cat(outs)

# --------------------------------------------------------------------- #
# Public entry-point                                                    #
# --------------------------------------------------------------------- #
def run_optimized_inference():
    print("🔍  Loading ensemble checkpoints …")
    models = []
    for f in sorted(glob.glob("/kaggle/input/openfwi-preprocessed-72x72/models_1000x70/*.pt")):
        print("  ↳", f)
        m = Net(cfg.backbone, pretrained=False)
        sd = torch.load(f, map_location=cfg.device, weights_only=True)
        sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
        m.load_state_dict(sd); models.append(m)
    model = EnsembleModel(models).to(cfg.device).eval()
    print(f"✅  {len(models)} models loaded.")

    dtype_cast = get_autocast_dtype(cfg.precision)
    sub_bs     = 4  # adjust per GPU memory

    test_files = sorted(glob.glob("/kaggle/input/open-wfi-test/test/*.npy"))
    dl = DataLoader(_TestDS(test_files), sampler=SequentialSampler(test_files),
                    batch_size=cfg.batch_size_val, num_workers=4, pin_memory=True)

    x_cols = [f"x_{i}" for i in range(1, 70, 2)]
    fieldnames = ["oid_ypos"] + x_cols

    t0 = time.time(); rows = 0
    with open("submission.csv", "w", newline="") as fh:
        writer = csv.DictWriter(fh, fieldnames); writer.writeheader()

        with torch.inference_mode():
            for x, stems in tqdm(dl):
                x = x.to(cfg.device, non_blocking=True)
                y = _infer_batched(model, x, sub_bs, dtype_cast).cpu().numpy()

                for arr, stem in zip(y[:, 0], stems):
                    for y_pos in range(70):
                        writer.writerow({
                            "oid_ypos": f"{stem}_y_{y_pos}",
                            **{c: arr[y_pos, idx] for idx, c in enumerate(x_cols)}
                        })
                        rows += 1
                if rows % 100_000 == 0: fh.flush()
                torch.cuda.empty_cache()

    print(f"📝  submission.csv written ({rows:_} rows)  |  {time.time()-t0:.1f}s")
    return "submission.csv"


In [None]:
# Cell 11 ───────────────────────────────────────────────────────────────
# Quick fold-0 validation (runs only if RUN_VALID is True)

if RUN_VALID:
    from tqdm import tqdm
    import numpy as np, torch
    from torch.cuda.amp import autocast
    from _dataset import CustomDataset
    from _model   import EnsembleModel
    from _train_utils import get_autocast_dtype

    valid_ds = CustomDataset(cfg, mode="valid")
    valid_dl = torch.utils.data.DataLoader(
        valid_ds,
        sampler=torch.utils.data.SequentialSampler(valid_ds),
        batch_size=cfg.batch_size_val,
        num_workers=4,
    )

    criterion   = torch.nn.L1Loss()
    dtype_cast  = get_autocast_dtype(cfg.precision)
    val_logits, val_targets = [], []

    with torch.no_grad():
        for x, y in tqdm(valid_dl):
            x, y = x.to(cfg.device), y.to(cfg.device)
            with autocast(device_type=cfg.device.type, dtype=dtype_cast, enabled=cfg.use_amp):
                out = model(x)
            val_logits.append(out.cpu()); val_targets.append(y.cpu())

    val_logits  = torch.cat(val_logits); val_targets = torch.cat(val_targets)
    mae         = criterion(val_logits, val_targets).item()
    print(f"🔎  Fold-0 validation MAE: {mae:.2f}")


In [None]:
# Cell 12 ───────────────────────────────────────────────────────────────
# Test-set inference (runs only if RUN_TEST is True)

if RUN_TEST:
    try:
        path = run_optimized_inference()
        print(f"✅  Submission saved → {path}")
    except Exception as e:
        print("⚠️  Optimised path failed – falling back.  Error:", e)
        # ─ fallback: single-pass inference without sub-batching ─
        import glob, csv, numpy as np, pandas as pd, time, torch
        from torch.utils.data import DataLoader, SequentialSampler

        test_files = sorted(glob.glob("/kaggle/input/open-wfi-test/test/*.npy"))
        dl = DataLoader(_TestDS(test_files), sampler=SequentialSampler(test_files),
                        batch_size=cfg.batch_size_val, num_workers=4)
        x_cols = [f"x_{i}" for i in range(1, 70, 2)]
        fieldnames = ["oid_ypos"] + x_cols
        t0 = time.time(); rows = 0
        with open("submission.csv", "w", newline="") as fh:
            writer = csv.DictWriter(fh, fieldnames); writer.writeheader()
            with torch.inference_mode(), torch.cuda.amp.autocast(device_type=cfg.device.type,
                                                                 dtype=get_autocast_dtype(cfg.precision),
                                                                 enabled=cfg.use_amp):
                for x, stems in tqdm(dl):
                    preds = model(x.to(cfg.device)).cpu().numpy()[:, 0]
                    for arr, stem in zip(preds, stems):
                        for y_pos in range(70):
                            writer.writerow({
                                "oid_ypos": f"{stem}_y_{y_pos}",
                                **{c: arr[y_pos, idx] for idx, c in enumerate(x_cols)}
                            })
                            rows += 1
        print(f"✓ Fallback inference done ({rows:_} rows) | {time.time()-t0:.1f}s")


In [None]:
# Cell 13 ───────────────────────────────────────────────────────────────
# Visual sanity-check of a few predictions (only if RUN_TEST)

if RUN_TEST:
    import matplotlib.pyplot as plt, torch
    n = 5  # number of slices to show
    fig, axes = plt.subplots(1, n, figsize=(n * 2.5, 2.5))

    # reuse last batch from Cell 12 or run a tiny inference
    sample = torch.tensor(np.load(test_files[0]), dtype=torch.float32)[None].to(cfg.device)
    with torch.no_grad():
        pred = model(sample)[0, 0].cpu().numpy()

    for i in range(n):
        axes[i].imshow(pred[:, i * 2], cmap="gray")
        axes[i].set_title(f"Slice {i}")
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()
