# RSNA 2025 Intracranial Aneurysm Detection - Inference

This notebook performs inference using the trained 2.5D EfficientNet hybrid model.

## Model Details
- Architecture: tf_efficientnet_b0
- Training: 5-fold cross-validation
- Input: 2.5D windows (5-slice)
- Dual-stream: Full image + ROI processing


In [None]:
import os
import gc
import re
import cv2
import math
import numpy as np
import pandas as pd
import polars as pl
import pydicom
import torch
import torch.nn as nn
import timm
from collections import defaultdict
from typing import List, Tuple
import shutil
from sklearn.metrics import roc_auc_score

# Kaggle server
import kaggle_evaluation.rsna_inference_server

# ========= Competition schema =========
ID_COL = 'SeriesInstanceUID'
LABEL_COLS = [
    'Left Infraclinoid Internal Carotid Artery',
    'Right Infraclinoid Internal Carotid Artery',
    'Left Supraclinoid Internal Carotid Artery',
    'Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery',
    'Right Middle Cerebral Artery',
    'Anterior Communicating Artery',
    'Left Anterior Cerebral Artery',
    'Right Anterior Cerebral Artery',
    'Left Posterior Communicating Artery',
    'Right Posterior Communicating Artery',
    'Basilar Tip',
    'Other Posterior Circulation',
    'Aneurysm Present',
]

# ========= Inference config =========
IMG_SIZE = 224
OFFSETS = (-2, -1, 0, 1, 2)   # window length 5
IN_CHANS = len(OFFSETS)
BATCH_SIZE = 16
AGGREGATE = "max"  # max/mean/topk_mean
USE_ROI = False     # coords not available on test → use same stream for full+roi
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model weights location - update this path to match your uploaded dataset
CANDIDATE_MODEL_DIRS = [
    "/kaggle/input/2025-09-11-20-34-47",
    "/kaggle/working",                           # runtime dir
    ".",                                         # current dir
]


In [None]:
# ========= Model definition (Hybrid full + ROI + coords) =========
class HybridAneurysmModel(nn.Module):
    def __init__(self, base_model_name: str, num_classes: int):
        super().__init__()
        self.backbone = timm.create_model(base_model_name, in_chans=IN_CHANS, num_classes=0, pretrained=False)
        self.feature_dim = self.backbone.num_features
        self.coord_fc = nn.Sequential(nn.Linear(2, 32), nn.ReLU(), nn.Linear(32, 64))
        self.fc = nn.Sequential(nn.Dropout(0.3), nn.Linear(self.feature_dim * 2 + 64, num_classes))

    def forward(self, x_full: torch.Tensor, x_roi: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
        f_full = self.backbone(x_full)
        f_roi  = self.backbone(x_roi)
        f_coord = self.coord_fc(coords.float())
        return self.fc(torch.cat([f_full, f_roi, f_coord], dim=1))


In [None]:
# ========= Helper functions =========
def sort_dicom_slices(filepaths: List[str]):
    dicoms = [pydicom.dcmread(fp, force=True) for fp in filepaths]
    try:
        dicoms.sort(key=lambda d: float(d.ImagePositionPatient[2]))
    except Exception:
        dicoms.sort(key=lambda d: int(getattr(d, 'InstanceNumber', 0)))
    return dicoms

def series_to_tensor_chw(dicoms) -> np.ndarray:
    # Resize all to IMG_SIZE and apply modality-specific normalization (matching training)
    resized = []
    for d in dicoms:
        arr = d.pixel_array
        if arr is None or arr.size == 0:
            continue
        arr = arr.astype(np.float32)
        
        # Apply RescaleSlope and RescaleIntercept
        slope = getattr(d, 'RescaleSlope', 1)
        intercept = getattr(d, 'RescaleIntercept', 0)
        if slope != 1 or intercept != 0:
            arr = arr * float(slope) + float(intercept)
        
        # Apply modality-specific normalization (matching training data processing)
        modality = getattr(d, 'Modality', 'MR')
        if modality == 'CT':
            # CT: Fixed range normalization [0, 500] → [0, 255]
            arr = np.clip(arr, 0, 500)
            arr = (arr - 0) / (500 - 0)
            arr = (arr * 255).astype(np.uint8)
        else:
            # MR modalities: Percentile normalization [p1, p99] → [0, 255]
            p1, p99 = np.percentile(arr, [1, 99])
            if p99 > p1:
                arr = np.clip(arr, p1, p99)
                arr = (arr - p1) / (p99 - p1)
                arr = (arr * 255).astype(np.uint8)
            else:
                # Fallback: min-max normalization
                img_min, img_max = arr.min(), arr.max()
                if img_max > img_min:
                    arr = (arr - img_min) / (img_max - img_min)
                    arr = (arr * 255).astype(np.uint8)
                else:
                    arr = np.zeros_like(arr, dtype=np.uint8)
        
        # Handle multi-frame DICOMs
        if arr.ndim == 3:  # Multi-frame DICOM
            for frame in arr:
                frame_resized = cv2.resize(frame, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
                resized.append(frame_resized)
        else:  # Single-frame DICOM
            arr = cv2.resize(arr, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
            resized.append(arr)
            
    if len(resized) == 0:
        # fallback to zeros to avoid crashes (rare)
        vol = np.zeros((1, IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    else:
        vol = np.stack(resized, axis=0)  # [N,H,W] uint8
    
    return vol  # [N,H,W] uint8 - matching training data format

def take_window_from_volume(vol_nhw: np.ndarray, center_idx: int, offsets=OFFSETS) -> np.ndarray:
    # vol_nhw: [N,H,W] uint8
    N = vol_nhw.shape[0]
    idxs = [min(max(0, center_idx + o), N - 1) for o in offsets]
    win = vol_nhw[idxs, :, :]              # [len(offsets),H,W]
    return win.astype(np.float32, copy=False)  # Convert to float32 for model input

def window_to_full_and_roi(win_chw: np.ndarray, coords: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    # No coords on test → identical streams
    return win_chw, win_chw


In [None]:
# ========= Model loading and discovery =========
_ckpt_cache = None
_models = None

def discover_checkpoints() -> List[Tuple[str, str]]:
    # Returns list of (arch_name, path)
    found: List[Tuple[str, str]] = []
    for base in CANDIDATE_MODEL_DIRS:
        if not os.path.isdir(base):
            continue
        for root, _, files in os.walk(base):
            for f in files:
                if f.endswith('.pth') and 'tf_efficientnet_b0' in f and 'fold' in f and ('best' in f or 'final' in f):
                    arch = 'tf_efficientnet_b0'  # Fixed architecture
                    found.append((arch, os.path.join(root, f)))
    # stable ordering
    found.sort(key=lambda x: x[1])
    return found

def load_hybrid_model(arch_name: str, weight_path: str) -> nn.Module:
    model = HybridAneurysmModel(base_model_name=arch_name, num_classes=len(LABEL_COLS))
    state = torch.load(weight_path, map_location=DEVICE)
    
    # Handle different state dict formats
    if isinstance(state, dict) and 'model_state_dict' in state:
        state = state['model_state_dict']
    elif isinstance(state, dict) and any(k.startswith('module.') for k in state.keys()):
        state = {k.replace('module.', '', 1): v for k, v in state.items()}
    
    # Fix layer name mismatch: classifier -> fc
    if isinstance(state, dict):
        state = {k.replace('classifier.', 'fc.') if k.startswith('classifier.') else k: v for k, v in state.items()}
    
    model.load_state_dict(state, strict=True)
    model.eval().to(DEVICE)
    return model

def get_models() -> List[Tuple[str, nn.Module]]:
    global _ckpt_cache, _models
    if _models is not None:
        return _models
    _ckpt_cache = discover_checkpoints()
    if not _ckpt_cache:
        raise FileNotFoundError('No model checkpoints found. Make sure model dataset is attached.')
    mods: List[Tuple[str, nn.Module]] = []
    for arch, path in _ckpt_cache:
        try:
            m = load_hybrid_model(arch, path)
            mods.append((arch, m))
            print(f"Loaded model: {os.path.basename(path)}")
        except Exception as e:
            print(f"Failed to load {path}: {e}")
            continue
    if not mods:
        raise RuntimeError('Failed to load any checkpoints from discovered files.')
    _models = mods
    print(f"Loaded {len(_models)} models total")
    return _models


In [None]:
# ========= Inference pipeline =========
@torch.no_grad()
def predict_series_probs(dicoms) -> np.ndarray:
    models = get_models()
    # Build normalized volume [N,H,W] uint8 (matching training data format)
    vol = series_to_tensor_chw(dicoms)
    N = vol.shape[0]
    # Prepare coords zeros on test
    coords = np.zeros((N, 2), dtype=np.float32)

    all_model_probs = []
    for _, model in models:
        batch_full, batch_roi, batch_coords = [], [], []
        probs_accum = []
        for c in range(N):
            win = take_window_from_volume(vol, c, OFFSETS)   # [C,H,W] float32
            win_chw = np.transpose(win, (0, 1, 2))           # still [C,H,W]
            full_chw, roi_chw = window_to_full_and_roi(win_chw, coords[c])
            batch_full.append(full_chw)
            batch_roi.append(roi_chw)
            batch_coords.append(coords[c])
            # flush by batch
            if len(batch_full) == BATCH_SIZE or c == N - 1:
                xb_full = torch.from_numpy(np.stack(batch_full).astype(np.float32)).to(DEVICE)
                xb_roi  = torch.from_numpy(np.stack(batch_roi).astype(np.float32)).to(DEVICE)
                cb      = torch.from_numpy(np.stack(batch_coords).astype(np.float32)).to(DEVICE)
                logits = model(xb_full, xb_roi, cb)
                probs = torch.sigmoid(logits).cpu().numpy()
                probs_accum.append(probs)
                batch_full.clear(); batch_roi.clear(); batch_coords.clear()
        probs_all = np.concatenate(probs_accum, axis=0) if probs_accum else np.zeros((1, len(LABEL_COLS)), dtype=np.float32)
        if AGGREGATE == 'max':
            series_prob = probs_all.max(axis=0)
        elif AGGREGATE == 'mean':
            series_prob = probs_all.mean(axis=0)
        else:  # topk_mean
            k = max(1, N // 5)
            series_prob = np.sort(probs_all, axis=0)[-k:].mean(axis=0)
        all_model_probs.append(series_prob)
        # free memory between models
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    # ensemble (probability average)
    return np.mean(np.stack(all_model_probs, axis=0), axis=0)


In [None]:
# ========= Kaggle-required predict function =========
def predict(series_path: str) -> pl.DataFrame | pd.DataFrame:
    series_id = os.path.basename(series_path)

    # Collect all DICOM files
    filepaths = []
    for root, _, files in os.walk(series_path):
        for f in files:
            if f.endswith('.dcm'):
                filepaths.append(os.path.join(root, f))
    
    if not filepaths:
        # Return zeros if no DICOMs found
        zeros = [[series_id] + [0.0] * len(LABEL_COLS)]
        predictions = pl.DataFrame(data=zeros, schema=[ID_COL, *LABEL_COLS], orient='row')
        return predictions.drop(ID_COL)
    
    # Sort DICOMs and perform inference
    dicoms = sort_dicom_slices(filepaths)
    probs = predict_series_probs(dicoms)

    # Build output (one row)
    data = [[series_id] + probs.tolist()]
    predictions = pl.DataFrame(data=data, schema=[ID_COL, *LABEL_COLS], orient='row')

    # Required cleanup to avoid disk pressure
    shutil.rmtree('/kaggle/shared', ignore_errors=True)

    # Server expects features only (without ID_COL)
    return predictions.drop(ID_COL)


In [None]:
# Test model loading and basic functionality
print("Testing model discovery...")
try:
    checkpoints = discover_checkpoints()
    print(f"Found {len(checkpoints)} model checkpoints:")
    for arch, path in checkpoints:
        print(f"  - {arch}: {path}")
except Exception as e:
    print(f"Error discovering models: {e}")

# Test if we're in competition mode
if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    print("Running in test mode - not competition submission")


In [None]:
# ========= Start RSNA inference server =========
inference_server = kaggle_evaluation.rsna_inference_server.RSNAInferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    # Competition mode - serve predictions
    inference_server.serve()
else:
    # Test mode - run local gateway
    inference_server.run_local_gateway()
    if os.path.exists('/kaggle/working/submission.parquet'):
        print("Submission file created successfully")
        submission_df = pl.read_parquet('/kaggle/working/submission.parquet')
        print(f"Submission shape: {submission_df.shape}")
        print(submission_df.head())
    else:
        print("No submission file generated")
