In [1]:
# ==========================================
# PHASE 2: SUBMISSION NOTEBOOK (Internet OFF)
# ==========================================

# -------------------------
# 1. GLOBAL IMPORTS
# -------------------------
# All these libraries are pre-installed in the Kaggle environment.
# No !pip install is needed.

import os
import glob
import random
import warnings
import numpy as np
import pandas as pd
import cv2
import functools
from pathlib import Path
from tqdm import tqdm
from typing import List, Tuple, Optional
import gc
import shutil

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pydicom
import polars as pl

# This is the competition's inference server
import kaggle_evaluation.rsna_inference_server as rsna_inference_server

warnings.filterwarnings('ignore')

# -------------------------
# 2. GLOBAL CONFIGURATION
# -------------------------
class Config:
    # --- This path is from your screenshot ---
    CKPT_DIR = "/kaggle/input/effecient-net-models" 
    
    # --- Model Hyperparameters (must match training) ---
    NUM_FRAMES = 8
    IMAGE_SIZE = 224
    NUM_CLASSES = 14
    MODEL_NAME_BACKBONE = "tf_efficientnetv2_s.in1k"
    
    # --- Feature Flags (must match training) ---
    USE_METADATA = True
    USE_WINDOWING = True
    USE_CLAHE = True

config = Config()

# -------------------------
# 3. GLOBAL SEED & DEVICE
# -------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# -------------------------
# 4. GLOBAL TARGETS
# -------------------------
TARGET_COLS = [
    'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery', 
    'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery',
    'Anterior Communicating Artery', 'Left Anterior Cerebral Artery',
    'Right Anterior Cerebral Artery', 'Left Posterior Communicating Artery',
    'Right Posterior Communicating Artery', 'Basilar Tip',
    'Other Posterior Circulation', 'Aneurysm Present'
]

# -------------------------
# 5. PREPROCESSING (Must match training)
# -------------------------
# These functions are identical to your training notebook
def get_windowing_params(modality: str) -> Tuple[float, float]:
    windows = {'CT': (40, 80), 'CTA': (50, 350), 'MRA': (600, 1200), 'MRI': (40, 80), 'MR': (40, 80)}
    return windows.get(modality, (40, 80))

def apply_dicom_windowing(img: np.ndarray, window_center: float, window_width: float) -> np.ndarray:
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img = np.clip(img, img_min, img_max)
    img = (img - img_min) / (img_max - img_min + 1e-7)
    return (img * 255).astype(np.uint8)

def apply_clahe_normalization(img: np.ndarray, modality: str) -> np.ndarray:
    if not config.USE_CLAHE: return img.astype(np.uint8)
    img = img.astype(np.uint8)
    if modality in ['CTA', 'MRA']:
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
        img_clahe = clahe.apply(img)
        img_clahe = cv2.convertScaleAbs(img_clahe, alpha=1.1, beta=5)
    elif modality in ['MRI', 'MR']:
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        img_clahe = clahe.apply(img)
        img_clahe = np.power(img_clahe / 255.0, 0.9) * 255
        img_clahe = img_clahe.astype(np.uint8)
    else:
        clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))
        img_clahe = clahe.apply(img)
    return img_clahe

def robust_normalization(volume: np.ndarray) -> np.ndarray:
    p1, p99 = np.percentile(volume.flatten(), [1, 99])
    volume_norm = np.clip(volume, p1, p99)
    if p99 > p1:
        volume_norm = (volume_norm - p1) / (p99 - p1 + 1e-7)
    else:
        volume_norm = np.zeros_like(volume_norm)
    return (volume_norm * 255).astype(np.uint8)

def create_3channel_input_8frame(volume: np.ndarray) -> np.ndarray:
    if len(volume) == 0:
        return np.zeros((config.IMAGE_SIZE, config.IMAGE_SIZE, 3), dtype=np.uint8)
    middle_slice = volume[len(volume) // 2]
    mip = np.max(volume, axis=0)
    std_proj = np.std(volume, axis=0).astype(np.float32)
    if std_proj.max() > std_proj.min():
        p1, p99 = np.percentile(std_proj, [5, 95])
        std_proj = np.clip(std_proj, p1, p99)
        std_proj = ((std_proj - p1) / (p99 - p1 + 1e-7) * 255).astype(np.uint8)
    else:
        std_proj = np.zeros_like(std_proj, dtype=np.uint8)
    return np.stack([middle_slice, mip, std_proj], axis=-1)

# This is the Albumentations transform that matches your validation transform
val_transform = A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# -------------------------
# 6. MODEL DEFINITION (Must match training)
# -------------------------
class ImprovedMultiFrameModel(nn.Module):
    def __init__(self, num_frames=8, num_classes=14, pretrained=False):
        super(ImprovedMultiFrameModel, self).__init__()
        self.num_frames = num_frames
        self.num_classes = num_classes
        self.use_metadata = config.USE_METADATA
        
        # NOTE: pretrained=False because Internet is OFF
        self.backbone = timm.create_model(
            config.MODEL_NAME_BACKBONE,
            pretrained=pretrained,
            num_classes=0,
            global_pool='avg'
        )
        self.feature_dim = self.backbone.num_features
        
        if self.use_metadata:
            self.meta_fc = nn.Sequential(
                nn.Linear(2, 16), nn.ReLU(), nn.Dropout(0.2),
                nn.Linear(16, 32), nn.ReLU()
            )
            classifier_input_dim = self.feature_dim + 32
        else:
            classifier_input_dim = self.feature_dim
            
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 512),
            nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x, meta=None):
        features = self.backbone(x)
        if self.use_metadata and meta is not None:
            meta_features = self.meta_fc(meta)
            features = torch.cat([features, meta_features], dim=1)
        output = self.classifier(features)
        return output

# --------------------------
# 7. LOAD ENSEMBLE MODELS
# --------------------------
def discover_ckpts(ckpt_dir):
    files = sorted(glob.glob(os.path.join(ckpt_dir, "*.pth")) + glob.glob(os.path.join(ckpt_dir, "*.pt")))
    return files

CKPT_PATHS = discover_ckpts(config.CKPT_DIR)
_LOADED_MODELS = []

if len(CKPT_PATHS) == 0:
    print(f"ERROR: No .pth files found in {config.CKPT_DIR}. Did you update the path?")
else:
    print(f"Found {len(CKPT_PATHS)} checkpoints for ensemble.")
    for p in CKPT_PATHS:
        print(f"Loading checkpoint: {p}")
        model = ImprovedMultiFrameModel(num_frames=config.NUM_FRAMES, num_classes=config.NUM_CLASSES, pretrained=False)
        try:
            ck = torch.load(p, map_location=device)
            
            if isinstance(ck, dict) and 'model_state_dict' in ck:
                state = ck['model_state_dict']
            else:
                state = ck
            
            sd = {}
            for k, v in state.items():
                nk = k[len('module.'):] if k.startswith('module.') else k
                sd[nk] = v
            
            model.load_state_dict(sd, strict=True)
            model.to(device)
            model.eval()
            _LOADED_MODELS.append(model)
        except Exception as e:
            print(f"Warning: Failed to load model {p}. Error: {e}")

print(f"Successfully loaded {len(_LOADED_MODELS)} models.")

# --------------------------
# 8. INFERENCE FUNCTION
# --------------------------
def sample_eight_from_list(paths: List[str]) -> List[str]:
    n = len(paths)
    if n == 0: return []
    if n <= 8:
        out = paths[:]
        while len(out) < 8:
            out += paths[:(8 - len(out))]
        return out[:8]
    idxs = np.linspace(0, n - 1, 8).astype(int).tolist()
    return [paths[i] for i in idxs]

@torch.no_grad()
def predict(series_path: str) -> pl.DataFrame:
    """
    Main inference function called by the server.
    Reads DICOMs, preprocesses them to match training, and runs ensemble inference.
    """
    # ------------- 1. Collect files -------------
    all_filepaths = sorted(glob.glob(os.path.join(series_path, "*.dcm")))
    
    if len(all_filepaths) == 0:
        return pl.DataFrame([tuple([0.5] * len(TARGET_COLS))], schema=TARGET_COLS)

    # ------------- 2. Sample 8 frames -------------
    chosen_paths = sample_eight_from_list(all_filepaths)
    
    # ------------- 3. Preprocess DICOMs (Matching your training) -------------
    volume = []
    modality = 'CT'
    
    for i, fp in enumerate(chosen_paths):
        try:
            ds = pydicom.dcmread(fp, force=True)
            if i == 0:
                modality = getattr(ds, 'Modality', 'CT')
                
            img = ds.pixel_array.astype(np.float32)
            
            if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
                img = img * float(ds.RescaleSlope) + float(ds.RescaleIntercept)
            
            wc, ww = get_windowing_params(modality)
            img = apply_dicom_windowing(img, wc, ww)
            img = apply_clahe_normalization(img, modality)
            
            img = cv2.resize(img, (config.IMAGE_SIZE, config.IMAGE_SIZE), interpolation=cv2.INTER_AREA)
            volume.append(img)
        except Exception:
            volume.append(np.zeros((config.IMAGE_SIZE, config.IMAGE_SIZE), dtype=np.uint8))

    while len(volume) < 8:
        volume.append(np.zeros((config.IMAGE_SIZE, config.IMAGE_SIZE), dtype=np.uint8))
        
    vol_np = np.array(volume)
    
    # ------------- 4. Create 3-Channel Input (matching training) -------------
    vol_norm = robust_normalization(vol_np)
    input_hwc = create_3channel_input_8frame(vol_norm)
    
    # ------------- 5. Convert to Tensor (matching training) -------------
    # Apply the same validation transform
    # val_transform handles Normalize and ToTensorV2 (permute + scaling)
    img_normalized = val_transform(image=input_hwc)['image'] 
    img_tensor = img_normalized.unsqueeze(0).to(device)
    
    # Metadata placeholder (real metadata is unavailable in test set)
    meta_tensor = torch.tensor([[0.5, 0.5]], dtype=torch.float32).to(device)

    # ------------- 6. Model Inference (Ensemble Average) -------------
    preds_accum = None
    
    if not _LOADED_MODELS: # Safety check
         return pl.DataFrame([tuple([0.5] * len(TARGET_COLS))], schema=TARGET_COLS)

    for model in _LOADED_MODELS:
        with torch.cuda.amp.autocast():
            # Standard pred
            logits = model(img_tensor, meta_tensor)
            probs = torch.sigmoid(logits)
            
            # TTA (Horizontal Flip)
            flipped = torch.flip(img_tensor, dims=[-1])
            logits_f = model(flipped, meta_tensor)
            probs_f = torch.sigmoid(logits_f)
        
        # Average TTA
        probs = (probs + probs_f) / 2.0
        
        preds_accum = probs if preds_accum is None else preds_accum + probs

    # Average ensemble predictions
    preds_accum = (preds_accum / len(_LOADED_MODELS)).cpu().numpy()

    row = preds_accum[0].tolist()
    predictions = pl.DataFrame([tuple(row)], schema=TARGET_COLS)

    return predictions

# --------------------------
# 9. START SERVER
# --------------------------
print("Starting RSNA Inference Server...")
# This server will call your `predict()` function for each test series
inference_server = rsna_inference_server.RSNAInferenceServer(predict)
inference_server.serve()

  data = fetch_version_info()


Using device: cuda
Found 6 checkpoints for ensemble.
Loading checkpoint: /kaggle/input/effecient-net-models/eightframe_efficientnetv2s_fold0_epoch10_score0.609154.pth
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` or the `torch.serialization.safe_globals([scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only ht