In [1]:
#!/usr/bin/env python
# coding: utf-8
# RSNA IAD — Fast Profile (<=12h) — single model + 3-center bagging + TTA=2

import os, gc, shutil, warnings
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from contextlib import contextmanager

warnings.filterwarnings("ignore")

# --- Core deps
import numpy as np
import polars as pl
import pandas as pd
import pydicom
import cv2

# --- DL stack
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

# --- Competition API
import kaggle_evaluation.rsna_inference_server as rsna

# ================= Speed knobs =================
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# ================= Device =================
def setup_device():
    return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = setup_device()

# ================= Config =================
ID_COL = "SeriesInstanceUID"
LABEL_COLS = [
    'Left Infraclinoid Internal Carotid Artery','Right Infraclinoid Internal Carotid Artery',
    'Left Supraclinoid Internal Carotid Artery','Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery','Right Middle Cerebral Artery',
    'Anterior Communicating Artery','Left Anterior Cerebral Artery','Right Anterior Cerebral Artery',
    'Left Posterior Communicating Artery','Right Posterior Communicating Artery',
    'Basilar Tip','Other Posterior Circulation','Aneurysm Present',
]

MODEL_PATHS = {
    'tf_efficientnetv2_s': '/kaggle/input/rsna-iad-trained-models/models/tf_efficientnetv2_s_fold0_best.pth',
    # keep others for future, but unused in fast profile
    'convnext_small': '/kaggle/input/rsna-iad-trained-models/models/convnext_small_fold0_best.pth',
    'swin_small_patch4_window7_224': '/kaggle/input/rsna-iad-trained-models/models/swin_small_patch4_window7_224_fold0_best.pth'
}

class InferenceConfig:
    def __init__(self):
        # FAST PROFILE: single model
        self.model_selection = "tf_efficientnetv2_s"
        self.use_ensemble    = False

        # Image/volume
        self.image_size  = 512
        self.num_slices  = 32
        self.use_window  = True

        # Inference
        self.batch_size  = 1
        self.use_amp     = torch.cuda.is_available()
        self.use_tta     = True
        self.tta_n       = 2     # identity + hflip

        # Bagging
        self.slice_bags  = 3     # centers around mid: -2,0,+2
        self.thickness   = 0.8   # slab ~80% around center for MIP/STD

        # Housekeeping
        self.cleanup_every = 20
        self.max_retries   = 2
        self.fallback_ok   = True

        # Default windowing (fallback if tags absent)
        self.windowing = {
            'CT': (40,80), 'CTA': (50,350), 'MRA': (600,1200), 'MRI': (40,80), 'default': (40,80)
        }

CFG = InferenceConfig()

# ================= Model =================
class MultiBackboneModel(nn.Module):
    def __init__(self, model_name: str, num_classes: int = 14,
                 pretrained: bool = True, drop_rate: float = 0.0, drop_path_rate: float = 0.0):
        super().__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self._create_backbone(model_name, pretrained, drop_rate, drop_path_rate)
        self._determine_feature_dimensions()
        self._create_classifier(drop_rate)

    def _create_backbone(self, name, pretrained, drop_rate, drop_path_rate):
        kw = dict(pretrained=pretrained, in_chans=3, drop_rate=drop_rate, num_classes=0, global_pool='')
        if 'swin' in name:
            kw.update({'drop_path_rate': drop_path_rate, 'img_size': CFG.image_size})
        elif 'convnext' in name:
            kw['drop_path_rate'] = drop_path_rate
        self.backbone = timm.create_model(name, **kw)

    def _determine_feature_dimensions(self):
        with torch.no_grad():
            x = torch.zeros(1,3,CFG.image_size,CFG.image_size)
            f = self.backbone(x)
            if f.ndim == 4:
                self.num_features, self.needs_pool, self.needs_seq_pool = f.shape[1], True, False
                self.global_pool = nn.AdaptiveAvgPool2d(1)
            elif f.ndim == 3:
                self.num_features, self.needs_pool, self.needs_seq_pool = f.shape[-1], False, True
            else:
                self.num_features, self.needs_pool, self.needs_seq_pool = f.shape[1], False, False

    def _create_classifier(self, drop_rate):
        self.meta_fc = nn.Sequential(
            nn.Linear(2,16), nn.ReLU(inplace=True), nn.Dropout(0.2),
            nn.Linear(16,32), nn.ReLU(inplace=True), nn.Dropout(0.1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(self.num_features+32,512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Dropout(drop_rate),
            nn.Linear(512,256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Dropout(drop_rate*0.5),
            nn.Linear(256,self.num_classes)
        )

    def _pool_features(self, f):
        if self.needs_pool: return self.global_pool(f).flatten(1)
        if self.needs_seq_pool: return f.mean(1)
        if f.ndim==4: return F.adaptive_avg_pool2d(f,1).flatten(1)
        if f.ndim==3: return f.mean(1)
        return f

    def forward(self, image, meta):
        imgf = self._pool_features(self.backbone(image))
        metf = self.meta_fc(meta)
        return self.classifier(torch.cat([imgf,metf],1))

# ================= DICOM utils =================
@contextmanager
def dicom_error_handler(_):
    try: yield
    except Exception: raise

def _is_dicom_file(p:str)->bool:
    try:
        with open(p,'rb') as f:
            f.seek(128); return f.read(4)==b'DICM'
    except: return False

def _apply_rescale(img:np.ndarray, ds:pydicom.Dataset)->np.ndarray:
    try:
        slope = float(getattr(ds,'RescaleSlope',1.0))
        intercept = float(getattr(ds,'RescaleIntercept',0.0))
        return img*slope + intercept
    except: return img

def _get_window_from_ds(ds)->Tuple[float,float]:
    # Prefer DICOM tags; fallback to modality defaults
    try:
        wc = ds.WindowCenter; ww = ds.WindowWidth
        if isinstance(wc, pydicom.multival.MultiValue): wc = float(wc[0])
        else: wc = float(wc)
        if isinstance(ww, pydicom.multival.MultiValue): ww = float(ww[0])
        else: ww = float(ww)
        if ww <= 1: ww = 1.0
        return wc, ww
    except:
        modality = getattr(ds,'Modality','CT')
        return CFG.windowing.get(modality, CFG.windowing['default'])

def _process_pixel_array(img:np.ndarray)->np.ndarray:
    if img.ndim==3 and img.shape[-1]==3:
        img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32)
    elif img.ndim==3:  # multi-frame -> middle
        img = img[img.shape[0]//2]
    elif img.ndim>3:
        img = img.reshape(img.shape[-2], img.shape[-1])
    return img

def _safe_age(ds):
    try:
        s = str(getattr(ds,'PatientAge','050Y'))[:3]
        d = int(''.join([c for c in s if c.isdigit()]) or 50)
        return max(0,min(d,120))
    except: return 50

def _safe_sex(ds):
    try: return 1 if str(getattr(ds,'PatientSex','M')).upper().startswith('M') else 0
    except: return 0

def apply_dicom_windowing(img:np.ndarray, center:float, width:float)->np.ndarray:
    imin, imax = center - width/2.0, center + width/2.0
    img = np.clip(img, imin, imax)
    img = (img - imin) / max(imax - imin, 1e-6)
    return (img*255).astype(np.uint8)

def process_dicom_series(series_path:str)->Tuple[np.ndarray, Dict]:
    series_path = Path(series_path)
    files = []
    for root,_,fs in os.walk(series_path):
        for fn in fs:
            p = os.path.join(root, fn)
            if fn.lower().endswith(('.dcm','.dicom')) or _is_dicom_file(p):
                files.append(p)
    if not files:
        return _get_default_volume_and_metadata()

    # Sort by ImagePositionPatient (z) else InstanceNumber
    def sort_key(p):
        try:
            ds = pydicom.dcmread(p, stop_before_pixels=True, force=True)
            ipp = getattr(ds,'ImagePositionPatient', None)
            if ipp is not None and len(ipp)>=3:
                return float(ipp[2])
            return int(getattr(ds,'InstanceNumber',0))
        except: return 0
    files.sort(key=sort_key)

    slices, meta, err = [], {}, 0
    for i,p in enumerate(files):
        try:
            with dicom_error_handler(p):
                ds = pydicom.dcmread(p, force=True)
                img = _process_pixel_array(ds.pixel_array.astype(np.float32))
                if i==0:
                    meta = {'modality': getattr(ds,'Modality','CT'),
                            'age': _safe_age(ds), 'sex': _safe_sex(ds)}
                img = _apply_rescale(img, ds)
                if CFG.use_window:
                    c,w = _get_window_from_ds(ds)
                    img = apply_dicom_windowing(img, c, w)
                else:
                    mn, mx = img.min(), img.max()
                    img = ((img - mn)/max(mx-mn,1e-6)*255).astype(np.uint8)
                img = cv2.resize(img, (CFG.image_size, CFG.image_size), interpolation=cv2.INTER_LINEAR)
                slices.append(img)
        except:
            err += 1
            if err > len(files)*0.5:
                return _get_default_volume_and_metadata()

    if not meta: meta = {'age':50,'sex':0,'modality':'CT'}
    volume = _create_volume_from_slices(slices)
    return volume, meta

def _create_volume_from_slices(slices:List[np.ndarray])->np.ndarray:
    if not slices:
        return np.zeros((CFG.num_slices, CFG.image_size, CFG.image_size), np.uint8)
    vol = np.asarray(slices)
    n = CFG.num_slices
    if len(vol) > n:
        idx = np.linspace(0, len(vol)-1, n).astype(int)
        vol = vol[idx]
    elif len(vol) < n:
        pad = n - len(vol)
        if len(vol)==1: vol = np.repeat(vol, n, axis=0)
        else: vol = np.pad(vol, ((0,pad),(0,0),(0,0)), mode='edge')
    return vol

def _get_default_volume_and_metadata():
    return (np.zeros((CFG.num_slices, CFG.image_size, CFG.image_size), np.uint8),
            {'age':50,'sex':0,'modality':'CT'})

# ================= Transforms =================
def get_inference_transform():
    return A.Compose([A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225], max_pixel_value=255.0),
                      ToTensorV2()])

def get_tta_transforms():
    base = [A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ToTensorV2()]
    return [
        A.Compose(base),                          # identity
        A.Compose([A.HorizontalFlip(p=1.0)] + base),
    ]

MODELS: Dict[str, nn.Module] = {}
TRANSFORM: Optional[A.Compose] = None
TTA_TRANSFORMS: Optional[List[A.Compose]] = None
_PRED_COUNT = 0

# ================= Loading =================
def _validate_model(m:nn.Module):
    with torch.no_grad():
        im = torch.randn(1,3,CFG.image_size,CFG.image_size, device=device).to(memory_format=torch.channels_last)
        me = torch.randn(1,2, device=device)
        out = m(im, me)
        if out.shape != (1,len(LABEL_COLS)):
            raise RuntimeError(f"bad output {out.shape}")

def load_single_model(name, path)->nn.Module:
    if not os.path.exists(path): raise FileNotFoundError(path)
    ckpt = torch.load(path, map_location=device, weights_only=False)
    tr = ckpt.get('training_config', {})
    if 'image_size' in tr: CFG.image_size = tr['image_size']
    m = MultiBackboneModel(name, num_classes=tr.get('num_classes',14),
                           pretrained=False, drop_rate=0.0, drop_path_rate=0.0)
    m.load_state_dict(ckpt['model_state_dict'], strict=False)
    m.to(device)
    m.eval()
    m.to(memory_format=torch.channels_last)
    _validate_model(m)
    return m

def load_models():
    global MODELS, TRANSFORM, TTA_TRANSFORMS
    MODELS.clear()
    name = CFG.model_selection
    MODELS[name] = load_single_model(name, MODEL_PATHS[name])
    TRANSFORM = get_inference_transform()
    if CFG.use_tta: TTA_TRANSFORMS = get_tta_transforms()
    _warmup_models()

def _warmup_models():
    try:
        im = torch.randn(CFG.batch_size,3,CFG.image_size,CFG.image_size, device=device).to(memory_format=torch.channels_last)
        me = torch.randn(CFG.batch_size,2, device=device)
        with torch.no_grad():
            for m in MODELS.values():
                with autocast(enabled=CFG.use_amp):
                    _ = m(im, me)  # single pass
        del im, me
        if torch.cuda.is_available(): torch.cuda.empty_cache()
    except:
        pass

# ================= Prediction helpers =================
def _prepare_metadata_tensor(meta:Dict)->torch.Tensor:
    age = float(np.clip(meta.get('age',50)/100.0, 0.0, 1.2))
    sex = float(np.clip(int(meta.get('sex',0)), 0, 1))
    return torch.tensor([[age,sex]], dtype=torch.float32, device=device)

def _create_multichannel_input_with_thickness(volume: np.ndarray, center_idx: int, frac: float) -> np.ndarray:
    mid = int(np.clip(center_idx, 1, volume.shape[0]-2))
    half = max(int(volume.shape[0]*frac/2), 1)
    lo, hi = max(0, mid-half), min(volume.shape[0], mid+half+1)
    slab = volume[lo:hi]
    middle = volume[mid]
    mip = slab.max(axis=0)
    std = slab.astype(np.float32).std(axis=0)
    if std.max() > std.min():
        std = ((std - std.min()) / max(std.max() - std.min(),1e-6) * 255).astype(np.uint8)
    else:
        std = np.full_like(middle, 128, dtype=np.uint8)
    return np.stack([middle, mip, std], -1)

def _tta_predict(model, image_np, meta_tensor):
    preds = []
    if CFG.use_tta and TTA_TRANSFORMS:
        for tfm in TTA_TRANSFORMS[:CFG.tta_n]:
            t = tfm(image=image_np)['image'].unsqueeze(0).to(device, non_blocking=True)
            t = t.to(memory_format=torch.channels_last)
            with torch.no_grad(), autocast(enabled=CFG.use_amp):
                o = model(t, meta_tensor)
            preds.append(torch.sigmoid(o).float().cpu().numpy())
        return np.mean(preds,0).squeeze()
    else:
        t = TRANSFORM(image=image_np)['image'].unsqueeze(0).to(device, non_blocking=True)
        t = t.to(memory_format=torch.channels_last)
        with torch.no_grad(), autocast(enabled=CFG.use_amp):
            o = model(t, meta_tensor)
        return torch.sigmoid(o).float().cpu().numpy().squeeze()

def predict_single_model(model: nn.Module, volume: np.ndarray, meta_tensor: torch.Tensor) -> np.ndarray:
    centers = np.linspace(volume.shape[0]//2 - 2, volume.shape[0]//2 + 2, CFG.slice_bags).astype(int)
    bag = []
    for c in centers:
        img = _create_multichannel_input_with_thickness(volume, c, CFG.thickness)
        bag.append(_tta_predict(model, img, meta_tensor))
    return np.mean(bag, axis=0)

def _validate_predictions(p:np.ndarray)->np.ndarray:
    if p.shape!=(len(LABEL_COLS),): p = np.resize(p, len(LABEL_COLS))
    p = np.nan_to_num(p, nan=0.1, posinf=0.9, neginf=0.0)
    return np.clip(p, 1e-3, 1-1e-3)

# ================= Orchestration =================
_MODELS_LOADED = False
_COUNTER = 0

def _manage_memory():
    global _COUNTER
    _COUNTER += 1
    if _COUNTER % CFG.cleanup_every == 0:
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        gc.collect()

def _predict_inner(series_path: str) -> pl.DataFrame:
    global _MODELS_LOADED
    if not _MODELS_LOADED:
        load_models()
        _MODELS_LOADED = True

    vol, meta = process_dicom_series(series_path)
    meta_t = _prepare_metadata_tensor(meta)
    model = list(MODELS.values())[0]
    pred = predict_single_model(model, vol, meta_t)
    pred = _validate_predictions(pred)
    _manage_memory()
    return pl.DataFrame(data=[pred.tolist()], schema=LABEL_COLS, orient='row')

def _create_fallback_predictions()->pl.DataFrame:
    vals = [0.05]*(len(LABEL_COLS)-1)+[0.1]
    return pl.DataFrame(data=[vals], schema=LABEL_COLS, orient='row')

def predict(series_path: str) -> pl.DataFrame:
    try:
        if not os.path.exists(series_path):
            return _create_fallback_predictions()
        return _predict_inner(series_path)
    except Exception:
        return _create_fallback_predictions()
    finally:
        try:
            shared = '/kaggle/shared'
            if os.path.exists(shared): shutil.rmtree(shared, ignore_errors=True)
            os.makedirs(shared, exist_ok=True)
            if torch.cuda.is_available():
                torch.cuda.empty_cache(); torch.cuda.synchronize()
            gc.collect()
        except: pass

# ================= Main (quiet) =================
def main():
    # Do not print; Kaggle will handle outputs/submission.parquet
    server = rsna.RSNAInferenceServer(predict)
    if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        server.serve()
    else:
        server.run_local_gateway()

if __name__ == "__main__":
    main()
