# RSNA 2025 Intracranial Aneurysm Detection - Two-Step Inference

This notebook performs two-step inference using trained 2.5D EfficientNet hybrid models.

## Two-Step Approach
1. **Step 1**: 14-class model predicts all 13 anatomical locations + "Aneurysm Present"
2. **Step 2**: Binary model predicts refined "Aneurysm Present" 
3. **Final Output**: Use binary model's "Aneurysm Present" to overwrite 14th class

## Model Details
- Architecture: tf_efficientnet_b0
- Training: 5-fold cross-validation
- Input: 2.5D windows (5-slice)
- Dual-stream: Full image + ROI processing
- Models: 14-class (2025-09-11-20-34-47) + Binary (2025-09-15-05-27-19)


In [None]:
import os
import gc
import re
import cv2
import math
import numpy as np
import pandas as pd
import polars as pl
import pydicom
import torch
import torch.nn as nn
import timm
from collections import defaultdict
from typing import List, Tuple
import shutil
from sklearn.metrics import roc_auc_score

# Import normalization functions from dedicated module
from normalization import normalize_dicom_series, apply_rescale_intercept_slope

# Import utility functions
from utils import (
    take_window, coords_to_px, crop_and_resize_hwc, make_bbox_px, 
    valid_coords, load_cached_volume, window_to_full_and_roi
)

# Import data processing functions
from data_processing import sort_dicom_slices

# Import model classes and creation functions
from model import HybridAneurysmModel, BinaryAneurysmModel, create_model, create_binary_model

# Kaggle server
import kaggle_evaluation.rsna_inference_server

# ========= Competition schema =========
ID_COL = 'SeriesInstanceUID'
LABEL_COLS = [
    'Left Infraclinoid Internal Carotid Artery',
    'Right Infraclinoid Internal Carotid Artery',
    'Left Supraclinoid Internal Carotid Artery',
    'Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery',
    'Right Middle Cerebral Artery',
    'Anterior Communicating Artery',
    'Left Anterior Cerebral Artery',
    'Right Anterior Cerebral Artery',
    'Left Posterior Communicating Artery',
    'Right Posterior Communicating Artery',
    'Basilar Tip',
    'Other Posterior Circulation',
    'Aneurysm Present',
]

# ========= Inference config =========
IMG_SIZE = 224
OFFSETS = (-2, -1, 0, 1, 2)   # window length 5
IN_CHANS = len(OFFSETS)
BATCH_SIZE = 128
AGGREGATE = "max"  # max/mean/topk_mean
USE_ROI = False     # coords not available on test → use same stream for full+roi
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model weights location - update this path to match your uploaded dataset
CANDIDATE_MODEL_DIRS = [
    "/kaggle/input/2025-09-11-20-34-47",        # 14-class models
    "/kaggle/input/2025-09-15-05-27-19",        # binary models
    "/kaggle/working",                           # runtime dir
    ".",                                         # current dir
]


In [None]:
# ========= Model configuration =========
# Model classes are imported from model.py
# HybridAneurysmModel: 14-class classification model (first step)
# BinaryAneurysmModel: Binary classification model (second step)

class SimpleConfig:
    """Simple config class for model initialization"""
    def __init__(self, architecture: str, in_channels: int, num_classes: int):
        self.architecture = architecture
        self.in_channels = in_channels
        self.num_classes = num_classes

# Create config objects for both model types
config_14class = SimpleConfig("tf_efficientnet_b0", IN_CHANS, len(LABEL_COLS))
config_binary = SimpleConfig("tf_efficientnet_b0", IN_CHANS, 1)


In [None]:
# ========= Model loading and discovery =========
_ckpt_cache_14class = None
_ckpt_cache_binary = None
_models_14class = None
_models_binary = None

def discover_checkpoints_14class() -> List[Tuple[str, str]]:
    """Discover 14-class model checkpoints"""
    found: List[Tuple[str, str]] = []
    for base in CANDIDATE_MODEL_DIRS:
        if not os.path.isdir(base):
            continue
        for root, _, files in os.walk(base):
            for f in files:
                if f.endswith('.pth') and 'tf_efficientnet_b0' in f and 'fold' in f and ('best' in f or 'final' in f):
                    # Check if it's from the 14-class training (models/2025-09-11-20-34-47)
                    if '2025-09-11-20-34-47' in root or '14class' in f.lower():
                        arch = 'tf_efficientnet_b0'
                        found.append((arch, os.path.join(root, f)))
    found.sort(key=lambda x: x[1])
    return found

def discover_checkpoints_binary() -> List[Tuple[str, str]]:
    """Discover binary model checkpoints"""
    found: List[Tuple[str, str]] = []
    for base in CANDIDATE_MODEL_DIRS:
        if not os.path.isdir(base):
            continue
        for root, _, files in os.walk(base):
            for f in files:
                if f.endswith('.pth') and 'tf_efficientnet_b0' in f and 'fold' in f and ('best' in f or 'final' in f):
                    # Check if it's from the binary training (models/2025-09-15-05-27-19)
                    if '2025-09-15-05-27-19' in root or 'binary' in f.lower():
                        arch = 'tf_efficientnet_b0'
                        found.append((arch, os.path.join(root, f)))
    found.sort(key=lambda x: x[1])
    return found

def load_hybrid_model(arch_name: str, weight_path: str, num_classes: int) -> nn.Module:
    """Load 14-class model using model.py creation function"""
    model = create_model(config_14class)
    state = torch.load(weight_path, map_location=DEVICE)
    
    # Handle different state dict formats
    if isinstance(state, dict) and 'model_state_dict' in state:
        state = state['model_state_dict']
    elif isinstance(state, dict) and any(k.startswith('module.') for k in state.keys()):
        state = {k.replace('module.', '', 1): v for k, v in state.items()}
    
    model.load_state_dict(state, strict=True)
    model.eval().to(DEVICE)
    return model

def load_binary_model(arch_name: str, weight_path: str, num_classes: int) -> nn.Module:
    """Load binary model using model.py creation function"""
    model = create_binary_model(config_binary)
    state = torch.load(weight_path, map_location=DEVICE)
    
    # Handle different state dict formats
    if isinstance(state, dict) and 'model_state_dict' in state:
        state = state['model_state_dict']
    elif isinstance(state, dict) and any(k.startswith('module.') for k in state.keys()):
        state = {k.replace('module.', '', 1): v for k, v in state.items()}
    
    model.load_state_dict(state, strict=True)
    model.eval().to(DEVICE)
    return model

def get_models_14class() -> List[Tuple[str, nn.Module]]:
    """Get 14-class models"""
    global _ckpt_cache_14class, _models_14class
    if _models_14class is not None:
        return _models_14class
    _ckpt_cache_14class = discover_checkpoints_14class()
    if not _ckpt_cache_14class:
        raise FileNotFoundError('No 14-class model checkpoints found. Make sure model dataset is attached.')
    mods: List[Tuple[str, nn.Module]] = []
    for arch, path in _ckpt_cache_14class:
        try:
            m = load_hybrid_model(arch, path, len(LABEL_COLS))
            mods.append((arch, m))
            print(f"Loaded 14-class model: {os.path.basename(path)}")
        except Exception as e:
            print(f"Failed to load 14-class model {path}: {e}")
            continue
    if not mods:
        raise RuntimeError('Failed to load any 14-class checkpoints from discovered files.')
    _models_14class = mods
    print(f"Loaded {len(_models_14class)} 14-class models total")
    return _models_14class

def get_models_binary() -> List[Tuple[str, nn.Module]]:
    """Get binary models"""
    global _ckpt_cache_binary, _models_binary
    if _models_binary is not None:
        return _models_binary
    _ckpt_cache_binary = discover_checkpoints_binary()
    if not _ckpt_cache_binary:
        raise FileNotFoundError('No binary model checkpoints found. Make sure model dataset is attached.')
    mods: List[Tuple[str, nn.Module]] = []
    for arch, path in _ckpt_cache_binary:
        try:
            m = load_binary_model(arch, path, 1)  # Binary model has 1 output
            mods.append((arch, m))
            print(f"Loaded binary model: {os.path.basename(path)}")
        except Exception as e:
            print(f"Failed to load binary model {path}: {e}")
            continue
    if not mods:
        raise RuntimeError('Failed to load any binary checkpoints from discovered files.')
    _models_binary = mods
    print(f"Loaded {len(_models_binary)} binary models total")
    return _models_binary


In [None]:
# ========= Two-step inference pipeline =========
@torch.no_grad()
def predict_series_probs_two_step(dicoms) -> np.ndarray:
    """
    Two-step inference pipeline:
    1. First step: 14-class model predicts all 13 anatomical locations + "Aneurysm Present"
    2. Second step: Binary model predicts refined "Aneurysm Present" 
    3. Final output: Use binary model's "Aneurysm Present" to overwrite 14th class
    """
    # Get both model types
    models_14class = get_models_14class()
    models_binary = get_models_binary()
    
    # Build normalized volume [N,H,W] uint8 (matching training data format)
    vol = normalize_dicom_series(dicoms, target_size=IMG_SIZE, apply_rescale=True)
    N = vol.shape[0]
    # Prepare coords zeros on test
    coords = np.zeros((N, 2), dtype=np.float32)

    # Step 1: 14-class prediction
    print("Step 1: Running 14-class models...")
    all_14class_probs = []
    for _, model in models_14class:
        batch_full, batch_roi, batch_coords = [], [], []
        probs_accum = []
        for c in range(N):
            win = take_window(vol, c, OFFSETS)   # [C,H,W] float32
            win_chw = np.transpose(win, (0, 1, 2))           # still [C,H,W]
            full_chw, roi_chw = window_to_full_and_roi(win_chw, coords[c], IMG_SIZE)
            batch_full.append(full_chw)
            batch_roi.append(roi_chw)
            batch_coords.append(coords[c])
            # flush by batch
            if len(batch_full) == BATCH_SIZE or c == N - 1:
                xb_full = torch.from_numpy(np.stack(batch_full).astype(np.float32)).to(DEVICE)
                xb_roi  = torch.from_numpy(np.stack(batch_roi).astype(np.float32)).to(DEVICE)
                cb      = torch.from_numpy(np.stack(batch_coords).astype(np.float32)).to(DEVICE)
                logits = model(xb_full, xb_roi, cb)
                probs = torch.sigmoid(logits).cpu().numpy()
                probs_accum.append(probs)
                batch_full.clear(); batch_roi.clear(); batch_coords.clear()
        probs_all = np.concatenate(probs_accum, axis=0) if probs_accum else np.zeros((1, len(LABEL_COLS)), dtype=np.float32)
        if AGGREGATE == 'max':
            series_prob = probs_all.max(axis=0)
        elif AGGREGATE == 'mean':
            series_prob = probs_all.mean(axis=0)
        else:  # topk_mean
            k = max(1, N // 5)
            series_prob = np.sort(probs_all, axis=0)[-k:].mean(axis=0)
        all_14class_probs.append(series_prob)
        # free memory between models
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Ensemble 14-class predictions
    probs_14class = np.mean(np.stack(all_14class_probs, axis=0), axis=0)
    
    # Step 2: Binary prediction
    print("Step 2: Running binary models...")
    all_binary_probs = []
    for _, model in models_binary:
        batch_full, batch_roi, batch_coords = [], [], []
        probs_accum = []
        for c in range(N):
            win = take_window(vol, c, OFFSETS)   # [C,H,W] float32
            win_chw = np.transpose(win, (0, 1, 2))           # still [C,H,W]
            full_chw, roi_chw = window_to_full_and_roi(win_chw, coords[c], IMG_SIZE)
            batch_full.append(full_chw)
            batch_roi.append(roi_chw)
            batch_coords.append(coords[c])
            # flush by batch
            if len(batch_full) == BATCH_SIZE or c == N - 1:
                xb_full = torch.from_numpy(np.stack(batch_full).astype(np.float32)).to(DEVICE)
                xb_roi  = torch.from_numpy(np.stack(batch_roi).astype(np.float32)).to(DEVICE)
                cb      = torch.from_numpy(np.stack(batch_coords).astype(np.float32)).to(DEVICE)
                logits = model(xb_full, xb_roi, cb)
                probs = torch.sigmoid(logits).cpu().numpy()
                probs_accum.append(probs)
                batch_full.clear(); batch_roi.clear(); batch_coords.clear()
        probs_all = np.concatenate(probs_accum, axis=0) if probs_accum else np.zeros((1, 1), dtype=np.float32)
        if AGGREGATE == 'max':
            series_prob = probs_all.max(axis=0)
        elif AGGREGATE == 'mean':
            series_prob = probs_all.mean(axis=0)
        else:  # topk_mean
            k = max(1, N // 5)
            series_prob = np.sort(probs_all, axis=0)[-k:].mean(axis=0)
        all_binary_probs.append(series_prob)
        # free memory between models
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Ensemble binary predictions
    probs_binary = np.mean(np.stack(all_binary_probs, axis=0), axis=0)
    
    # Step 3: Combine results - use binary model's "Aneurysm Present" to overwrite 14th class
    final_probs = probs_14class.copy()
    final_probs[-1] = probs_binary[0]  # Overwrite "Aneurysm Present" with binary prediction
    
    print(f"14-class 'Aneurysm Present': {probs_14class[-1]:.4f}")
    print(f"Binary 'Aneurysm Present': {probs_binary[0]:.4f}")
    print(f"Final 'Aneurysm Present': {final_probs[-1]:.4f}")
    
    return final_probs


In [None]:
# ========= Kaggle-required predict function =========
def predict(series_path: str) -> pl.DataFrame | pd.DataFrame:
    series_id = os.path.basename(series_path)

    # Collect all DICOM files
    filepaths = []
    for root, _, files in os.walk(series_path):
        for f in files:
            if f.endswith('.dcm'):
                filepaths.append(os.path.join(root, f))
    
    if not filepaths:
        # Return zeros if no DICOMs found
        zeros = [[series_id] + [0.0] * len(LABEL_COLS)]
        predictions = pl.DataFrame(data=zeros, schema=[ID_COL, *LABEL_COLS], orient='row')
        return predictions.drop(ID_COL)
    
    # Sort DICOMs and perform two-step inference
    dicoms = sort_dicom_slices(filepaths)
    probs = predict_series_probs_two_step(dicoms)

    # Build output (one row)
    data = [[series_id] + probs.tolist()]
    predictions = pl.DataFrame(data=data, schema=[ID_COL, *LABEL_COLS], orient='row')

    # Required cleanup to avoid disk pressure
    shutil.rmtree('/kaggle/shared', ignore_errors=True)

    # Server expects features only (without ID_COL)
    return predictions.drop(ID_COL)


In [None]:
# Test model loading and basic functionality
print("Testing two-step model discovery...")
try:
    checkpoints_14class = discover_checkpoints_14class()
    print(f"Found {len(checkpoints_14class)} 14-class model checkpoints:")
    for arch, path in checkpoints_14class:
        print(f"  - {arch}: {path}")
    
    checkpoints_binary = discover_checkpoints_binary()
    print(f"Found {len(checkpoints_binary)} binary model checkpoints:")
    for arch, path in checkpoints_binary:
        print(f"  - {arch}: {path}")
        
    if len(checkpoints_14class) == 0:
        print("WARNING: No 14-class models found!")
    if len(checkpoints_binary) == 0:
        print("WARNING: No binary models found!")
        
except Exception as e:
    print(f"Error discovering models: {e}")

# Test if we're in competition mode
if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    print("Running in test mode - not competition submission")


In [None]:
# ========= Start RSNA inference server =========
inference_server = kaggle_evaluation.rsna_inference_server.RSNAInferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    # Competition mode - serve predictions
    inference_server.serve()
else:
    # Test mode - run local gateway
    inference_server.run_local_gateway()
    if os.path.exists('/kaggle/working/submission.parquet'):
        print("Submission file created successfully")
        submission_df = pl.read_parquet('/kaggle/working/submission.parquet')
        print(f"Submission shape: {submission_df.shape}")
        print(submission_df.head())
    else:
        print("No submission file generated")
