In [1]:
#!/usr/bin/env python
# coding: utf-8
# RSNA IAD — Single-model (EffNetV2-S) + TTA=4, quiet, <=12h

import os, gc, shutil, warnings
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from contextlib import contextmanager

warnings.filterwarnings("ignore")

# Data handling
import numpy as np
import polars as pl
import pydicom
import cv2

# DL
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import timm

# Augs
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Kaggle API
import kaggle_evaluation.rsna_inference_server as rsna

# -------- Speed knobs --------
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

QUIET = True  # keep outputs minimal

def _log(msg):
    if not QUIET:
        print(msg)

# -------- Device --------
def setup_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = setup_device()

# -------- Config --------
ID_COL = 'SeriesInstanceUID'
LABEL_COLS = [
    'Left Infraclinoid Internal Carotid Artery','Right Infraclinoid Internal Carotid Artery',
    'Left Supraclinoid Internal Carotid Artery','Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery','Right Middle Cerebral Artery',
    'Anterior Communicating Artery','Left Anterior Cerebral Artery','Right Anterior Cerebral Artery',
    'Left Posterior Communicating Artery','Right Posterior Communicating Artery',
    'Basilar Tip','Other Posterior Circulation','Aneurysm Present',
]

MODEL_PATHS = {
    'tf_efficientnetv2_s': '/kaggle/input/rsna-iad-trained-models/models/tf_efficientnetv2_s_fold0_best.pth',
    'convnext_small': '/kaggle/input/rsna-iad-trained-models/models/convnext_small_fold0_best.pth',
    'swin_small_patch4_window7_224': '/kaggle/input/rsna-iad-trained-models/models/swin_small_patch4_window7_224_fold0_best.pth'
}

class InferenceConfig:
    def __init__(self):
        self.model_selection = 'tf_efficientnetv2_s'
        self.use_ensemble = False

        self.image_size = 512
        self.num_slices = 32
        self.use_windowing = True

        self.batch_size = 1
        self.use_amp = torch.cuda.is_available()
        self.use_tta = True
        self.tta_transforms = 4  # (identity, hflip, vflip, small rotate)

        self.enable_memory_cleanup = True
        self.cleanup_frequency = 16
        self.max_retries = 2
        self.fallback_enabled = True

        self.windowing_params = {
            'CT': (40, 80),'CTA': (50, 350),'MRA': (600, 1200),'MRI': (40, 80),'default': (40, 80)
        }

CFG = InferenceConfig()

# -------- Model --------
class MultiBackboneModel(nn.Module):
    def __init__(self, model_name: str, num_classes: int = 14,
                 pretrained: bool = True, drop_rate: float = 0.0, drop_path_rate: float = 0.0):
        super().__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self._create_backbone(model_name, pretrained, drop_rate, drop_path_rate)
        self._determine_feature_dimensions()
        self._create_classifier(drop_rate)

    def _create_backbone(self, name, pretrained, drop_rate, drop_path_rate):
        kw = dict(pretrained=pretrained, in_chans=3, drop_rate=drop_rate, num_classes=0, global_pool='')
        if 'swin' in name:
            kw.update({'drop_path_rate': drop_path_rate, 'img_size': CFG.image_size})
        elif 'convnext' in name:
            kw['drop_path_rate'] = drop_path_rate
        self.backbone = timm.create_model(name, **kw)

    def _determine_feature_dimensions(self):
        with torch.no_grad():
            dummy = torch.zeros(1,3,CFG.image_size,CFG.image_size)
            f = self.backbone(dummy)
            if f.ndim == 4:
                self.num_features, self.needs_pool, self.needs_seq_pool = f.shape[1], True, False
                self.global_pool = nn.AdaptiveAvgPool2d(1)
            elif f.ndim == 3:
                self.num_features, self.needs_pool, self.needs_seq_pool = f.shape[-1], False, True
            else:
                self.num_features, self.needs_pool, self.needs_seq_pool = f.shape[1], False, False

    def _create_classifier(self, drop_rate):
        self.meta_fc = nn.Sequential(
            nn.Linear(2,16), nn.ReLU(inplace=True), nn.Dropout(0.2),
            nn.Linear(16,32), nn.ReLU(inplace=True), nn.Dropout(0.1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(self.num_features+32,512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Dropout(drop_rate),
            nn.Linear(512,256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Dropout(drop_rate*0.5),
            nn.Linear(256,self.num_classes)
        )

    def _pool_features(self, f):
        if self.needs_pool: return self.global_pool(f).flatten(1)
        if self.needs_seq_pool: return f.mean(1)
        if f.ndim==4: return F.adaptive_avg_pool2d(f,1).flatten(1)
        if f.ndim==3: return f.mean(1)
        return f

    def forward(self, image: torch.Tensor, meta: torch.Tensor) -> torch.Tensor:
        imgf = self._pool_features(self.backbone(image))
        metf = self.meta_fc(meta)
        return self.classifier(torch.cat([imgf,metf], dim=1))

# -------- DICOM helpers --------
@contextmanager
def dicom_error_handler(_): 
    try: yield
    except Exception: raise

def _is_dicom_file(p:str)->bool:
    try:
        with open(p,'rb') as f:
            f.seek(128); return f.read(4)==b'DICM'
    except: return False

def _apply_rescale(img:np.ndarray, ds:pydicom.Dataset)->np.ndarray:
    try:
        slope = float(getattr(ds,'RescaleSlope',1.0))
        intercept = float(getattr(ds,'RescaleIntercept',0.0))
        return img*slope + intercept
    except: return img

def _get_window_from_ds(ds)->Tuple[float,float]:
    try:
        wc = ds.WindowCenter; ww = ds.WindowWidth
        if isinstance(wc, pydicom.multival.MultiValue): wc = float(wc[0])
        else: wc = float(wc)
        if isinstance(ww, pydicom.multival.MultiValue): ww = float(ww[0])
        else: ww = float(ww)
        if ww <= 1: ww = 1.0
        return wc, ww
    except:
        modality = getattr(ds,'Modality','CT')
        return CFG.windowing_params.get(modality, CFG.windowing_params['default'])

def apply_dicom_windowing(img:np.ndarray, center:float, width:float)->np.ndarray:
    imin, imax = center - width/2.0, center + width/2.0
    img = np.clip(img, imin, imax)
    img = (img - imin) / max(imax - imin, 1e-6)
    return (img*255).astype(np.uint8)

def _process_pixel_array(img: np.ndarray) -> np.ndarray:
    if img.ndim==3 and img.shape[-1]==3:
        img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32)
    elif img.ndim==3:
        img = img[img.shape[0]//2]
    elif img.ndim>3:
        img = img.reshape(img.shape[-2], img.shape[-1])
    return img

def _safe_age(ds):
    try:
        s = str(getattr(ds,'PatientAge','050Y'))[:3]
        d = int(''.join([c for c in s if c.isdigit()]) or 50)
        return max(0,min(d,120))
    except: return 50

def _safe_sex(ds):
    try: return 1 if str(getattr(ds,'PatientSex','M')).upper().startswith('M') else 0
    except: return 0

def process_dicom_series(series_path: str) -> Tuple[np.ndarray, Dict]:
    series_path = Path(series_path)
    files = []
    for root,_,fs in os.walk(series_path):
        for fn in fs:
            p = os.path.join(root, fn)
            if fn.lower().endswith(('.dcm','.dicom')) or _is_dicom_file(p):
                files.append(p)
    if not files:
        return _get_default_volume_and_metadata()

    # sort by z (ImagePositionPatient[2]) then InstanceNumber fallback
    def sort_key(p):
        try:
            ds = pydicom.dcmread(p, stop_before_pixels=True, force=True)
            ipp = getattr(ds,'ImagePositionPatient', None)
            if ipp is not None and len(ipp)>=3:
                return float(ipp[2])
            return int(getattr(ds,'InstanceNumber',0))
        except: return 0
    files.sort(key=sort_key)

    slices, metadata, errors = [], {}, 0
    for i,fp in enumerate(files):
        try:
            with dicom_error_handler(fp):
                ds = pydicom.dcmread(fp, force=True)
                img = _process_pixel_array(ds.pixel_array.astype(np.float32))
                if i==0:
                    metadata = {'modality': getattr(ds,'Modality','CT'),
                                'age': _safe_age(ds), 'sex': _safe_sex(ds)}
                img = _apply_rescale(img, ds)
                if CFG.use_windowing:
                    c,w = _get_window_from_ds(ds)
                    img = apply_dicom_windowing(img, c, w)
                else:
                    mn, mx = img.min(), img.max()
                    img = ((img-mn)/max(mx-mn,1e-6)*255).astype(np.uint8)
                img = cv2.resize(img, (CFG.image_size, CFG.image_size), interpolation=cv2.INTER_LINEAR)
                slices.append(img)
        except Exception:
            errors += 1
            if errors > len(files)*0.5:
                return _get_default_volume_and_metadata()

    if not metadata:
        metadata = {'age':50,'sex':0,'modality':'CT'}

    volume = _create_volume_from_slices(slices)
    return volume, metadata

def _create_volume_from_slices(slices: List[np.ndarray]) -> np.ndarray:
    if not slices:
        return np.zeros((CFG.num_slices, CFG.image_size, CFG.image_size), np.uint8)
    vol = np.asarray(slices)
    n = CFG.num_slices
    if len(vol) > n:
        idx = np.linspace(0, len(vol)-1, n).astype(int)
        vol = vol[idx]
    elif len(vol) < n:
        pad = n - len(vol)
        if len(vol)==1:
            vol = np.repeat(vol, n, axis=0)
        else:
            vol = np.pad(vol, ((0,pad),(0,0),(0,0)), mode='edge')
    return vol

def _get_default_volume_and_metadata():
    return (np.zeros((CFG.num_slices, CFG.image_size, CFG.image_size), np.uint8),
            {'age':50,'sex':0,'modality':'CT'})

# -------- Transforms --------
def get_inference_transform():
    return A.Compose([
        A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225], max_pixel_value=255.0),
        ToTensorV2()
    ])

def get_tta_transforms():
    base = [A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ToTensorV2()]
    return [
        A.Compose(base),                                  # identity
        A.Compose([A.HorizontalFlip(p=1.0)] + base),      # hflip
        A.Compose([A.VerticalFlip(p=1.0)] + base),        # vflip
        A.Compose([A.Rotate(limit=15, p=1.0)] + base),    # small rotate
    ]

# -------- Globals --------
MODELS: Dict[str, nn.Module] = {}
TRANSFORM: Optional[A.Compose] = None
TTA_TRANSFORMS: Optional[List[A.Compose]] = None
PREDICTION_COUNT = 0

# -------- Loading --------
def _validate_model(model: nn.Module):
    with torch.no_grad():
        dummy_image = torch.randn(1,3,CFG.image_size,CFG.image_size, device=device).to(memory_format=torch.channels_last)
        dummy_meta  = torch.randn(1,2, device=device)
        out = model(dummy_image, dummy_meta)
        if out.shape != (1, len(LABEL_COLS)):
            raise RuntimeError(f"Unexpected output shape: {out.shape}")

def load_single_model(model_name: str, model_path: str) -> nn.Module:
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Missing model: {model_path}")
    ckpt = torch.load(model_path, map_location=device, weights_only=False)
    tr = ckpt.get('training_config', {})
    if 'image_size' in tr:
        CFG.image_size = tr['image_size']
    model = MultiBackboneModel(model_name, num_classes=tr.get('num_classes',14),
                               pretrained=False, drop_rate=0.0, drop_path_rate=0.0)
    try:
        model.load_state_dict(ckpt['model_state_dict'], strict=True)
    except RuntimeError:
        model.load_state_dict(ckpt['model_state_dict'], strict=False)
    model.to(device)
    model.eval()
    model.to(memory_format=torch.channels_last)
    _validate_model(model)
    return model

def load_models():
    global MODELS, TRANSFORM, TTA_TRANSFORMS
    MODELS.clear()
    name = CFG.model_selection
    MODELS[name] = load_single_model(name, MODEL_PATHS[name])
    TRANSFORM = get_inference_transform()
    if CFG.use_tta:
        TTA_TRANSFORMS = get_tta_transforms()
    _warmup_models()

def _warmup_models():
    try:
        x = torch.randn(CFG.batch_size,3,CFG.image_size,CFG.image_size, device=device).to(memory_format=torch.channels_last)
        m = torch.randn(CFG.batch_size,2, device=device)
        with torch.no_grad():
            for model in MODELS.values():
                with autocast(enabled=CFG.use_amp):
                    _ = model(x,m)   # single pass warmup
        del x, m
        if torch.cuda.is_available(): torch.cuda.empty_cache()
    except Exception:
        pass

# -------- Prediction --------
def _create_multichannel_input(volume: np.ndarray) -> np.ndarray:
    # middle slice + MIP + STD over full resampled stack (same as your 0.66 setup)
    if volume.size == 0:
        s = np.zeros((CFG.image_size, CFG.image_size), np.uint8)
        return np.stack([s,s,s], -1)
    mid = volume.shape[0] // 2
    middle = volume[mid]
    mip = np.max(volume, axis=0)
    std = volume.astype(np.float32).std(axis=0)
    if std.max() > std.min():
        std = ((std - std.min()) / max(std.max()-std.min(),1e-6) * 255).astype(np.uint8)
    else:
        std = np.full_like(middle, 128, dtype=np.uint8)
    img = np.stack([middle, mip, std], axis=-1)
    return img

def _prepare_metadata_tensor(metadata: Dict) -> torch.Tensor:
    age = float(np.clip(metadata.get('age',50)/100.0, 0.0, 1.2))
    sex = float(np.clip(int(metadata.get('sex',0)), 0, 1))
    return torch.tensor([[age,sex]], dtype=torch.float32, device=device)

def predict_single_model(model: nn.Module, image: np.ndarray, meta_tensor: torch.Tensor) -> np.ndarray:
    preds = []
    if CFG.use_tta and TTA_TRANSFORMS:
        for tfm in TTA_TRANSFORMS[:CFG.tta_transforms]:
            x = tfm(image=image)['image'].unsqueeze(0).to(device, non_blocking=True)
            x = x.to(memory_format=torch.channels_last)
            with torch.no_grad(), autocast(enabled=CFG.use_amp):
                out = model(x, meta_tensor)
            preds.append(torch.sigmoid(out).float().cpu().numpy())
        return np.mean(preds, axis=0).squeeze()
    else:
        x = TRANSFORM(image=image)['image'].unsqueeze(0).to(device, non_blocking=True)
        x = x.to(memory_format=torch.channels_last)
        with torch.no_grad(), autocast(enabled=CFG.use_amp):
            out = model(x, meta_tensor)
        return torch.sigmoid(out).float().cpu().numpy().squeeze()

def _validate_predictions(pred: np.ndarray) -> np.ndarray:
    if pred.shape != (len(LABEL_COLS),):
        pred = np.resize(pred, len(LABEL_COLS))
    pred = np.nan_to_num(pred, nan=0.1, posinf=0.9, neginf=0.0)
    return np.clip(pred, 1e-3, 1-1e-3)

def _manage_memory():
    global PREDICTION_COUNT
    PREDICTION_COUNT += 1
    if CFG.enable_memory_cleanup and PREDICTION_COUNT % CFG.cleanup_frequency == 0:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        gc.collect()

def _predict_inner(series_path: str) -> pl.DataFrame:
    global MODELS
    if not MODELS:
        load_models()
    vol, meta = process_dicom_series(series_path)
    img = _create_multichannel_input(vol)
    meta_t = _prepare_metadata_tensor(meta)
    model = MODELS[CFG.model_selection]
    pred = predict_single_model(model, img, meta_t)
    pred = _validate_predictions(pred)
    _manage_memory()
    return pl.DataFrame(data=[pred.tolist()], schema=LABEL_COLS, orient='row')

def _create_fallback_predictions() -> pl.DataFrame:
    vals = [0.05] * (len(LABEL_COLS)-1) + [0.1]
    return pl.DataFrame(data=[vals], schema=LABEL_COLS, orient='row')

def predict(series_path: str) -> pl.DataFrame:
    try:
        if not os.path.exists(series_path):
            return _create_fallback_predictions()
        return _predict_inner(series_path)
    except Exception:
        return _create_fallback_predictions()
    finally:
        try:
            shared = '/kaggle/shared'
            if os.path.exists(shared): shutil.rmtree(shared, ignore_errors=True)
            os.makedirs(shared, exist_ok=True)
            if torch.cuda.is_available():
                torch.cuda.empty_cache(); torch.cuda.synchronize()
            gc.collect()
        except Exception:
            pass

# -------- Main --------
def main():
    if not QUIET:
        print("Starting RSNA IAD inference (quiet mode=%s)" % QUIET)
    server = rsna.RSNAInferenceServer(predict)
    if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        server.serve()
    else:
        server.run_local_gateway()

if __name__ == "__main__":
    main()
