## Pipeline
1. **DICOM → 3D Volume**: Normalize to `(32, 384, 384)`
2. **EfficientNetV2-S**: 32-channel input, 14 binary outputs
3. **Ensemble**: Average 5-fold predictions

In [None]:
# import timm
# print("timm version:", timm.__version__)

In [None]:
# import timm, pprint

# # all models (long list)
# all_models = timm.list_models()
# print(len(all_models), "models available")
# for pat in [
#     "*convnextv2*",
#     "*efficientnetv2*",
#     "*maxvit*",
#     "*regnet*",
#     "*vit_*patch16*",
#     "*swin*",
# ]:
#     print("\n", pat)
#     pprint.pp(sorted(timm.list_models(pat)))

In [None]:
import os
import numpy as np
import pydicom
import cv2
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from scipy import ndimage
import warnings
import gc
from time import time
from tqdm.auto import tqdm
warnings.filterwarnings('ignore')

class DICOMPreprocessorKaggle:
    """
    DICOM preprocessing that converts original 
    DICOMPreprocessor logic to single series processing
    """
    
    def __init__(self, target_shape: Tuple[int, int, int] = (32, 384, 384)):
        self.target_depth, self.target_height, self.target_width = target_shape
        
    def load_dicom_series(self, series_path: str) -> Tuple[List[pydicom.Dataset], str]:
        """
        Load DICOM series
        """
        series_path = Path(series_path)
        series_name = series_path.name
        
        # Search for DICOM files
        dicom_files = []
        for root, _, files in os.walk(series_path):
            for file in files:
                if file.endswith('.dcm'):
                    dicom_files.append(os.path.join(root, file))
        
        if not dicom_files:
            raise ValueError(f"No DICOM files found in {series_path}")
        
        #print(f"Found {len(dicom_files)} DICOM files in series {series_name}")
        
        # Load DICOM datasets
        datasets = []
        for filepath in dicom_files:
            try:
                ds = pydicom.dcmread(filepath, force=True)
                datasets.append(ds)
            except Exception as e:
                #print(f"Failed to load {filepath}: {e}")
                continue
        
        if not datasets:
            raise ValueError(f"No valid DICOM files in {series_path}")
        
        return datasets, series_name
    
    def extract_slice_info(self, datasets: List[pydicom.Dataset]) -> List[Dict]:
        """
        Extract position information for each slice
        """
        slice_info = []
        
        for i, ds in enumerate(datasets):
            info = {
                'dataset': ds,
                'index': i,
                'instance_number': getattr(ds, 'InstanceNumber', i),
            }
            
            # Get z-coordinate from ImagePositionPatient
            try:
                position = getattr(ds, 'ImagePositionPatient', None)
                if position is not None and len(position) >= 3:
                    info['z_position'] = float(position[2])
                else:
                    # Fallback: use InstanceNumber
                    info['z_position'] = float(info['instance_number'])
                    #print("ImagePositionPatient not found, using InstanceNumber")
            except Exception as e:
                info['z_position'] = float(i)
                #print(f"Failed to extract position info: {e}")
            
            slice_info.append(info)
        
        return slice_info
    
    def sort_slices_by_position(self, slice_info: List[Dict]) -> List[Dict]:
        """
        Sort slices by z-coordinate
        """
        # Sort by z-coordinate
        sorted_slices = sorted(slice_info, key=lambda x: x['z_position'])
        
        #print(f"Sorted {len(sorted_slices)} slices by z-position")
        #print(f"Z-range: {sorted_slices[0]['z_position']:.2f} to {sorted_slices[-1]['z_position']:.2f}")
        
        return sorted_slices
    
    def get_windowing_params(self, ds: pydicom.Dataset, img: np.ndarray = None) -> Tuple[Optional[float], Optional[float]]:
        """
        Return (center, width) for windowing if appropriate, else (None, None).
        For CTA/CT we use a fixed angiography window; for MR we skip windowing.
        """
        modality = str(getattr(ds, "Modality", "CT")).upper()
    
        if modality == "CT":
            # CTA-style windowing for vessels
            center, width = 50.0, 350.0
            return center, width
    
        # MR and other modalities: do percentile-based normalization downstream
        return None, None

    
    def apply_windowing_or_normalize(self, img: np.ndarray, center: Optional[float], width: Optional[float]) -> np.ndarray:
        """
        If (center,width) provided -> apply window to 0-255.
        Else -> robust percentile normalization to 0-255.
        Returns uint8.
        """
        if center is not None and width is not None:
            # CT/CTA windowing
            img_min = center - width / 2.0
            img_max = center + width / 2.0
            windowed = np.clip(img, img_min, img_max)
            windowed = (windowed - img_min) / max(1e-6, (img_max - img_min))
            return (windowed * 255.0).astype(np.uint8)
    
        # MR (or unknown) -> percentile normalization
        p1, p99 = np.percentile(img, [1, 99])
        if p99 > p1:
            norm = np.clip(img, p1, p99)
            norm = (norm - p1) / (p99 - p1)
            return (norm * 255.0).astype(np.uint8)
    
        # Fallback: min-max
        mn, mx = float(img.min()), float(img.max())
        if mx > mn:
            norm = (img - mn) / (mx - mn)
            return (norm * 255.0).astype(np.uint8)
        return np.zeros_like(img, dtype=np.uint8)

    
    def extract_pixel_array(self, ds: pydicom.Dataset) -> np.ndarray:
        """
        Extract 2D pixel array from DICOM and apply basic preprocessing.
        Returns float32 image (no scaling to 0–255 here).
        """
        # Raw pixel data to float32
        img = ds.pixel_array.astype(np.float32)
    
        # If multi-frame (3D in a single file), take the middle frame for 2D path
        if img.ndim == 3:
            frame_idx = img.shape[0] // 2
            img = img[frame_idx]
    
        # Handle MONOCHROME1 (invert): larger values are darker -> flip
        if getattr(ds, "PhotometricInterpretation", "").upper() == "MONOCHROME1":
            # Invert relative to full range to preserve dynamic range
            img = img.max() - img
    
        # Apply DICOM rescale (DO NOT override slope/intercept)
        slope = float(getattr(ds, "RescaleSlope", 1.0))
        intercept = float(getattr(ds, "RescaleIntercept", 0.0))
        if slope != 1.0 or intercept != 0.0:
            img = img * slope + intercept
    
        # Optional: mask out pixel padding value if present (common in CT)
        if hasattr(ds, "PixelPaddingValue"):
            ppv = float(ds.PixelPaddingValue)
            img = np.where(np.isclose(img, ppv), np.nan, img)
    
        # Replace NaNs introduced by padding with local minimum (keeps dtype)
        if np.isnan(img).any():
            # Use finite min; if all NaN (shouldn't happen), fill zeros
            finite = img[np.isfinite(img)]
            fill_val = finite.min() if finite.size else 0.0
            img = np.nan_to_num(img, nan=fill_val)
    
        return img  # float32

    
    def resize_volume_3d(self, volume: np.ndarray) -> np.ndarray:
        """
        Resize 3D volume to target size
        """
        current_shape = volume.shape
        target_shape = (self.target_depth, self.target_height, self.target_width)
        
        if current_shape == target_shape:
            return volume
        
        #print(f"Resizing volume from {current_shape} to {target_shape}")
        
        # 3D resizing using scipy.ndimage
        zoom_factors = [
            target_shape[i] / current_shape[i] for i in range(3)
        ]
        
        # Resize with linear interpolation
        resized_volume = ndimage.zoom(volume, zoom_factors, order=1, mode='nearest')
        
        # Clip to exact size just in case
        resized_volume = resized_volume[:self.target_depth, :self.target_height, :self.target_width]
        
        # Padding if necessary
        pad_width = [
            (0, max(0, self.target_depth - resized_volume.shape[0])),
            (0, max(0, self.target_height - resized_volume.shape[1])),
            (0, max(0, self.target_width - resized_volume.shape[2]))
        ]
        
        if any(pw[1] > 0 for pw in pad_width):
            resized_volume = np.pad(resized_volume, pad_width, mode='edge')
        
        #print(f"Final volume shape: {resized_volume.shape}")
        return resized_volume.astype(np.uint8)
    
    def process_series(self, series_path: str, return_meta: bool = False):
        """
        Process DICOM series and return resampled NumPy array.
        If return_meta=True, also returns a dict with:
            - 'orig_depth': int, number of source slices/frames
            - 'spacing':   tuple (dz, dy, dx) if available, else None
        """
        try:
            # 1) Load DICOM files (headers + pixels on demand)
            datasets, series_name = self.load_dicom_series(series_path)
    
            # --- compute orig_depth from headers, BEFORE resampling ---
            orig_depth = None
            dz = dy = dx = None
    
            if len(datasets) == 1:
                ds0 = datasets[0]
                # multiframe (enhanced) 3D?
                nframes = getattr(ds0, "NumberOfFrames", None)
                if nframes is not None:
                    try:
                        orig_depth = int(nframes)
                    except Exception:
                        orig_depth = None
                # if not set, try pixel array ndim==3
                if orig_depth is None:
                    try:
                        arr0 = ds0.pixel_array  # pydicom will decode
                        if arr0.ndim == 3:
                            orig_depth = int(arr0.shape[0])
                    except Exception:
                        pass
            else:
                # multiple single-slice files: count unique z positions if possible
                z_vals = []
                for ds in datasets:
                    ipp = getattr(ds, "ImagePositionPatient", None)
                    iop = getattr(ds, "ImageOrientationPatient", None)
                    if ipp is not None and iop is not None and len(ipp) == 3 and len(iop) >= 6:
                        # Quick proxy: use z = IPP[2]
                        z_vals.append(float(ipp[2]))
                    else:
                        sl = getattr(ds, "SliceLocation", None)
                        if sl is not None:
                            z_vals.append(float(sl))
                if z_vals:
                    orig_depth = int(np.unique(np.round(z_vals, 5)).size)
                if orig_depth is None:
                    # fallback: number of DICOM files
                    orig_depth = len(datasets)
    
            # optional spacing (best-effort)
            try:
                # dy, dx from PixelSpacing; dz from SpacingBetweenSlices or SliceThickness
                ds_ref = datasets[0]
                px = getattr(ds_ref, "PixelSpacing", None)  # [dy, dx]
                dy = float(px[0]) if px is not None else None
                dx = float(px[1]) if px is not None else None
                dz = getattr(ds_ref, "SpacingBetweenSlices", None)
                dz = float(dz) if dz is not None else None
                if dz is None:
                    st = getattr(ds_ref, "SliceThickness", None)
                    dz = float(st) if st is not None else None
            except Exception:
                dz = dy = dx = None
    
            # 2) Produce the resampled volume (your existing logic)
            first_ds = datasets[0]
            first_img = first_ds.pixel_array
    
            if len(datasets) == 1 and first_img.ndim == 3:
                vol = self._process_single_3d_dicom(first_ds, series_name)  # (32,H,W) float/uint8
            else:
                vol = self._process_multiple_2d_dicoms(datasets, series_name)  # (32,H,W)
    
            if return_meta:
                return vol, {"orig_depth": int(orig_depth) if orig_depth is not None else None,
                             "spacing": (dz, dy, dx) if (dz is not None and dy is not None and dx is not None) else None}
            return vol
    
        except Exception:
            raise

    
    def _process_single_3d_dicom(self, ds: pydicom.Dataset, series_name: str) -> np.ndarray:
        """
        Process single 3D DICOM file (for Kaggle: no file saving)
        """
        # Get pixel array
        volume = ds.pixel_array.astype(np.float32)
        
        # Apply RescaleSlope and RescaleIntercept
        slope = float(getattr(ds, "RescaleSlope", 1.0))
        intercept = float(getattr(ds, "RescaleIntercept", 0.0))
        if slope != 1.0 or intercept != 0.0:
            volume = volume * slope + intercept
            # #print(f"Applied rescaling: slope={slope}, intercept={intercept}")
        
        # Get windowing settings
        window_center, window_width = self.get_windowing_params(ds)
        
        # Apply windowing to each slice
        processed_slices = []
        for i in range(volume.shape[0]):
            slice_img = volume[i]
            processed_img = self.apply_windowing_or_normalize(slice_img, window_center, window_width)
            processed_slices.append(processed_img)
        
        volume = np.stack(processed_slices, axis=0)
        ##print(f"3D volume shape after windowing: {volume.shape}")
        
        # 3D resize
        final_volume = self.resize_volume_3d(volume)
        
        ##print(f"Successfully processed 3D DICOM series {series_name}")
        return final_volume
    
    def _process_multiple_2d_dicoms(self, datasets: List[pydicom.Dataset], series_name: str) -> np.ndarray:
        """
        Process multiple 2D DICOM files (for Kaggle: no file saving)
        """
        slice_info = self.extract_slice_info(datasets)
        sorted_slices = self.sort_slices_by_position(slice_info)
        first_img = self.extract_pixel_array(sorted_slices[0]['dataset'])
        window_center, window_width = self.get_windowing_params(sorted_slices[0]['dataset'], first_img)
        processed_slices = []
        
        for slice_data in sorted_slices:
            ds = slice_data['dataset']
            img = self.extract_pixel_array(ds)
            processed_img = self.apply_windowing_or_normalize(img, window_center, window_width)
            # resized_img = cv2.resize(processed_img, (self.target_width, self.target_height), interpolation=cv2.INTER_AREA)
            resized_img = cv2.resize(processed_img, (self.target_width, self.target_height), interpolation=cv2.INTER_LINEAR)
            
            processed_slices.append(resized_img)

        volume = np.stack(processed_slices, axis=0)
        ##print(f"2D slices stacked to volume shape: {volume.shape}")
        final_volume = self.resize_volume_3d(volume)
        
        ##print(f"Successfully processed 2D DICOM series {series_name}")
        return final_volume

def process_dicom_series_kaggle(series_path: str, target_shape: Tuple[int, int, int] = (32, 384, 384)) -> np.ndarray:
    """
    DICOM processing function for Kaggle inference (single series)
    
    Args:
        series_path: Path to DICOM series
        target_shape: Target volume size (depth, height, width)
    
    Returns:
        np.ndarray: Processed volume
    """
    preprocessor = DICOMPreprocessorKaggle(target_shape=target_shape)
    return preprocessor.process_series(series_path)

# Safe processing function with memory cleanup
def process_dicom_series_safe(series_path: str, target_shape: Tuple[int, int, int] = (32, 384, 384)) -> np.ndarray:
    """
    Safe DICOM processing with memory cleanup
    
    Args:
        series_path: Path to DICOM series
        target_shape: Target volume size (depth, height, width)
    
    Returns:
        np.ndarray: Processed volume
    """
    try:
        volume = process_dicom_series_kaggle(series_path, target_shape)
        return volume
    finally:
        # Memory cleanup
        gc.collect()

# Test function
def test_single_series(series_path: str, target_shape: Tuple[int, int, int] = (32, 384, 384)):
    """
    Test processing for single series
    """
    try:
        #print(f"Testing single series: {series_path}")
        
        # Execute processing
        volume = process_dicom_series_safe(series_path, target_shape)
        
        # Display results
        #print(f"Successfully processed series")
        #print(f"Volume shape: {volume.shape}")
        #print(f"Volume dtype: {volume.dtype}")
        #print(f"Volume range: [{volume.min()}, {volume.max()}]")
        
        return volume
        
    except Exception as e:
        #print(f"Failed to process series: {e}")
        return None

In [None]:
#  DDP / AMP helpers

import os
import torch.distributed as dist

def get_dist_env():
    """Read torchrun/torch.distributed env with safe defaults."""
    local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", 0)))
    rank       = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    return local_rank, rank, world_size

def setup_distributed():
    local_rank, rank, world_size = get_dist_env()
    is_distributed = world_size > 1
    if is_distributed and not dist.is_initialized():
        dist.init_process_group(backend="nccl", init_method="env://")
        torch.cuda.set_device(local_rank)
    return local_rank, rank, world_size, is_distributed

def cleanup_distributed():
    if dist.is_available() and dist.is_initialized():
        dist.barrier()
        dist.destroy_process_group()

def is_main_process():
    return int(os.environ.get("RANK", 0)) == 0

def seed_everything(base_seed=42):
    # different seed per rank for true shuffling; still reproducible
    _, rank, _ = get_dist_env()
    seed = base_seed + rank
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    return seed

def scale_lr_for_world_size(lr: float):
    """Linear LR scaling (per-GPU batch fixed, total batch = world_size * perGPU)."""
    _, _, world_size = get_dist_env()
    return lr * max(1, world_size)

# Optional toggles you can use elsewhere
USE_CHANNELS_LAST = True     # improves throughput on T4 with AMP
USE_TORCH_COMPILE = False     # try torch.compile; fall back if it errors
AMP_DTYPE = "bf16"           # T4 → fp16; (A100/H100 can use "bf16")

In [None]:
# Config & labels 
import os, math, json, random, copy
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import timm

# Label order must match competition columns
LABEL_COLS = [
    "Left Infraclinoid Internal Carotid Artery",
    "Right Infraclinoid Internal Carotid Artery",
    "Left Supraclinoid Internal Carotid Artery",
    "Right Supraclinoid Internal Carotid Artery",
    "Left Middle Cerebral Artery",
    "Right Middle Cerebral Artery",
    "Anterior Communicating Artery",
    "Left Anterior Cerebral Artery",
    "Right Anterior Cerebral Artery",
    "Left Posterior Communicating Artery",
    "Right Posterior Communicating Artery",
    "Basilar Tip",
    "Other Posterior Circulation",
    "Aneurysm Present",
]
ANEURYSM_PRESENT_IDX = 13

@dataclass
class CFG:
    series_root: str = r"D:/User Data/Downloads/rsna-intracranial-aneurysm-detection/series"
    train_csv: str  = r"D:/User Data/Downloads/rsna-intracranial-aneurysm-detection/train.csv"
    localizers_csv_path: str = r"D:/User Data/Downloads/rsna-intracranial-aneurysm-detection/train_localizers.csv"

    img_size: int = 384
    base_slices: int = 32
    extra_cached_chans: int = 0
    use_vessel_sidecar: bool = True
    vessel_sidecar_mode: str = "mip"
    vessel_extra: int = 1 if (use_vessel_sidecar and vessel_sidecar_mode == "mip") else (32 if (use_vessel_sidecar and vessel_sidecar_mode == "per_slice") else 0)
    use_localizers: bool = True
    max_localizer_crops: int = 3
    
    local_crop_size: int = 128
    p_localizer_dropout: float = 0.30
    # optional: occasionally turn localizers fully off
    p_global_localizer_off: float = 0.10

    num_classes: int = 14
    model_name: str = "maxvit_base_tf_384"
    epochs: int = 34
    batch_size: int = 2
    num_workers: int = 2
    lr: float = 1.6e-4
    weight_decay: float = 0.05
    warmup_epochs: float = 5.0
    min_lr: float = 3e-6
    clip_grad_norm: float = 1.0
    use_amp: bool = True
    label_smoothing: float = 0.02
    focal_loss: bool = False
    focal_gamma: float = 1.5
    pos_weight: float = 1.0
    folds: int = 5
    seed: int = 42
    out_dir: str = "./outputs"
    save_name: str = "maxvitbasemodel"
    seeds: list = field(default_factory=lambda: [42, 2025, 8])

    # computed after init
    in_chans: int = base_slices + extra_cached_chans + max_localizer_crops + vessel_extra 
    # in_chans: int = field(init=False)

    def __post_init__(self):
        k = self.max_localizer_crops if self.use_localizers else 0
        self.in_chans = self.base_slices + self.extra_cached_chans + k


def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

def per_label_auc(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
    out = {}
    for i, name in enumerate(LABEL_COLS):
        yi, pi = y_true[:, i], y_prob[:, i]
        out[name] = roc_auc_score(yi, pi) if len(np.unique(yi)) >= 2 else np.nan
    return out

def comp_weighted_auc(aucs: Dict[str, float]) -> float:
    weights, vals = [], []
    for i, name in enumerate(LABEL_COLS):
        w = 13.0 if i == ANEURYSM_PRESENT_IDX else 1.0
        if not np.isnan(aucs[name]):
            weights.append(w); vals.append(aucs[name]*w)
    return (sum(vals)/sum(weights)) if weights else np.nan

def cfg_to_dict(cfg_cls) -> dict:
    return {
        k: getattr(cfg_cls, k) 
        for k in dir(cfg_cls)
        if not k.startswith("__") and not callable(getattr(cfg_cls, k))
    }

# Cacheing via shards
Strategy

Shard the cache into folders/files ≤ ~5–8 GB each (safe margin).

In each Kaggle run, build one shard in /kaggle/working/shard_k/…, then:

Option A: leave as plain .npy files inside shard folder.

Option B (nice for scale): pack into WebDataset tar shards (.tar with simple naming).

Download the shard (or “Commit & Save Output”) and upload as a Kaggle Dataset (either one dataset with many files or multiple versions).

In your training notebook, Add Data → select your dataset(s). They’ll appear under /kaggle/input/<your-dataset>/… (no 20 GB limit).

Your DataLoader reads from /kaggle/input paths (mmap .npy or stream tar shards).

In [None]:

# import os, math, hashlib, numpy as np, pandas as pd, time, multiprocessing as mp
# from tqdm.auto import tqdm

# # Prevent oversubscription (each worker will do BLAS/ndimage work)
# os.environ.setdefault("OMP_NUM_THREADS", "1")
# os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
# os.environ.setdefault("MKL_NUM_THREADS", "1")
# os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# IMG_SIZE      = CFG.img_size      # 384 to preserve quality (or 320)
# BYTES_PER_ELT = 1                 # uint8
# DEPTH         = 32
# TARGET_BYTES_PER_SHARD = 6 * 1024**3
# OUT_BASE = f"/kaggle/working/cache_u8_{IMG_SIZE}"

# # Choose which shard to build in THIS run:
# NUM_SHARDS   = None   # None = auto-compute
# SHARD_ID     = 3      # change per run
# NUM_WORKERS  = max(1, mp.cpu_count()-1)  

# df_all = pd.read_csv(CFG.train_csv)
# df_all = df_all[df_all["SeriesInstanceUID"].apply(
#     lambda u: os.path.isdir(os.path.join(CFG.series_root, str(u)))
# )].reset_index(drop=True)

# bytes_per_series = DEPTH * IMG_SIZE * IMG_SIZE * BYTES_PER_ELT
# est_total_bytes  = len(df_all) * bytes_per_series
# if NUM_SHARDS is None:
#     NUM_SHARDS = max(1, math.ceil(est_total_bytes / TARGET_BYTES_PER_SHARD))

# print(f"[Shard plan] ~{est_total_bytes/1e9:.2f} GB total, "
#       f"~{TARGET_BYTES_PER_SHARD/1e9:.1f} GB per shard → NUM_SHARDS={NUM_SHARDS}")
# assert 0 <= SHARD_ID < NUM_SHARDS, "Set SHARD_ID within [0, NUM_SHARDS)"

# def sid_to_shard(sid: str, num_shards: int) -> int:
#     h = int(hashlib.md5(sid.encode("utf-8")).hexdigest()[:8], 16)
#     return h % num_shards

# df_shard = df_all[df_all["SeriesInstanceUID"].astype(str).map(
#     lambda s: sid_to_shard(s, NUM_SHARDS) == SHARD_ID
# )].reset_index(drop=True)
# print(f"[Shard {SHARD_ID}/{NUM_SHARDS}] series: {len(df_shard)}")

# OUT_DIR = f"{OUT_BASE}_shard{SHARD_ID:02d}"
# os.makedirs(OUT_DIR, exist_ok=True)
# print("Output dir:", OUT_DIR)

# def cache_path_for(sid: str) -> str:
#     return os.path.join(OUT_DIR, f"{sid}_{IMG_SIZE}.npy")

# sids_all = df_shard["SeriesInstanceUID"].astype(str).tolist()
# to_write = [sid for sid in sids_all if not os.path.exists(cache_path_for(sid))]
# to_skip  = len(sids_all) - len(to_write)
# print(f"Will save: {len(to_write)}  |  Already cached (skip): {to_skip}")

# # Worker function (creates its own preprocessor)
# def _worker(sid):
#     try:
#         from __main__ import DICOMPreprocessorKaggle, CFG
#         preproc = DICOMPreprocessorKaggle(target_shape=(DEPTH, IMG_SIZE, IMG_SIZE))
#         vol, meta = preproc.process_series(os.path.join(CFG.series_root, sid), return_meta=True)
#         if vol.dtype != np.uint8:
#             vol = np.clip(vol, 0, 255).astype(np.uint8)
#         np.save(cache_path_for(sid), vol, allow_pickle=False)
#         return (sid, int(meta.get('orig_depth') or 32), True, None)
#     except Exception as e:
#         return (sid, None, False, str(e))


# t0 = time.time()

# # Collect results from workers
# results = []
# ok, err = 0, 0
# with mp.Pool(processes=NUM_WORKERS) as pool, tqdm(total=len(to_write), unit="series",
#                                                   desc=f"Shard {SHARD_ID} @ {IMG_SIZE}px") as pbar:
#     for sid, orig_depth, success, error_msg in pool.imap_unordered(_worker, to_write, chunksize=2):
#         results.append((sid, orig_depth, success, error_msg))
#         if success: ok += 1
#         else: err += 1
#         pbar.update(1)

# # Build manifest rows: include both processed and skipped
# rows = []

# # 1) Add processed results
# for sid, orig_depth, success, error_msg in results:
#     rows.append({
#         "SeriesInstanceUID": sid,
#         "img_size": IMG_SIZE,
#         "cache_path": cache_path_for(sid),
#         "orig_depth": orig_depth,
#         "skipped": False,
#         "success": bool(success),
#         "error": (None if success else (str(error_msg) if error_msg is not None else "unknown"))
#     })

# # 2) Add skipped (already cached) entries so the manifest is complete
# skipped_sids = [sid for sid in sids_all if sid not in set(to_write)]
# for sid in skipped_sids:
#     rows.append({
#         "SeriesInstanceUID": sid,
#         "img_size": IMG_SIZE,
#         "cache_path": cache_path_for(sid),
#         "orig_depth": None,   # unknown because we didn't reprocess; can be filled later if you have it
#         "skipped": True,
#         "success": True,
#         "error": None
#     })

# # Save manifest as Parquet in the shard folder
# manifest_path = os.path.join(OUT_DIR, f"manifest_{IMG_SIZE}.parquet")
# pd.DataFrame(rows).to_parquet(manifest_path, index=False)
# print(f"Manifest saved: {manifest_path}")

# dt = time.time() - t0
# print(f"[Shard {SHARD_ID}] saved: {ok}, skipped: {to_skip}, errors: {err}, "
#       f"elapsed: {dt/60:.1f} min (~{(dt/max(1, max(ok,1))):0.2f}s/series)")
# print("Shard folder is ready to download or save as Notebook Output.")

In [None]:
# # =============================
# # Cache vs Raw Preprocessing Check
# # =============================
# import os, random, math, numpy as np, pandas as pd
# from pathlib import Path
# from typing import List, Optional, Tuple
# from tqdm.auto import tqdm
# import glob

# # --- reuse your cache discovery helpers ---
# def discover_shard_roots_for(img_size: int) -> List[str]:
#     SHARDS_ROOT = "/kaggle/input/shards"
#     pattern = os.path.join(SHARDS_ROOT, "*", f"cache_u8_{img_size}_shard*")
#     roots = sorted([p for p in glob.glob(pattern) if os.path.isdir(p)])
#     if is_main_process():
#         print(f"Found {len(roots)} shard roots for {img_size}px")
#     return roots

# def make_find_cached_path(shard_roots: List[str]):
#     def _find_cached_path(sid: str, img_size: int) -> Optional[str]:
#         fname = f"{sid}_{img_size}.npy"
#         for root in shard_roots:
#             p = os.path.join(root, fname)
#             if os.path.exists(p):
#                 return p
#         return None
#     return _find_cached_path

# def _psnr_u8(a: np.ndarray, b: np.ndarray) -> float:
#     # a, b uint8 arrays of same shape
#     diff = a.astype(np.float32) - b.astype(np.float32)
#     mse = float(np.mean(diff**2))
#     if mse == 0: 
#         return float('inf')
#     return 20.0 * math.log10(255.0) - 10.0 * math.log10(mse)

# def _compare_pair(vol_cached_u8: np.ndarray, vol_raw_u8: np.ndarray) -> dict:
#     assert vol_cached_u8.shape == vol_raw_u8.shape, f"Shape mismatch {vol_cached_u8.shape} vs {vol_raw_u8.shape}"
#     assert vol_cached_u8.dtype == np.uint8 and vol_raw_u8.dtype == np.uint8

#     a = vol_cached_u8
#     b = vol_raw_u8

#     mean_abs = float(np.mean(np.abs(a.astype(np.int16) - b.astype(np.int16))))
#     max_abs  = int(np.max(np.abs(a.astype(np.int16) - b.astype(np.int16))))
#     eq_rate  = float(np.mean(a == b))
#     psnr     = _psnr_u8(a, b)

#     # also compare after scaling to [0,1] like training input
#     a01 = a.astype(np.float32) / 255.0
#     b01 = b.astype(np.float32) / 255.0
#     mae01 = float(np.mean(np.abs(a01 - b01)))
#     rmse01 = float(np.sqrt(np.mean((a01 - b01)**2)))

#     return {
#         "mean_abs_u8": mean_abs,
#         "max_abs_u8":  max_abs,
#         "eq_rate":     eq_rate,
#         "psnr_u8":     psnr,
#         "mae_01":      mae01,
#         "rmse_01":     rmse01,
#         "min_cached":  int(a.min()), "max_cached": int(a.max()),
#         "min_raw":     int(b.min()), "max_raw":   int(b.max()),
#     }

# @torch.no_grad()
# def verify_cache_vs_raw(sample_n: int = 50, seed: int = 123, verbose: bool = True) -> pd.DataFrame:
#     """
#     Compare cached u8 volumes vs on-the-fly preprocessor outputs.
#     Returns a DataFrame with metrics per sampled SID.
#     """
#     rng = random.Random(seed)

#     # load list of available series
#     df_all = pd.read_csv(CFG.train_csv)
#     exists = df_all["SeriesInstanceUID"].apply(lambda u: os.path.isdir(os.path.join(CFG.series_root, str(u))))
#     df_all = df_all[exists].reset_index(drop=True)

#     # cache resolver
#     shard_roots = discover_shard_roots_for(CFG.img_size)
#     find_cached_path = make_find_cached_path(shard_roots)

#     # pick SIDs that have a cache file
#     candidates = []
#     for sid in df_all["SeriesInstanceUID"].astype(str).tolist():
#         cp = find_cached_path(sid, CFG.img_size)
#         if cp is not None:
#             candidates.append((sid, cp))

#     if len(candidates) == 0:
#         raise RuntimeError("No cached files found; build cache first.")

#     if sample_n > len(candidates):
#         sample_n = len(candidates)

#     sample = rng.sample(candidates, sample_n)

#     # instantiate preprocessor once (same as cache builder)
#     preproc = DICOMPreprocessorKaggle(target_shape=(CFG.in_chans, CFG.img_size, CFG.img_size))

#     rows = []
#     for sid, cache_path in tqdm(sample, desc="Checking cache vs raw", unit="series"):
#         # load cached
#         vol_cached = np.load(cache_path, mmap_mode="r")
#         if vol_cached.dtype != np.uint8:
#             vol_cached = np.clip(vol_cached, 0, 255).astype(np.uint8)

#         # recompute raw
#         series_path = os.path.join(CFG.series_root, sid)
#         vol_raw = preproc.process_series(series_path)
#         if vol_raw.dtype != np.uint8:
#             vol_raw = np.clip(vol_raw, 0, 255).astype(np.uint8)

#         # compare
#         try:
#             metrics = _compare_pair(vol_cached, vol_raw)
#         except AssertionError as e:
#             metrics = {"error": str(e)}

#         row = {"SeriesInstanceUID": sid, **metrics}
#         rows.append(row)

#     df = pd.DataFrame(rows)

#     if verbose:
#         if "error" in df.columns:
#             has_error = df["error"].notna() & (df["error"].astype(str).str.len() > 0)
#         else:
#             has_error = pd.Series(False, index=df.index)

#         ok = df[~has_error]
#         if len(ok):
#             psnr_vals = ok["psnr_u8"].replace(np.inf, 100.0) if "psnr_u8" in ok else pd.Series(dtype=float)
#             print(
#                 "Summary (no-error rows): "
#                 f"mean mean_abs_u8={ok['mean_abs_u8'].mean():.4f}, "
#                 f"mean max_abs_u8={ok['max_abs_u8'].mean():.2f}, "
#                 f"mean eq_rate={ok['eq_rate'].mean():.4f}, "
#                 f"mean psnr_u8={psnr_vals.mean():.2f} dB, "
#                 f"mean mae_01={ok['mae_01'].mean():.6f}, "
#                 f"mean rmse_01={ok['rmse_01'].mean():.6f}"
#             )
#         bad = df[has_error]
#         if len(bad):
#             print(f"{len(bad)} series had shape/dtype errors; inspect df for details.")

#     return df

# # ---- Run it (example) ----
# df_check = verify_cache_vs_raw(sample_n=40, seed=42)
# display(df_check.sort_values("mean_abs_u8", ascending=False).head(10))

In [None]:
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import cv2
import torch
from torch.utils.data import Dataset
from typing import Optional, Dict, List, Tuple
import glob
import ast

def load_localizers_csv(csv_path: Optional[str], max_points_per_series: int = 3) -> Dict[str, List[dict]]:
    """
    Load localizers with columns:
      - SeriesInstanceUID
      - SOPInstanceUID
      - coordinates: string like "{'x': 258.3, 'y': 261.4}" or '{"x":..., "y":...}'
      - location: optional text label

    Returns: { sid: [ {'x': float|None, 'y': float|None, 'sop': str|None, 'loc': str|None}, ... ] }
    """
    if not csv_path:
        return {}

    df = pd.read_csv(csv_path)
    # normalize column names
    cols = {c.lower(): c for c in df.columns}
    sid_col = cols.get('seriesinstanceuid') or 'SeriesInstanceUID'
    sop_col = cols.get('sopinstanceuid') or 'SOPInstanceUID'
    coord_col = cols.get('coordinates') or 'coordinates'
    loc_col = cols.get('location') or ('Location' if 'Location' in df.columns else None)

    keep = [c for c in [sid_col, sop_col, coord_col, loc_col] if c in df.columns]
    df = df[keep].copy()

    by_sid: Dict[str, List[dict]] = defaultdict(list)
    for _, r in df.iterrows():
        sid = str(r[sid_col])
        sop = str(r[sop_col]) if sop_col in r and pd.notna(r[sop_col]) else None
        loc = str(r[loc_col]) if loc_col and pd.notna(r[loc_col]) else None

        x = y = None
        if coord_col in r and pd.notna(r[coord_col]):
            s = str(r[coord_col]).strip()
            try:
                # handle both single-quote and JSON strings
                if s.startswith('{') and s.endswith('}'):
                    xy = ast.literal_eval(s)
                    x = float(xy.get('x')) if xy.get('x') is not None else None
                    y = float(xy.get('y')) if xy.get('y') is not None else None
            except Exception:
                x = y = None

        by_sid[sid].append({'x': x, 'y': y, 'sop': sop, 'loc': loc})

    # cap per series
    for sid in list(by_sid.keys()):
        by_sid[sid] = by_sid[sid][:max_points_per_series]

    return dict(by_sid)

In [None]:
def discover_shard_roots() -> List[str]:
    """Find all cache shard folders under /kaggle/input/shards/*/cache_u8_{img}_shard*."""
    SHARDS_ROOT = "D:/User Data/Downloads/rsna-intracranial-aneurysm-detection/cache"
    pattern = os.path.join(SHARDS_ROOT, "*", f"cache_u8_{CFG.img_size}_shard*")
    shard_roots = sorted([p for p in glob.glob(pattern) if os.path.isdir(p)])
    if is_main_process():
        print("Found shard roots:", len(shard_roots))
        for p in shard_roots[:8]:
            print("  ", p)
    return shard_roots

In [None]:
# Cached Dataset 

STRICT_CACHE_ONLY = True   # set False to allow on-the-fly fallback

def _safe_crop_2d(img: np.ndarray, cx: int, cy: int, size: int) -> np.ndarray:
    """Center crop with clamping; returns (size,size)."""
    H, W = img.shape
    half = size // 2
    x0 = max(0, cx - half); x1 = min(W, cx + half)
    y0 = max(0, cy - half); y1 = min(H, cy + half)
    crop = img[y0:y1, x0:x1]
    # pad if we hit borders
    if crop.shape[0] != size or crop.shape[1] != size:
        pad_y = size - crop.shape[0]
        pad_x = size - crop.shape[1]
        crop = np.pad(crop,
                      ((0, max(0,pad_y)), (0, max(0,pad_x))),
                      mode='edge')
        crop = crop[:size, :size]
    return crop

def _resize_2d(img: np.ndarray, out_hw: Tuple[int,int]) -> np.ndarray:
    return cv2.resize(img, (out_hw[1], out_hw[0]), interpolation=cv2.INTER_LINEAR)

def _map_localizer_to_cached_depth(
    loc_z: Optional[float],
    loc_f: Optional[int],
    cached_depth: int,
    orig_depth: Optional[int] = None
) -> int:
    """
    Map a localizer z/f to the 0..cached_depth-1 index space.
    - If we know orig_depth, linearly map: round( f / (orig_depth-1) * (cached_depth-1) ).
    - Else if we have z as a [0..orig_depth) style value, same idea.
    - Else fallback to middle slice.
    """
    if orig_depth and loc_f is not None:
        return int(np.clip(round(loc_f / max(1, (orig_depth-1)) * (cached_depth-1)), 0, cached_depth-1))
    if orig_depth and loc_z is not None:
        return int(np.clip(round(loc_z  / max(1, (orig_depth-1)) * (cached_depth-1)), 0, cached_depth-1))
    # fallback middle slice
    return cached_depth // 2

import os, pydicom, numpy as np
from functools import lru_cache

def _rank_to_cached_idx(rank: int, orig_depth: int, cached_depth: int) -> int:
    if orig_depth <= 1: 
        return cached_depth // 2
    r = np.clip(rank, 0, orig_depth-1)
    return int(round(r / (orig_depth - 1) * (cached_depth - 1)))

@lru_cache(maxsize=512)
def _build_sop_rank_map(series_dir: str) -> tuple[dict, int]:
    """
    Returns (sop_to_rank, orig_depth) for a series, computed by sorting slices by z.
    No pixel reads; fast.
    """
    items = []
    try:
        for name in os.listdir(series_dir):
            if not name.lower().endswith(".dcm"):
                continue
            path = os.path.join(series_dir, name)
            ds = pydicom.dcmread(path, stop_before_pixels=True, force=True)
            sop = str(getattr(ds, "SOPInstanceUID", os.path.splitext(name)[0]))
            ipp = getattr(ds, "ImagePositionPatient", None)
            z = float(ipp[2]) if ipp is not None and len(ipp) == 3 else float(getattr(ds, "SliceLocation", 0.0))
            items.append((sop, z))
    except Exception:
        pass
    if not items:
        return ({}, 0)
    # sort by z, assign ranks 0..N-1
    items.sort(key=lambda t: t[1])
    sop_to_rank = {sop: i for i, (sop, _) in enumerate(items)}
    return sop_to_rank, len(items)

import hashlib
import numpy as np

def _sid_seed(sid: str, salt: str = "locrand") -> int:
    h = hashlib.sha1((salt + sid).encode()).hexdigest()[:8]
    return int(h, 16)

def _random_local_points(H: int, W: int, K: int, rng: np.random.Generator) -> list[tuple[int,int]]:
    if K <= 0: return []
    xs = rng.integers(low=W//8, high=W - W//8, size=K)
    ys = rng.integers(low=H//8, high=H - H//8, size=K)
    return list(zip(xs.tolist(), ys.tolist()))

class RSNADataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        series_root: str,
        preproc,  # DICOMPreprocessorKaggle
        find_cached_path_fn,
        localizers_csv_path: Optional[str] = None,
        max_localizer_crops: int = 3,
        local_crop_size: int = 128,
        sid_to_orig_depth: Optional[Dict[str, int]] = None,  # if you have it
    ):
        self.df = df.reset_index(drop=True)
        self.series_root = series_root
        self.preproc = preproc
        self.find_cached_path = find_cached_path_fn

        # Localizers
        localizers_csv_path = CFG.localizers_csv_path
        self.localizers_map: Dict[str, List[dict]] = load_localizers_csv(
            localizers_csv_path, max_points_per_series=max_localizer_crops
        ) if localizers_csv_path else {}
        self.max_localizer_crops = max_localizer_crops
        self.local_crop_size = int(local_crop_size)
        self.sid_to_orig_depth = sid_to_orig_depth or {}
        self._epoch = 0
        self._rng = np.random.default_rng(CFG.seed)

    def set_epoch(self, e: int):
        self._epoch = int(e)

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sid = str(row["SeriesInstanceUID"])
        cp = self.find_cached_path(sid, CFG.img_size)

        # tiny helper: 2D resize to (H,W) 
        import cv2
        def resize2d(arr2d, out_hw):
            h, w = int(out_hw[0]), int(out_hw[1])
            if arr2d.shape != (h, w):
                return cv2.resize(arr2d, (w, h), interpolation=cv2.INTER_LINEAR)
            return arr2d

        # load cached (or preprocess) 
        if cp is not None:
            vol_u8 = np.load(cp, mmap_mode="r")  # (C,H,W) usually C==32
        else:
            if STRICT_CACHE_ONLY:
                raise FileNotFoundError(f"Not found in cache: {sid}_{CFG.img_size}.npy")
            series_path = os.path.join(self.series_root, sid)
            vol = self.preproc.process_series(series_path)  # (32,H,W) float32 0..255
            vol_u8 = vol if vol.dtype == np.uint8 else np.clip(vol, 0, 255).astype(np.uint8)

        # sanity / resize base stack to (H,W) 
        if vol_u8.ndim != 3:
            raise ValueError(f"Expected (C,H,W), got {tuple(vol_u8.shape)}")
        base_C = vol_u8.shape[0]
        if base_C < 32:
            raise ValueError(f"Expected at least 32 slices, got {base_C}")

        H, W = vol_u8.shape[1], vol_u8.shape[2]
        if (H != CFG.img_size) or (W != CFG.img_size):
            vol_u8 = np.stack(
                [cv2.resize(vol_u8[c], (CFG.img_size, CFG.img_size), interpolation=cv2.INTER_LINEAR)
                for c in range(vol_u8.shape[0])],
                axis=0
            )
            H, W = CFG.img_size, CFG.img_size  # recompute

        #  optional vesselness sidecar (added b4 localizers) 
        # sidecar filename: "<sid>_<img>_vessel_u8.npy"
        if cp is not None:
            guess1 = cp.replace(f"_{CFG.img_size}.npy", f"_{CFG.img_size}_vessel_u8.npy")
            guess2 = os.path.join(os.path.dirname(cp), f"{sid}_{CFG.img_size}_vessel_u8.npy")
            sidecar_path = guess1 if os.path.exists(guess1) else (guess2 if os.path.exists(guess2) else None)
        else:
            sidecar_path = None

        if getattr(CFG, "use_vessel_sidecar", True) and sidecar_path is not None:
            vess = np.load(sidecar_path, mmap_mode="r")  # (H,W) or (32,H,W)
            mode = getattr(CFG, "vessel_sidecar_mode", "mip")  # "mip" or "per_slice"

            if vess.ndim == 2:
                vess2d = resize2d(np.asarray(vess), (H, W)).astype(np.uint8)
                vol_u8 = np.concatenate([vol_u8, vess2d[np.newaxis, ...]], axis=0)

            elif vess.ndim == 3:
                if vess.shape[-2:] != (H, W):
                    vess = np.stack([resize2d(vess[z], (H, W)) for z in range(vess.shape[0])], axis=0)
                if mode == "per_slice":
                    vol_u8 = np.concatenate([vol_u8, vess.astype(np.uint8)], axis=0)        # +32
                else:
                    vess2d = np.asarray(vess).max(axis=0).astype(np.uint8)                  # +1
                    vol_u8 = np.concatenate([vol_u8, vess2d[np.newaxis, ...]], axis=0)

        # localizer-based extra channels (fixed K) 
        local_chans = []
        K = self.max_localizer_crops
        force_random = False

        pgo = getattr(CFG, "p_global_localizer_off", 0.0)
        if pgo > 0:
            grng = np.random.default_rng(hash(('global_off', sid, self._epoch, CFG.seed)) & 0xffffffff)
            if grng.random() < pgo:
                force_random = True

        if CFG.use_localizers and K > 0:
            # Build MIPs from the first 32 slices only (do not include sidecar channels)
            vol_for_mip = vol_u8[:32] if vol_u8.shape[0] >= 32 else vol_u8
            cached_depth = vol_for_mip.shape[0]

            hint_orig_depth = self.sid_to_orig_depth.get(sid, None)
            series_dir = os.path.join(self.series_root, sid)
            if os.path.isdir(series_dir):
                sop_to_rank, hdr_orig_depth = _build_sop_rank_map(series_dir)
            else:
                sop_to_rank, hdr_orig_depth = ({}, 0)
            use_orig_depth = hint_orig_depth or hdr_orig_depth or cached_depth

            rng = np.random.default_rng(hash((sid, self._epoch, CFG.seed)) & 0xffffffff)
            use_random = force_random or (rng.random() < CFG.p_localizer_dropout)

            pts = self.localizers_map.get(sid, [])
            use_pts = (len(pts) > 0) and (not use_random)

            if use_pts:
                for p in pts[:K]:
                    sop = p.get('sop')
                    if sop and sop in sop_to_rank and use_orig_depth > 0:
                        rank = sop_to_rank[sop]
                        z_idx = _rank_to_cached_idx(rank, use_orig_depth, cached_depth)
                    else:
                        z_idx = _map_localizer_to_cached_depth(
                            p.get('z'), p.get('f'),
                            cached_depth=cached_depth, orig_depth=use_orig_depth
                        )
                    z0 = max(0, z_idx - 8); z1 = min(cached_depth, z_idx + 9)
                    slab = vol_for_mip[z0:z1]
                    mip = slab.max(axis=0) if slab.size else np.zeros((H, W), dtype=vol_u8.dtype)

                    px, py = p.get('x'), p.get('y')
                    if px is None or py is None:
                        cx, cy = W // 2, H // 2
                    else:
                        cx = int(round(np.clip(px, 0, W - 1)))
                        cy = int(round(np.clip(py, 0, H - 1)))

                    crop = _safe_crop_2d(mip, cx, cy, size=self.local_crop_size)
                    crop_full = resize2d(crop, (H, W))
                    local_chans.append(crop_full[np.newaxis, ...])

            if len(local_chans) < K:
                need = K - len(local_chans)
                z_rng = np.random.default_rng(hash((sid, 'z', self._epoch, CFG.seed)) & 0xffffffff)
                cz = z_rng.integers(low=0, high=max(1, cached_depth), size=need)
                for i in range(need):
                    z_idx = int(cz[i])
                    z0 = max(0, z_idx - 8); z1 = min(cached_depth, z_idx + 9)
                    slab = vol_for_mip[z0:z1]
                    mip = slab.max(axis=0) if slab.size else np.zeros((H, W), dtype=vol_u8.dtype)

                    rrng = np.random.default_rng(hash((sid, 'xy', i, self._epoch, CFG.seed)) & 0xffffffff)
                    (cx, cy) = _random_local_points(H, W, 1, rrng)[0]

                    crop = _safe_crop_2d(mip, cx, cy, size=self.local_crop_size)
                    crop_full = resize2d(crop, (H, W))
                    local_chans.append(crop_full[np.newaxis, ...])

            if len(local_chans) > K:
                local_chans = local_chans[:K]
            extra = np.concatenate(local_chans, axis=0) if local_chans else np.zeros((K, H, W), dtype=vol_u8.dtype)
            vol_u8 = np.concatenate([vol_u8, extra], axis=0)

        #  final channel alignment to CFG.in_chans 
        C = vol_u8.shape[0]
        target_C = int(CFG.in_chans)
        if C < target_C:
            pad = np.zeros((target_C - C, H, W), dtype=vol_u8.dtype)
            vol_u8 = np.concatenate([vol_u8, pad], axis=0)
        elif C > target_C:
            vol_u8 = vol_u8[:target_C]

        # --- to tensor ---
        x = torch.from_numpy(np.asarray(vol_u8)).to(torch.float32).div_(255.0)  # (C,H,W) in [0,1]
        y = torch.tensor(row[LABEL_COLS].values.astype(np.float32))
        return x, y, sid


    # def __getitem__(self, idx):
    #     row = self.df.iloc[idx]
    #     sid = str(row["SeriesInstanceUID"])
    #     cp = self.find_cached_path(sid, CFG.img_size)
    
    #     # --- load cached (or preprocess) ---
    #     if cp is not None:
    #         vol_u8 = np.load(cp, mmap_mode="r")  # (C,H,W) or (32,H,W)
    #     else:
    #         if STRICT_CACHE_ONLY:
    #             raise FileNotFoundError(f"Not found in cache: {sid}_{CFG.img_size}.npy")
    #         series_path = os.path.join(self.series_root, sid)
    #         vol = self.preproc.process_series(series_path)  # (32,H,W) float32 0..255
    #         vol_u8 = vol if vol.dtype == np.uint8 else np.clip(vol, 0, 255).astype(np.uint8)
    
    #     # --- sanity checks / shapes ---
    #     if vol_u8.ndim != 3:
    #         raise ValueError(f"Expected (C,H,W), got {tuple(vol_u8.shape)}")
    #     base_C = vol_u8.shape[0]
    #     if base_C < 32:
    #         raise ValueError(f"Expected at least 32 slices, got {base_C}")
    
    #     H, W = vol_u8.shape[1], vol_u8.shape[2]
    #     if (H != CFG.img_size) or (W != CFG.img_size):
    #         import cv2
    #         vol_u8 = np.stack(
    #             [cv2.resize(vol_u8[c], (CFG.img_size, CFG.img_size), interpolation=cv2.INTER_LINEAR)
    #              for c in range(vol_u8.shape[0])],
    #             axis=0
    #         )
    #     # recompute after possible resize
    #     H, W = vol_u8.shape[1], vol_u8.shape[2]

    #     # ---- localizer-based extra channels (fixed K) ----
    #     local_chans = []
    #     K = self.max_localizer_crops
    #     force_random = False

    #     # Global-off: keep K the same, but ignore real points for this sample
    #     pgo = getattr(CFG, "p_global_localizer_off", 0.0)
    #     if pgo > 0:
    #         grng = np.random.default_rng(hash(('global_off', sid, self._epoch, CFG.seed)) & 0xffffffff)
    #         if grng.random() < pgo:
    #             force_random = True

    #     if CFG.use_localizers and K > 0:
    #         vol_for_mip = vol_u8[:32] if base_C >= 32 else vol_u8
    #         cached_depth = vol_for_mip.shape[0]

    #         hint_orig_depth = self.sid_to_orig_depth.get(sid, None)
    #         series_dir = os.path.join(self.series_root, sid)
    #         if os.path.isdir(series_dir):
    #             sop_to_rank, hdr_orig_depth = _build_sop_rank_map(series_dir)
    #         else:
    #             sop_to_rank, hdr_orig_depth = ({}, 0)
    #         use_orig_depth = hint_orig_depth or hdr_orig_depth or cached_depth

    #         rng = np.random.default_rng(hash((sid, self._epoch, CFG.seed)) & 0xffffffff)
    #         # anti-leakage dropout OR forced-random from global-off
    #         use_random = force_random or (rng.random() < CFG.p_localizer_dropout)

    #         pts = self.localizers_map.get(sid, [])
    #         use_pts = (len(pts) > 0) and (not use_random)

    #         if use_pts:
    #             for p in pts[:K]:
    #                 sop = p.get('sop')
    #                 if sop and sop in sop_to_rank and use_orig_depth > 0:
    #                     rank = sop_to_rank[sop]
    #                     z_idx = _rank_to_cached_idx(rank, use_orig_depth, cached_depth)
    #                 else:
    #                     z_idx = _map_localizer_to_cached_depth(
    #                         p.get('z'), p.get('f'), cached_depth=cached_depth, orig_depth=use_orig_depth
    #                     )
    #                 z0 = max(0, z_idx - 8); z1 = min(cached_depth, z_idx + 9)
    #                 slab = vol_for_mip[z0:z1]
    #                 mip = slab.max(axis=0) if slab.size else np.zeros((H, W), dtype=vol_u8.dtype)

    #                 px, py = p.get('x'), p.get('y')
    #                 if px is None or py is None:
    #                     cx, cy = W // 2, H // 2
    #                 else:
    #                     cx = int(round(np.clip(px, 0, W - 1)))
    #                     cy = int(round(np.clip(py, 0, H - 1)))

    #                 crop = _safe_crop_2d(mip, cx, cy, size=self.local_crop_size)
    #                 crop_full = _resize_2d(crop, (H, W))
    #                 local_chans.append(crop_full[np.newaxis, ...])

    #         # Fill up to K with deterministic random crops (epoch-varying)
    #         if len(local_chans) < K:
    #             need = K - len(local_chans)
    #             z_rng = np.random.default_rng(hash((sid, 'z', self._epoch, CFG.seed)) & 0xffffffff)
    #             cz = z_rng.integers(low=0, high=max(1, cached_depth), size=need)
    #             for i in range(need):
    #                 z_idx = int(cz[i])
    #                 z0 = max(0, z_idx - 8); z1 = min(cached_depth, z_idx + 9)
    #                 slab = vol_for_mip[z0:z1]
    #                 mip = slab.max(axis=0) if slab.size else np.zeros((H, W), dtype=vol_u8.dtype)

    #                 rrng = np.random.default_rng(hash((sid, 'xy', i, self._epoch, CFG.seed)) & 0xffffffff)
    #                 (cx, cy) = _random_local_points(H, W, 1, rrng)[0]

    #                 crop = _safe_crop_2d(mip, cx, cy, size=self.local_crop_size)
    #                 crop_full = _resize_2d(crop, (H, W))
    #                 local_chans.append(crop_full[np.newaxis, ...])

    #         # Trim (paranoia) and concat
    #         if len(local_chans) > K:
    #             local_chans = local_chans[:K]
    #         extra = np.concatenate(local_chans, axis=0) if local_chans else np.zeros((K, H, W), dtype=vol_u8.dtype)
    #         vol_u8 = np.concatenate([vol_u8, extra], axis=0)


    #     # --- to tensor ---
    #     x = torch.from_numpy(np.asarray(vol_u8)).to(torch.float32).div_(255.0)  # (C,H,W)
    #     y = torch.tensor(row[LABEL_COLS].values.astype(np.float32))
    #     return x, y, sid


#  smarter stratification helpers 

def _age_to_bin(s: pd.Series) -> pd.Series:
    # PatientAge can be like "067Y" or numeric; coerce to number of years
    raw = pd.to_numeric(s.astype(str).str.extract(r'(\d+)')[0], errors='coerce')
    # decade-ish bins; fill missing as -1
    bins = pd.cut(raw, bins=[0,30,40,50,60,70,80,200], labels=False, include_lowest=True)
    return bins.fillna(-1).astype(int)

def _slice_bin_for_series(series_root: str, sid: str) -> int:
    """Super-cheap proxy for series 'size': count DICOM files in folder."""
    p = os.path.join(series_root, str(sid))
    n = 0
    try:
        with os.scandir(p) as it:
            for e in it:
                if e.is_file():
                    n += 1
                    # (optional) early stop to keep it fast
                    if n > 300: 
                        break
    except FileNotFoundError:
        n = 0
    # bucketize
    if n <= 64:   return 0
    if n <= 128:  return 1
    if n <= 256:  return 2
    return 3

def _make_strat_key(df: pd.DataFrame) -> pd.Series:
    # core target
    ap = df["Aneurysm Present"].astype(int)

    # simple categorical covariates (normalized)
    mod = df.get("Modality", "UNK").astype(str).str.upper().fillna("UNK")
    sex = df.get("PatientSex", "UNK").astype(str).str.upper().fillna("UNK")

    # age bins
    ageb = _age_to_bin(df.get("PatientAge", pd.Series([-1]*len(df))))

    # slice-count bins (very fast directory count)
    # note: uses CFG.series_root and SeriesInstanceUID
    slb = df["SeriesInstanceUID"].astype(str).apply(lambda sid: _slice_bin_for_series(CFG.series_root, sid))

    # compose a single strat label
    key = (
        ap.astype(str) + "_" +
        mod + "_" +
        sex + "_" +
        ageb.astype(str) + "_" +
        pd.Series(slb, index=df.index).astype(str)
    )
    return key

def build_folds() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Load CSV, keep existing series, add fold column with composite stratification."""
    df = pd.read_csv(CFG.train_csv)

    # keep only series that actually exist on disk
    exists = df["SeriesInstanceUID"].apply(lambda u: os.path.isdir(os.path.join(CFG.series_root, str(u))))
    df = df[exists].reset_index(drop=True)

    # composite strat key: target + modality + sex + age bin + slice-count bin
    strat_key = _make_strat_key(df)

    skf = StratifiedKFold(n_splits=CFG.folds, shuffle=True, random_state=CFG.seed)
    df["fold"] = -1
    for fold_i, (_, val_idx) in enumerate(skf.split(df, strat_key)):
        df.loc[val_idx, "fold"] = fold_i

    fold_idx = 0  # choose default here; you can pass a different one into build_loaders
    train_df = df[df["fold"] != fold_idx].reset_index(drop=True)
    val_df   = df[df["fold"] == fold_idx].reset_index(drop=True)
    return df, train_df, val_df

def load_sid_to_orig_depth(
    shard_roots: list[str],
    img_size: int,
    extra_search_roots: list[str] | None = None,
    verbose: bool = True,
) -> dict[str, int]:
    """
    Recursively search under each root for 'manifest_{img_size}.parquet' and build:
        { SeriesInstanceUID -> orig_depth }
    Works for paths like:
      D:/.../cache/SHARD_ID=0/cache_u8_384_shard00/manifest_384.parquet
    """
    # Collect candidate manifest files (recursive, Windows-safe)
    roots = (shard_roots or []) + (extra_search_roots or [])
    man_paths = set()
    for root in roots:
        if not root or not os.path.isdir(root):
            continue
        # same-dir check
        p = os.path.join(root, f"manifest_{img_size}.parquet")
        if os.path.exists(p):
            man_paths.add(os.path.normpath(p))
        # recursive search (any depth)
        pattern = os.path.join(root, "**", f"manifest_{img_size}.parquet")
        for mp in glob.glob(pattern, recursive=True):
            man_paths.add(os.path.normpath(mp))

    if not man_paths:
        if verbose:
            print("[manifest] No manifest files found under provided roots.")
        return {}

    if verbose:
        print(f"[manifest] Found {len(man_paths)} manifest(s).")

    sid_to_depth: dict[str, int] = {}
    for mp in sorted(man_paths):
        try:
            m = pd.read_parquet(mp)                         
        except Exception:
            m = pd.read_parquet(mp, engine="fastparquet") # fallback
        if not {"SeriesInstanceUID", "orig_depth"}.issubset(m.columns):
            continue
        sub = m[["SeriesInstanceUID", "orig_depth"]].dropna(subset=["orig_depth"])
        for sid, od in zip(sub["SeriesInstanceUID"].astype(str), sub["orig_depth"].astype(int)):
            # "latest wins" if duplicates across shards
            sid_to_depth[sid] = int(od)

    if verbose:
        print(f"[manifest] Loaded orig_depth for {len(sid_to_depth)} series.")
    return sid_to_depth


def build_loaders(fold_idx: int = 0):
    """
    Build train/val DataLoaders with DistributedSampler if WORLD_SIZE>1.
    Uses the same composite stratification as build_folds().
    """
    # discover cache + resolver
    shard_roots = discover_shard_roots()

    # build an O(1) index: (uid, size) -> path  [RECURSIVE!]
    uid_to_path = {}
    pattern = f"*_{CFG.img_size}.npy"
    for root in shard_roots:
        for p in glob.glob(os.path.join(root, "**", pattern), recursive=True):
            fname = os.path.basename(p)            # e.g. "1.2.840..._384.npy"
            uid, size_part = fname.rsplit("_", 1)  # ["1.2.840...", "384.npy"]
            size = int(size_part.split(".")[0])    # 384
            uid_to_path[(uid, size)] = p

    def find_cached_path(uid: str, img_size: int) -> str | None:
        return uid_to_path.get((uid, img_size))
    
    find_cached_path_fn = find_cached_path
    # build SeriesInstanceUID -> orig_depth from per-shard manifests
    sid_to_orig_depth = load_sid_to_orig_depth(shard_roots, CFG.img_size)
    
    # distributed env
    local_rank, rank, world_size, is_distributed = setup_distributed()
    seed_everything(CFG.seed)

    # folds (recompute the same split deterministically)
    df = pd.read_csv(CFG.train_csv)
    exists = df["SeriesInstanceUID"].apply(lambda u: os.path.isdir(os.path.join(CFG.series_root, str(u))))
    df = df[exists].reset_index(drop=True)

    strat_key = _make_strat_key(df)
    skf = StratifiedKFold(n_splits=CFG.folds, shuffle=True, random_state=CFG.seed)
    df["fold"] = -1
    for fi, (_, val_idx) in enumerate(skf.split(df, strat_key)):
        df.loc[val_idx, "fold"] = fi

    train_df = df[df["fold"] != fold_idx].reset_index(drop=True)
    val_df   = df[df["fold"] == fold_idx].reset_index(drop=True)

    # optional: quick cache check (main process only)
    if is_main_process() and len(shard_roots) > 0 and len(df) > 0:
        sample_sid = str(df.iloc[0]["SeriesInstanceUID"])
        print("Sample cached path:", find_cached_path_fn(sample_sid, CFG.img_size))

    # datasets
    preproc = DICOMPreprocessorKaggle(target_shape=(CFG.base_slices, CFG.img_size, CFG.img_size))
    
    train_ds = RSNADataset(
        train_df, CFG.series_root, preproc, find_cached_path_fn, 
        localizers_csv_path=getattr(CFG, "localizers_csv_path", None),
        max_localizer_crops=getattr(CFG, "max_localizer_crops", 3),
        local_crop_size=getattr(CFG, "local_crop_size", 128),
        sid_to_orig_depth=sid_to_orig_depth,)
    
    val_ds = RSNADataset(
        val_df,   CFG.series_root, preproc, find_cached_path_fn,
        localizers_csv_path=getattr(CFG, "localizers_csv_path", None),
        max_localizer_crops=getattr(CFG, "max_localizer_crops", 3),
        local_crop_size=getattr(CFG, "local_crop_size", 128),
        sid_to_orig_depth=sid_to_orig_depth,
    )

    # samplers
    if world_size > 1:
        train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True)
        val_sampler   = DistributedSampler(val_ds,   num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
    else:
        train_sampler = None
        val_sampler   = None

    # loaders
    train_loader = DataLoader(
        train_ds,
        batch_size=CFG.batch_size,             # per-GPU
        sampler=train_sampler,
        shuffle=(train_sampler is None),
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=True,
        persistent_workers=(CFG.num_workers > 0),
        prefetch_factor=2 if CFG.num_workers > 0 else None,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=max(1, CFG.batch_size // 2),
        sampler=val_sampler,
        shuffle=False,
        num_workers=CFG.num_workers,
        pin_memory=True,
        persistent_workers=(CFG.num_workers > 0),
        prefetch_factor=2 if CFG.num_workers > 0 else None,
    )

    if is_main_process():
        print(f"World size: {world_size}  |  Rank: {rank}  |  Local rank: {local_rank}")
        print(f"Train: {len(train_ds)} | Val: {len(val_ds)}")

    return train_loader, val_loader, fold_idx, world_size, rank, local_rank

In [None]:
class BCEWithLogitsSmooth(nn.Module):
    def __init__(self, smoothing=0.0, pos_weight=None):
        super().__init__()
        self.smoothing = smoothing
        self.pos_weight = pos_weight
    def forward(self, logits, targets):
        if self.smoothing > 0.0:
            targets = targets * (1 - self.smoothing) + 0.5 * self.smoothing
        return nn.functional.binary_cross_entropy_with_logits(logits, targets, pos_weight=self.pos_weight)

class FocalWithLogits(nn.Module):
    def __init__(self, gamma=1.5, pos_weight=None):
        super().__init__()
        self.gamma = gamma; self.pos_weight = pos_weight
    def forward(self, logits, targets):
        bce = nn.functional.binary_cross_entropy_with_logits(logits, targets, pos_weight=self.pos_weight, reduction="none")
        p = torch.sigmoid(logits); pt = p*targets + (1-p)*(1-targets)
        return ((1-pt)**self.gamma * bce).mean()

def cosine_sched(step, total_steps, base_lr, min_lr, warmup_steps):
    if step < warmup_steps:
        return base_lr * (step / max(1, warmup_steps))
    t = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return min_lr + 0.5*(base_lr - min_lr)*(1 + math.cos(math.pi*t))

def make_model():
    model = timm.create_model(
        CFG.model_name,
        in_chans=CFG.in_chans,
        num_classes=CFG.num_classes,
        img_size=CFG.img_size,
        pretrained=True
    )
    if USE_CHANNELS_LAST:
        model = model.to(memory_format=torch.channels_last)
    # optional compile for a few extra %
    if USE_TORCH_COMPILE:
        try:
            model = torch.compile(model, mode="reduce-overhead", fullgraph=False)
        except Exception:
            pass
    return model

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

class Trainer:
    def __init__(self, train_loader, val_loader, fold: int):
        # DDP env
        self.local_rank, self.rank, self.world, self.is_distributed = setup_distributed()
        dtype = torch.float16 if AMP_DTYPE == "bf16" else torch.bfloat16 if AMP_DTYPE == "bf16" else None

        # device & model
        self.device = torch.device("cuda", self.local_rank) if torch.cuda.is_available() else torch.device("cpu")
        torch.cuda.set_device(self.local_rank if torch.cuda.is_available() else 0)
        torch.set_float32_matmul_precision("high")

        self.model = make_model().to(self.device)

        if self.is_distributed:
            # important: broadcast buffers True, find_unused False for speed
            self.model = DDP(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                broadcast_buffers=True,
                find_unused_parameters=False,
            )

        # loss / opt / scaler
        pos_weight = torch.tensor([CFG.pos_weight]*CFG.num_classes, device=self.device)
        self.criterion = (FocalWithLogits(CFG.focal_gamma, pos_weight)
                          if CFG.focal_loss else BCEWithLogitsSmooth(CFG.label_smoothing, pos_weight))

        base_lr = CFG.lr  # no scaling; keeps LR stable moving from 1->2 GPUs
        # base_lr = scale_lr_for_world_size(CFG.lr)  # linear scale with world size
        self.optimizer = optim.AdamW(self.model.parameters(), lr=base_lr, weight_decay=CFG.weight_decay)
        self.global_step = 0
        self.ema_decay = 0.9998
        base_model = self.model.module if isinstance(self.model, DDP) else self.model
        self.ema = copy.deepcopy(base_model).eval()
        # ensure EMA is on the same device & memory format
        self.ema.to(self.device)
        if USE_CHANNELS_LAST:
            self.ema.to(memory_format=torch.channels_last)
        for p in self.ema.parameters():
            p.requires_grad_(False)
            
        self.base_lr = self.optimizer.param_groups[0]["lr"]
        self.scaler = GradScaler(enabled=CFG.use_amp and AMP_DTYPE == "bf16")  # GradScaler is for fp16 only

        self.train_loader, self.val_loader = train_loader, val_loader
        self.fold = fold
        self.fold_dir = os.path.join(CFG.out_dir, f"{CFG.save_name}_seed{CFG.seed}_fold{self.fold}")
        os.makedirs(self.fold_dir, exist_ok=True)
        self.total_steps = CFG.epochs * len(train_loader)
        self.warmup_steps = int(CFG.warmup_epochs * len(train_loader))
        os.makedirs(CFG.out_dir, exist_ok=True)
        self.best_auc = -1.0

        # remember autocast dtype
        self.autocast_dtype = torch.float16 if AMP_DTYPE == "bf16" else (torch.bfloat16 if AMP_DTYPE == "bf16" else None)

    def _cast_input(self, x):
        # Keep channels_last for better memory access on T4
        if USE_CHANNELS_LAST and x.ndim == 4:
            x = x.contiguous(memory_format=torch.channels_last)
        return x
    
    def _update_ema(self, step: int):
        """EMA update: params with decay; buffers (BN stats) copied 1:1 each step.
           On very first step, copy whole state 1:1 (warm start)."""
        src = self.model.module if isinstance(self.model, DDP) else self.model
    
        # Warm-start EMA at first step to avoid bias
        if step == 0:
            self.ema.load_state_dict(src.state_dict(), strict=True)
            return
    
        with torch.no_grad():
            # 1) EMA for parameters
            for pe, pm in zip(self.ema.parameters(), src.parameters()):
                pe.mul_(self.ema_decay).add_(pm.detach(), alpha=1.0 - self.ema_decay)
            # 2) Direct copy for buffers (BN running stats, etc.)
            for be, bm in zip(self.ema.buffers(), src.buffers()):
                be.copy_(bm.detach())
    
    def one_epoch(self, epoch):
        if self.is_distributed and hasattr(self.train_loader, "sampler") and hasattr(self.train_loader.sampler, "set_epoch"):
            self.train_loader.sampler.set_epoch(epoch)
         # NEW: always update dataset epoch (train & val)
        if hasattr(self.train_loader, "dataset") and hasattr(self.train_loader.dataset, "set_epoch"):
            self.train_loader.dataset.set_epoch(epoch)
        if hasattr(self.val_loader, "dataset") and hasattr(self.val_loader.dataset, "set_epoch"):
            self.val_loader.dataset.set_epoch(0)  # keep val deterministic
    
        self.model.train()
        running = 0.0
        start_step = epoch * len(self.train_loader)
        self.optimizer.zero_grad(set_to_none=True)

        iterator = self.train_loader
        if is_main_process():
            iterator = tqdm(self.train_loader, total=len(self.train_loader), desc=f"Epoch {epoch+1}", leave=False)

        t0 = time()
        for it, (x,y, _) in enumerate(iterator):
            x = self._cast_input(x.to(self.device, non_blocking=True))
            y = y.to(self.device, non_blocking=True)

            lr = cosine_sched(start_step+it, self.total_steps, self.base_lr, CFG.min_lr, self.warmup_steps)
            for pg in self.optimizer.param_groups: pg["lr"] = lr

            # forward 
            if CFG.use_amp and self.autocast_dtype is not None:
                with autocast(dtype=self.autocast_dtype):
                    logits = self.model(x)
                    loss = self.criterion(logits, y)
            else:
                logits = self.model(x)
                loss = self.criterion(logits, y)

            # backward + step + EMA 
            if CFG.use_amp and AMP_DTYPE == "bf16":
                self.scaler.scale(loss).backward()
                if CFG.clip_grad_norm:
                    self.scaler.unscale_(self.optimizer)
                    nn.utils.clip_grad_norm_(self.model.parameters(), CFG.clip_grad_norm)
                self.scaler.step(self.optimizer)
                self.scaler.update()

                # EMA update with step 
                self._update_ema(self.global_step)
                self.global_step += 1

            elif CFG.use_amp and AMP_DTYPE == "bf16":
                loss.backward()
                if CFG.clip_grad_norm:
                    nn.utils.clip_grad_norm_(self.model.parameters(), CFG.clip_grad_norm)
                self.optimizer.step()

                # EMA update with step 
                self._update_ema(self.global_step)
                self.global_step += 1

            else:  # fp32
                loss.backward()
                if CFG.clip_grad_norm:
                    nn.utils.clip_grad_norm_(self.model.parameters(), CFG.clip_grad_norm)
                self.optimizer.step()

                # EMA update with step 
                self._update_ema(self.global_step)
                self.global_step += 1

            self.optimizer.zero_grad(set_to_none=True)
            running += loss.item()

            if is_main_process() and (it+1) % 20 == 0:
                dt = time() - t0
                ips = (it+1) * CFG.batch_size / max(dt, 1e-6)
                iterator.set_postfix(lr=f"{lr:.2e}", loss=f"{loss.item():.4f}", ips=f"{ips:.1f} it/s")

        avg_loss = torch.tensor([running / max(1, len(self.train_loader))], device=self.device)
        if self.is_distributed:
            dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
        return avg_loss.item()

    
    @torch.no_grad()
    def validate(self):
        net = self.ema if getattr(self, "ema", None) is not None else (
            self.model.module if isinstance(self.model, DDP) else self.model
        )
        net.eval()
    
        tot = 0.0
        probs_all, tgts_all, sids_all = [], [], []
    
        for batch in self.val_loader:
            # unpack (x,y,sid)
            x, y, sid = batch
            x = self._cast_input(x.to(self.device, non_blocking=True))
            y = y.to(self.device, non_blocking=True)
    
            with torch.cuda.amp.autocast(enabled=False):
                logits = net(x)
                loss = self.criterion(logits, y)
            tot += float(loss.item())
    
            probs = torch.sigmoid(logits.float())
            probs_all.append(probs.cpu().numpy())
            tgts_all.append(y.float().cpu().numpy())
            sids_all.extend(list(sid))
    
        # local stacks
        y_prob = np.concatenate(probs_all, axis=0)
        y_true = np.concatenate(tgts_all, axis=0)
        sids   = np.array(sids_all)
    
        # clean
        y_prob = np.clip(y_prob, 1e-6, 1-1e-6)
        finite_mask = np.isfinite(y_prob).all(axis=1) & np.isfinite(y_true).all(axis=1)
        y_prob, y_true, sids = y_prob[finite_mask], y_true[finite_mask], sids[finite_mask]
    
        # gather (no padding)
        if self.is_distributed:
            local = {"prob": y_prob, "true": y_true, "sid": sids}
            gathered = [None] * self.world
            dist.all_gather_object(gathered, local)
            y_prob = np.concatenate([g["prob"] for g in gathered if g is not None], axis=0)
            y_true = np.concatenate([g["true"] for g in gathered if g is not None], axis=0)
            sids   = np.concatenate([g["sid"]  for g in gathered if g is not None], axis=0)
    
        # metrics
        aucs = {}
        for j, name in enumerate(LABEL_COLS):
            yi, pi = y_true[:, j], y_prob[:, j]
            m = np.isfinite(yi) & np.isfinite(pi)
            yi, pi = yi[m], pi[m]
            aucs[name] = roc_auc_score(yi, pi) if np.unique(yi).size >= 2 else np.nan
    
        wauc = comp_weighted_auc(aucs)
        va_loss = tot / max(1, len(self.val_loader))
    
        # SAVE fold-level OOF chunk (VAL predictions for this fold) on main process
        if is_main_process():
            df_pred = pd.DataFrame({"SeriesInstanceUID": sids})
            for j, name in enumerate(LABEL_COLS):
                df_pred[name] = y_prob[:, j]
            # Optional: include targets for analysis
            for j, name in enumerate(LABEL_COLS):
                df_pred[f"{name}_target"] = y_true[:, j]
            oof_path = os.path.join(self.fold_dir, f"oof_fold{self.fold}_seed{CFG.seed}.csv")
            df_pred.to_csv(oof_path, index=False)
    
        return va_loss, wauc, aucs, y_prob  # (kept same return shape)

    
    def fit(self):
        best_state = None
        no_improve = 0
        patience = 7
    
        for epoch in range(CFG.epochs):
            tr_loss = self.one_epoch(epoch)
            va_loss, wauc, aucs, _ = self.validate()
            ap = aucs[LABEL_COLS[ANEURYSM_PRESENT_IDX]]
    
            if is_main_process():
                print(f"[{epoch+1:02d}/{CFG.epochs}] tr={tr_loss:.4f}  va={va_loss:.4f}  wAUC={wauc:.5f}  AneurysmPresent={ap:.5f}")
                
            if wauc > self.best_auc:
                self.best_auc = wauc
                no_improve = 0
                # Save EMA as the best snapshot
                state = copy.deepcopy(self.ema.state_dict())
                best_state = state  # keep a local copy for end-of-training reload
                if is_main_process():
                    torch.save(
                    {"state_dict": state, "cfg": cfg_to_dict(CFG), "best_wAUC": float(wauc),
                     "fold": int(self.fold), "is_ema": True},
                    os.path.join(self.fold_dir, "best_ema.pth")
                    )
                    # raw/live weights (optional)
                    live = (self.model.module if isinstance(self.model, DDP) else self.model).state_dict()
                    torch.save(
                        {"state_dict": live, "cfg": cfg_to_dict(CFG), "best_wAUC": float(wauc),
                         "fold": int(self.fold), "is_ema": False},
                        os.path.join(self.fold_dir, "best_raw.pth")
                    )
                    # metrics sidecar
                    with open(os.path.join(self.fold_dir, "metrics.json"), "w") as f:
                        json.dump({"best_wAUC": float(wauc), "epoch": int(epoch), "fold": int(self.fold)}, f, indent=2)
            else:
                no_improve += 1
    
            if no_improve >= patience:
                if is_main_process(): 
                    print("Early stopping.")
                break
    
        # Load EMA-best back into the active model so further eval/infer use it
        if best_state is not None:
            target = self.model.module if isinstance(self.model, DDP) else self.model
            target.load_state_dict(best_state, strict=False)
    
        if is_main_process():
            print("Best wAUC:", self.best_auc)
    
        return self.best_auc

In [None]:
# # Single-GPU local run (Windows/Jupyter safe)
# import os, platform, torch

# print("cuda available:", torch.cuda.is_available())
# if torch.cuda.is_available():
#     print("device name:", torch.cuda.get_device_name(0))

# CFG.num_workers = 0
# CFG.persistent_workers = False
# CFG.pin_memory = torch.cuda.is_available()

# # Force non-distributed mode
# def _setup_dist_dummy():
#     # local_rank, rank, world_size, is_distributed
#     return 0, 0, 1, False

# def _cleanup_dist_dummy():
#     pass

# # Override any previous distributed helpers
# setup_distributed = _setup_dist_dummy
# cleanup_distributed = _cleanup_dist_dummy

# # Optional: small perf knobs
# torch.set_num_threads(max(1, os.cpu_count() // 2))
# if torch.cuda.is_available():
#     torch.backends.cudnn.benchmark = True  # speeds up convs with fixed input size

# # Train single process 
# from __main__ import build_loaders, Trainer, CFG, seed_everything  # ensure these are defined above

# seed_everything(CFG.seed)
# train_loader, val_loader, fold_idx, world, r, local = build_loaders(fold_idx=0)
# print(f"[single] world={world} rank={r} local_rank={local}  |  "
#       f"Train: {len(train_loader.dataset)}  Val: {len(val_loader.dataset)}")

# trainer = Trainer(train_loader, val_loader, fold=fold_idx)
# trainer.fit()

# Inference

In [None]:
# #  Inference + Kaggle Server (OFFLINE, multi-seed, robust paths)
# import os, gc, shutil, warnings
# warnings.filterwarnings("ignore")
# from pathlib import Path
# from typing import Dict, Tuple, List, Optional
# import numpy as np
# import polars as pl
# import torch
# import torch.nn as nn
# from torch.cuda.amp import autocast
# import timm

# ID_COL = "SeriesInstanceUID"
# TARGET_COLS = LABEL_COLS
# NUM_CLASSES = len(TARGET_COLS)

# # no internet pulls during submit
# os.environ["HF_HUB_OFFLINE"] = "1"
# os.environ["TRANSFORMERS_OFFLINE"] = "1"

# class InferenceCFG:
#     model_name: str = CFG.model_name
#     img_size:  int = CFG.img_size
#     in_chans:  int = CFG.in_chans
#     num_classes: int = NUM_CLASSES
#     channels_last: bool = True
#     use_amp: bool = True
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     base_ckpt_dir: str = "/kaggle/input/maxvit-base-tf-384-dataset"
#     save_name: str = CFG.save_name  # e.g. "maxvitbasemodel"
#     # which folds to use currently have fold0 only
#     folds: List[int] = [0]

#     # filenames we prefer (searched recursively)
#     ckpt_prefer = ["best_ema.pth", "ema.pth", "model_ema.pth", "best_raw.pth"]

# ICFG = InferenceCFG()

# def _make_model_for_infer() -> nn.Module:
#     m = timm.create_model(
#         ICFG.model_name,
#         in_chans=ICFG.in_chans,
#         num_classes=ICFG.num_classes,
#         img_size=ICFG.img_size,
#         pretrained=False,      # never fetch online
#     )
#     if ICFG.channels_last:
#         m = m.to(memory_format=torch.channels_last)
#     return m

# def _find_seed_dirs(base: Path, save_name: str, folds: List[int]) -> List[Path]:
#     """
#     Find seed directories like:
#       base / f"{save_name}_seed{X}_fold{F}"
#     (supports your current nested layout too)
#     """
#     cand_dirs = []
#     for p in base.iterdir():
#         if not p.is_dir():
#             continue
#         name = p.name
#         if not name.startswith(f"{save_name}_seed"):
#             continue
#         # keep only ones that end with _foldK we care about
#         if any(name.endswith(f"_fold{f}") for f in folds):
#             cand_dirs.append(p)
#     return sorted(cand_dirs)

# def _find_ckpt_in_dir(dirpath: Path) -> Optional[Path]:
#     """
#     Robust: try preferred filenames under dirpath (recursively).
#     If not found, fall back to any *.pth (if unique).
#     Handles your nested pattern: <seed_dir>/<same_name>/*.
#     """
#     # try preferred names anywhere under this dir
#     for fname in ICFG.ckpt_prefer:
#         hits = list(dirpath.rglob(fname))
#         if hits:
#             return hits[0]

#     # else: any *.pth (unique)
#     all_pth = list(dirpath.rglob("*.pth"))
#     if len(all_pth) == 1:
#         return all_pth[0]
#     # if multiple, try pick the one with 'ema' in its name
#     ema_pth = [p for p in all_pth if "ema" in p.name.lower()]
#     if len(ema_pth) == 1:
#         return ema_pth[0]
#     return None

# def _load_model_from_ckpt(ckpt_path: Path) -> nn.Module:
#     ckpt = torch.load(ckpt_path, map_location="cpu")
#     state = ckpt.get("state_dict", ckpt.get("model"))
#     if state is None:
#         raise KeyError(f"{ckpt_path} missing 'state_dict'/'model' keys")
#     m = _make_model_for_infer()
#     m.load_state_dict(state, strict=True)
#     m.to(ICFG.device).eval()
#     return m

# # cache of loaded models, keyed by (seed_dir_name)
# _MODELS: Dict[str, nn.Module] = {}

# def _ensure_models_loaded():
#     if _MODELS:
#         return
#     base = Path(ICFG.base_ckpt_dir)
#     seed_dirs = _find_seed_dirs(base, ICFG.save_name, ICFG.folds)

#     if not seed_dirs:
#         raise FileNotFoundError(f"No seed folders found under {base} for save_name={ICFG.save_name}")

#     loaded = 0
#     for sd in seed_dirs:
#         ckpt = _find_ckpt_in_dir(sd)
#         if ckpt is None:
#             # skip politely if a seed dir has only OOF csv but no weights
#             continue
#         try:
#             model = _load_model_from_ckpt(ckpt)
#             _MODELS[sd.name] = model
#             loaded += 1
#         except Exception as e:
#             # skip broken checkpoints but continue others
#             print(f"Warning: failed to load {ckpt}: {e}")

#     if loaded == 0:
#         raise FileNotFoundError(f"No usable *.pth checkpoints found under {base} (looked recursively).")

#     # optional warmup
#     with torch.no_grad():
#         dummy = torch.randn(1, ICFG.in_chans, ICFG.img_size, ICFG.img_size, device=ICFG.device)
#         for m in _MODELS.values():
#             _ = m(dummy)

# # preprocessing 
# def _process_series(series_path: str, target_shape: Tuple[int,int,int]) -> np.ndarray:
#     pre = DICOMPreprocessorKaggle(target_shape=target_shape)
#     vol = pre.process_series(series_path)  # (32,H,W), uint8 or float in [0,255]
#     return vol

# # predict 
# @torch.no_grad()
# def _predict_single(model: nn.Module, vol_u8: np.ndarray) -> np.ndarray:
#     x = torch.from_numpy(np.asarray(vol_u8)).to(torch.float32).div_(255.0).unsqueeze(0)
#     if ICFG.channels_last:
#         x = x.contiguous(memory_format=torch.channels_last)
#     x = x.to(ICFG.device, non_blocking=True)
#     with autocast(enabled=ICFG.use_amp):
#         logits = model(x)
#     prob = torch.sigmoid(logits.float()).cpu().numpy().squeeze(0)
#     return np.clip(prob, 1e-6, 1-1e-6)

# def _predict_ensemble(vol_u8: np.ndarray) -> np.ndarray:
#     _ensure_models_loaded()
#     preds = [_predict_single(m, vol_u8) for m in _MODELS.values()]
#     return np.mean(np.stack(preds, 0), 0)

# def _predict_inner(series_path: str) -> pl.DataFrame:
#     vol = _process_series(series_path, (ICFG.in_chans, ICFG.img_size, ICFG.img_size))
#     pred = _predict_ensemble(vol)
#     return pl.DataFrame([pred.tolist()], schema=TARGET_COLS)

# def predict(series_path: str) -> pl.DataFrame:
#     try:
#         return _predict_inner(series_path)
#     except Exception:
#         # conservative fallback
#         return pl.DataFrame([[0.1]*NUM_CLASSES], schema=TARGET_COLS)
#     finally:
#         # required cleanup between served series
#         shared_dir = "/kaggle/shared"
#         shutil.rmtree(shared_dir, ignore_errors=True)
#         os.makedirs(shared_dir, exist_ok=True)
#         if torch.cuda.is_available():
#             torch.cuda.empty_cache()
#         gc.collect()

In [None]:
# Inference + Kaggle Server (OFFLINE, multi-seed, robust paths) 
import os, gc, shutil, warnings, glob, ast, hashlib
warnings.filterwarnings("ignore")
from pathlib import Path
from typing import Dict, Tuple, List, Optional
from functools import lru_cache
import numpy as np
import polars as pl
import pandas as pd
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import timm
import pydicom
import cv2
from collections import defaultdict

# competition constants 
ID_COL = "SeriesInstanceUID"
TARGET_COLS = LABEL_COLS
NUM_CLASSES = len(TARGET_COLS)

# no internet pulls during submit
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

def is_main_process():
    # Kaggle server is single-process; define to avoid NameError
    return True

# load manifests from shard roots 
def _discover_shard_roots() -> List[str]:
    """Find all cache shard folders under D:/User Data/Downloads/rsna-intracranial-aneurysm-detection/cache/*/cache_u8_{img}_shard*."""
    SHARDS_ROOT = "D:/User Data/Downloads/rsna-intracranial-aneurysm-detection/cache"
    pattern = os.path.join(SHARDS_ROOT, "*", f"cache_u8_{CFG.img_size}_shard*")
    shard_roots = sorted([p for p in glob.glob(pattern) if os.path.isdir(p)])
    if is_main_process():
        print("Found shard roots:", len(shard_roots))
        for p in shard_roots[:8]:
            print("  ", p)
    return shard_roots

def _load_sid_to_orig_depth(shard_roots: List[str], img_size: int) -> Dict[str, int]:
    sid2: Dict[str, int] = {}
    for root in shard_roots:
        man = os.path.join(root, f"manifest_{img_size}.parquet")
        if os.path.exists(man):
            try:
                m = pd.read_parquet(man)
            except Exception:
                m = pd.read_parquet(man, engine="fastparquet")
            if {"SeriesInstanceUID","orig_depth"}.issubset(m.columns):
                sub = m[["SeriesInstanceUID","orig_depth"]].dropna()
                for sid, od in zip(sub["SeriesInstanceUID"].astype(str), sub["orig_depth"].astype(int)):
                    sid2[sid] = int(od)
    return sid2

# cached path index: (sid, size) -> npy path 
def _build_uid_to_path(shard_roots: List[str], img_size: int) -> Dict[Tuple[str,int], str]:
    idx: Dict[Tuple[str,int], str] = {}
    suffix = f"_{img_size}.npy"
    for root in shard_roots:
        with os.scandir(root) as it:
            for e in it:
                if e.is_file() and e.name.endswith(suffix):
                    uid = e.name[:-len(suffix)]
                    idx[(uid, img_size)] = e.path
    return idx

# parse localizers CSV (your schema) 
def _load_localizers_csv(csv_path: Optional[str], max_points_per_series: int = 3) -> Dict[str, List[dict]]:
    """
    Columns:
      - SeriesInstanceUID
      - SOPInstanceUID
      - coordinates: "{'x': ..., 'y': ...}" or '{"x":..., "y":...}'
      - location (optional)
    Returns: { sid: [ {'x','y','sop','loc'} ... ] }
    """
    if not csv_path or not os.path.exists(csv_path):
        return {}
    df = pd.read_csv(csv_path)
    cols = {c.lower(): c for c in df.columns}
    sid_col   = cols.get('seriesinstanceuid') or 'SeriesInstanceUID'
    sop_col   = cols.get('sopinstanceuid')   or 'SOPInstanceUID'
    coord_col = cols.get('coordinates')      or 'coordinates'
    loc_col   = cols.get('location') if 'location' in cols else None

    keep = [c for c in [sid_col, sop_col, coord_col, loc_col] if c and c in df.columns]
    df = df[keep].copy()

    by_sid: Dict[str, List[dict]] = defaultdict(list)
    for _, r in df.iterrows():
        sid = str(r[sid_col])
        sop = str(r[sop_col]) if sop_col in r and pd.notna(r[sop_col]) else None
        x = y = None
        if coord_col in r and pd.notna(r[coord_col]):
            s = str(r[coord_col]).strip()
            try:
                xy = ast.literal_eval(s)
                x = float(xy.get("x")) if xy.get("x") is not None else None
                y = float(xy.get("y")) if xy.get("y") is not None else None
            except Exception:
                x = y = None
        loc = str(r[loc_col]) if loc_col and pd.notna(r[loc_col]) else None
        by_sid[sid].append({"x": x, "y": y, "sop": sop, "loc": loc})

    for sid in list(by_sid.keys()):
        by_sid[sid] = by_sid[sid][:max_points_per_series]
    return dict(by_sid)

# SOP -> rank map (headers only, cached)
@lru_cache(maxsize=512)
def _build_sop_rank_map(series_dir: str) -> Tuple[Dict[str,int], int]:
    items = []
    try:
        for name in os.listdir(series_dir):
            if not name.lower().endswith(".dcm"):
                continue
            path = os.path.join(series_dir, name)
            ds = pydicom.dcmread(path, stop_before_pixels=True, force=True)
            sop = str(getattr(ds, "SOPInstanceUID", os.path.splitext(name)[0]))
            ipp = getattr(ds, "ImagePositionPatient", None)
            z = float(ipp[2]) if ipp is not None and len(ipp) == 3 else float(getattr(ds, "SliceLocation", 0.0))
            items.append((sop, z))
    except Exception:
        pass
    if not items:
        return ({}, 0)
    items.sort(key=lambda t: t[1])
    return ({sop: i for i, (sop, _) in enumerate(items)}, len(items))

def _rank_to_cached_idx(rank: int, orig_depth: int, cached_depth: int) -> int:
    if orig_depth <= 1:
        return cached_depth // 2
    r = np.clip(rank, 0, orig_depth-1)
    return int(round(r / (orig_depth - 1) * (cached_depth - 1)))

def _map_localizer_to_cached_depth(loc_z, loc_f, cached_depth: int, orig_depth: Optional[int]=None) -> int:
    if orig_depth and loc_f is not None:
        return int(np.clip(round(loc_f / max(1, (orig_depth-1)) * (cached_depth-1)), 0, cached_depth-1))
    if orig_depth and loc_z is not None:
        return int(np.clip(round(loc_z  / max(1, (orig_depth-1)) * (cached_depth-1)), 0, cached_depth-1))
    return cached_depth // 2

class InferenceCFG:
    # must match training
    model_name: str = CFG.model_name
    img_size:  int = CFG.img_size
    in_chans:  int = CFG.in_chans
    num_classes: int = NUM_CLASSES
    channels_last: bool = True
    use_amp: bool = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # uploaded dataset root 
    base_ckpt_dir: str = "D:/User Data/Downloads/rsna-intracranial-aneurysm-detection/outputs/maxvitbasemodel_seed42_fold0"
    save_name: str = CFG.save_name
    folds: List[int] = [0]
    ckpt_prefer = ["best_ema.pth", "ema.pth", "model_ema.pth", "best_raw.pth"]

ICFG = InferenceCFG()

# (Optional) recompute in_chans defensively to mirror training flags
_K = getattr(CFG, "max_localizer_crops", 0) if getattr(CFG, "use_localizers", False) else 0
ICFG.in_chans = CFG.base_slices + getattr(CFG, "extra_cached_chans", 0) + _K

# Build cache indices and metadata once
_SHARD_ROOTS = _discover_shard_roots()
_UID2PATH    = _build_uid_to_path(_SHARD_ROOTS, ICFG.img_size)
_SID2DEPTH   = _load_sid_to_orig_depth(_SHARD_ROOTS, ICFG.img_size)

# Localizers CSV path must match what i used in train
_LOCALIZERS = _load_localizers_csv(getattr(CFG, "localizers_csv_path", None),
                                   max_points_per_series=getattr(CFG, "max_localizer_crops", 3))

def _load_cached_or_preprocess(series_path: str, sid: str) -> np.ndarray:
    cp = _UID2PATH.get((sid, ICFG.img_size))
    if cp and os.path.exists(cp):
        vol_u8 = np.load(cp, mmap_mode="r")
    else:
        pre = DICOMPreprocessorKaggle(target_shape=(CFG.base_slices, ICFG.img_size, ICFG.img_size))
        vol = pre.process_series(series_path)
        vol_u8 = vol if vol.dtype == np.uint8 else np.clip(vol, 0, 255).astype(np.uint8)

    # ensure spatial dims
    C, H, W = vol_u8.shape
    if (H != ICFG.img_size) or (W != ICFG.img_size):
        vol_u8 = np.stack(
            [cv2.resize(vol_u8[c], (ICFG.img_size, ICFG.img_size), interpolation=cv2.INTER_LINEAR)
             for c in range(C)], axis=0)
    return vol_u8

def _safe_crop_2d(img: np.ndarray, cx: int, cy: int, size: int) -> np.ndarray:
    H, W = img.shape
    half = size // 2
    x0 = max(0, cx - half); x1 = min(W, cx + half)
    y0 = max(0, cy - half); y1 = min(H, cy + half)
    crop = img[y0:y1, x0:x1]
    if crop.shape[0] != size or crop.shape[1] != size:
        pad_y = size - crop.shape[0]
        pad_x = size - crop.shape[1]
        crop = np.pad(crop, ((0, max(0,pad_y)), (0, max(0,pad_x))), mode='edge')
        crop = crop[:size, :size]
    return crop

def _compose_localizer_channels(sid: str, vol_u8: np.ndarray, K: int, crop_size: int=128) -> np.ndarray:
    # Inference: do NOT apply train-time dropouts; use points if present, else deterministic random fill.
    if K <= 0:
        return np.zeros((0, vol_u8.shape[1], vol_u8.shape[2]), dtype=vol_u8.dtype)

    pts = _LOCALIZERS.get(sid, [])
    H, W = vol_u8.shape[1], vol_u8.shape[2]
    base = vol_u8[:CFG.base_slices]
    cached_depth = base.shape[0]

    # SOP→rank mapping & orig depth
    series_dir = os.path.join(CFG.series_root, sid)
    sop2rank, hdr_depth = _build_sop_rank_map(series_dir) if os.path.isdir(series_dir) else ({}, 0)
    use_depth = _SID2DEPTH.get(sid, None) or hdr_depth or cached_depth

    chans = []
    for p in pts[:K]:
        # choose z
        sop = p.get("sop")
        if sop and sop in sop2rank and use_depth > 0:
            z_idx = _rank_to_cached_idx(sop2rank[sop], use_depth, cached_depth)
        else:
            z_idx = _map_localizer_to_cached_depth(p.get("z"), p.get("f"), cached_depth, use_depth)
        z0 = max(0, z_idx-8); z1 = min(cached_depth, z_idx+9)
        slab = base[z0:z1]
        if slab.size == 0:
            chan = np.zeros((H,W), dtype=vol_u8.dtype)
        else:
            mip = slab.max(axis=0)
            px, py = p.get("x"), p.get("y")
            if px is None or py is None:
                cx, cy = W//2, H//2
            else:
                cx = int(round(np.clip(px, 0, W-1)))
                cy = int(round(np.clip(py, 0, H-1)))
            crop = _safe_crop_2d(mip, cx, cy, size=crop_size)
            chan = cv2.resize(crop, (W,H), interpolation=cv2.INTER_LINEAR)
        chans.append(chan[np.newaxis, ...])

    # deterministic, per-SID filler for any missing channels
    def _sid_seed(sid: str, salt: str = "infer_locrand") -> int:
        h = hashlib.sha1((salt + sid).encode()).hexdigest()[:8]
        return int(h, 16)

    def _safe_rand_points(H, W, K, rng):
        xs = rng.integers(low=W//8, high=W - W//8, size=K)
        ys = rng.integers(low=H//8, high=H - H//8, size=K)
        return list(zip(xs.tolist(), ys.tolist()))

    if len(chans) < K:
        need = K - len(chans)
        rng = np.random.default_rng(_sid_seed(sid))
        cz = rng.integers(low=0, high=max(1, cached_depth), size=need)
        pts_xy = _safe_rand_points(H, W, need, rng)
        for i in range(need):
            z_idx = int(cz[i])
            z0 = max(0, z_idx-8); z1 = min(cached_depth, z_idx+9)
            slab = base[z0:z1]
            mip = slab.max(axis=0) if slab.size else np.zeros((H,W), dtype=vol_u8.dtype)
            cx, cy = pts_xy[i]
            crop = _safe_crop_2d(mip, cx, cy, size=crop_size)
            chan = cv2.resize(crop, (W, H), interpolation=cv2.INTER_LINEAR)
            chans.append(chan[np.newaxis, ...])

    chans = chans[:K]
    return np.concatenate(chans, axis=0)

def _build_model_input(series_path: str) -> Tuple[np.ndarray, str]:
    sid = os.path.basename(series_path.rstrip("/"))
    vol_u8 = _load_cached_or_preprocess(series_path, sid)
    K = getattr(CFG, "max_localizer_crops", 0) if getattr(CFG, "use_localizers", False) else 0
    extra = _compose_localizer_channels(sid, vol_u8, K, crop_size=getattr(CFG, "local_crop_size", 128))
    vol_u8 = np.concatenate([vol_u8, extra], axis=0)
    # final check
    C, H, W = vol_u8.shape
    if C != ICFG.in_chans:
        if C < ICFG.in_chans:
            pad = np.zeros((ICFG.in_chans - C, H, W), dtype=vol_u8.dtype)
            vol_u8 = np.concatenate([vol_u8, pad], axis=0)
        else:
            vol_u8 = vol_u8[:ICFG.in_chans]
    return vol_u8, sid

def _make_model_for_infer() -> nn.Module:
    m = timm.create_model(
        ICFG.model_name,
        in_chans=ICFG.in_chans,
        num_classes=ICFG.num_classes,
        img_size=ICFG.img_size,
        pretrained=False,
    )
    if ICFG.channels_last:
        m = m.to(memory_format=torch.channels_last)
    return m

def _find_seed_dirs(base: Path, save_name: str, folds: List[int]) -> List[Path]:
    """Find seed directories like: base / f"{save_name}_seedX_foldF"."""
    cand_dirs = []
    if not base.exists():
        return cand_dirs
    for p in base.iterdir():
        if not p.is_dir():
            continue
        name = p.name
        if not name.startswith(f"{save_name}_seed"):
            continue
        if any(name.endswith(f"_fold{f}") for f in folds):
            cand_dirs.append(p)
    return sorted(cand_dirs)

def _find_ckpt_in_dir(dirpath: Path) -> Optional[Path]:
    # try preferred names anywhere under this dir
    for fname in ICFG.ckpt_prefer:
        hits = list(dirpath.rglob(fname))
        if hits:
            return hits[0]
    # else: any *.pth (unique / or best guess)
    all_pth = list(dirpath.rglob("*.pth"))
    if len(all_pth) == 1:
        return all_pth[0]
    ema_pth = [p for p in all_pth if "ema" in p.name.lower()]
    if len(ema_pth) == 1:
        return ema_pth[0]
    return all_pth[0] if all_pth else None

def _strip_prefix(state: dict, prefixes=("module.", "model.", "model_ema.", "ema.", "student.")):
    out = {}
    for k, v in state.items():
        kk = k
        for p in prefixes:
            if kk.startswith(p):
                kk = kk[len(p):]
        out[kk] = v
    return out

def _load_model_from_ckpt(ckpt_path: Path) -> nn.Module:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state = ckpt.get("state_dict") or ckpt.get("model") or ckpt
    state = _strip_prefix(state)
    m = _make_model_for_infer()
    missing, unexpected = m.load_state_dict(state, strict=False)
    if missing:
        print(f"[ckpt] missing {len(missing)} keys (ok if classifier head differs)")
    if unexpected:
        print(f"[ckpt] unexpected {len(unexpected)} keys (ignored)")
    m.to(ICFG.device).eval()
    return m

_MODELS: Dict[str, nn.Module] = {}

def _ensure_models_loaded():
    if _MODELS:
        return
    base = Path(ICFG.base_ckpt_dir)
    seed_dirs = _find_seed_dirs(base, ICFG.save_name, ICFG.folds)
    if not seed_dirs:
        # fallback: treat base itself as a seed dir (single-seed uploads)
        seed_dirs = [base]

    loaded = 0
    for sd in seed_dirs:
        ckpt = _find_ckpt_in_dir(sd)
        if ckpt is None:
            continue
        try:
            model = _load_model_from_ckpt(ckpt)
            _MODELS[sd.name] = model
            loaded += 1
        except Exception as e:
            print(f"Warning: failed to load {ckpt}: {e}")

    if loaded == 0:
        raise FileNotFoundError(f"No usable *.pth checkpoints under {base}.")

    # optional warmup
    with torch.no_grad():
        dummy = torch.randn(1, ICFG.in_chans, ICFG.img_size, ICFG.img_size, device=ICFG.device)
        for m in _MODELS.values():
            _ = m(dummy)

# predict 
@torch.no_grad()
def _predict_single(model: nn.Module, vol_u8: np.ndarray) -> np.ndarray:
    x = torch.from_numpy(np.asarray(vol_u8)).to(torch.float32).div_(255.0).unsqueeze(0)
    if ICFG.channels_last:
        x = x.contiguous(memory_format=torch.channels_last)
    x = x.to(ICFG.device, non_blocking=True)
    with autocast(enabled=ICFG.use_amp):
        logits = model(x)
    prob = torch.sigmoid(logits.float()).cpu().numpy().squeeze(0)
    return np.clip(prob, 1e-6, 1-1e-6)

def _predict_ensemble(vol_u8: np.ndarray) -> np.ndarray:
    _ensure_models_loaded()
    preds = [_predict_single(m, vol_u8) for m in _MODELS.values()]
    return np.mean(np.stack(preds, 0), 0)

def _predict_inner(series_path: str) -> pl.DataFrame:
    vol_u8, sid = _build_model_input(series_path)
    pred = _predict_ensemble(vol_u8)
    return pl.DataFrame([pred.tolist()], schema=TARGET_COLS)

def predict(series_path: str) -> pl.DataFrame:
    try:
        return _predict_inner(series_path)
    except Exception as e:
        print("[infer] fallback due to:", e)
        return pl.DataFrame([[0.1]*NUM_CLASSES], schema=TARGET_COLS)
    finally:
        # required cleanup between served series
        shared_dir = "/kaggle/shared"
        shutil.rmtree(shared_dir, ignore_errors=True)
        os.makedirs(shared_dir, exist_ok=True)
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

In [None]:
import kaggle_evaluation.rsna_inference_server as rsna_eval

# Load once at startup (warm)
_ensure_models_loaded()

server = rsna_eval.RSNAInferenceServer(predict)

if os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
    server.serve()
else:
    server.run_local_gateway()
    sub_df = pl.read_parquet("/kaggle/working/submission.parquet")
    display(sub_df)