In [None]:
# ==========================================
# PHASE 2: 9-MODEL ENSEMBLE SUBMISSION SCRIPT
# ==========================================
# This script loads ALL model files, automatically finds
# the 9 best (3 arch x 3 folds), and uses the
# official rsna_inference_server for submission.

# -------------------------
# 1. GLOBAL IMPORTS
# -------------------------
import os
import glob
import random
import warnings
import numpy as np
import pandas as pd
import cv2
import functools
from pathlib import Path
from tqdm import tqdm
import typing
from typing import List, Tuple, Optional, Dict
import gc
import shutil

# --- Competition-specific import ---
# This block correctly handles both
# "Save Version" (commit) and "Submit"
try:
    import kaggle_evaluation.rsna_inference_server as rsna_inference_server
    import polars as pl # Demo uses polars
    print("Successfully imported rsna_inference_server.")
except ImportError:
    # Create a minimal, empty MockEnv that does nothing.
    import polars as pl # Still need polars
    class MockEnv:
        def __init__(self):
            print("Creating simple MockEnv for 'Save Version' run.")
        def iter_test(self):
            print("MockEnv.iter_test() called, returning empty list.")
            return [] # Return an empty iterator
        def predict(self, submission_df):
            print("MockEnv.predict() called. Skipping.")
    
    rsna_inference_server = MockEnv()
    print("Mock rsna_inference_server environment created.")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pydicom

warnings.filterwarnings('ignore')
print(f"Imports successful. PyTorch version: {torch.__version__}")

# -------------------------
# 2. GLOBAL CONFIGURATION
# -------------------------
class Config:
    # -----------------------------------------------------------------
    # --- !!! 1. EDIT THIS PATH !!! ---
    # Change this to your dataset containing ALL 20+ model files
    MODEL_DIR = "/kaggle/input/newmodels" 
    # -----------------------------------------------------------------

    # --- Model Hyperparameters (MUST MATCH TRAINING) ---
    NUM_FRAMES = 8
    IMAGE_SIZE = 224
    NUM_CLASSES = 14
    NUM_FOLDS = 3 # Must match the number of folds you trained
    
    # --- Feature Flags (MUST MATCH TRAINING) ---
    USE_METADATA = True
    USE_WINDOWING = True
    USE_CLAHE = True
    
    # --- Inference Config ---
    TTA_ENABLED = True 

    # --- Define Models (MUST MATCH TRAINING) ---
    MODELS_TO_TRAIN = [
        ("effnetv2s", "tf_efficientnetv2_s.in1k"),
        ("convnext_tiny", "convnext_tiny.fb_in1k"),
        ("maxvit_tiny", "maxvit_tiny_tf_224.in1k"),
    ]
    MODELS_TO_TRAIN_DICT = dict(MODELS_TO_TRAIN)

config = Config()

# -------------------------
# 3. GLOBAL SEED & DEVICE
# -------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# -------------------------
# 4. GLOBAL TARGETS
# -------------------------
TARGET_COLS = [
    'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery', 
    'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery',
    'Anterior Communicating Artery', 'Left Anterior Cerebral Artery',
    'Right Anterior Cerebral Artery', 'Left Posterior Communicating Artery',
    'Right Posterior Communicating Artery', 'Basilar Tip',
    'Other Posterior Circulation', 'Aneurysm Present'
]

# ----------------------------------------------------
# 5. RE-USED HELPER FUNCTIONS (FROM TRAINING SCRIPT)
# ----------------------------------------------------
def get_windowing_params(modality: str) -> Tuple[float, float]:
    windows = {'CT': (40, 80), 'CTA': (50, 350), 'MRA': (600, 1200), 'MRI': (40, 80), 'MR': (40, 80)}
    return windows.get(modality, (40, 80))

def apply_dicom_windowing(img: np.ndarray, window_center: float, window_width: float) -> np.ndarray:
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img = np.clip(img, img_min, img_max)
    img = (img - img_min) / (img_max - img_min + 1e-7)
    return (img * 255).astype(np.uint8)

def apply_clahe_normalization(img: np.ndarray, modality: str) -> np.ndarray:
    if not config.USE_CLAHE: return img.astype(np.uint8)
    img = img.astype(np.uint8)
    if modality in ['CTA', 'MRA']:
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
        img_clahe = clahe.apply(img)
        img_clahe = cv2.convertScaleAbs(img_clahe, alpha=1.1, beta=5)
    elif modality in ['MRI', 'MR']:
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        img_clahe = clahe.apply(img)
        img_clahe = np.power(img_clahe / 255.0, 0.9) * 255
        img_clahe = img_clahe.astype(np.uint8)
    else:
        clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))
        img_clahe = clahe.apply(img)
    return img_clahe

def robust_normalization(volume: np.ndarray) -> np.ndarray:
    if volume.size == 0:
        return np.zeros((config.IMAGE_SIZE, config.IMAGE_SIZE), dtype=np.uint8)
    p1, p99 = np.percentile(volume.flatten(), [1, 99])
    volume_norm = np.clip(volume, p1, p99)
    if p99 > p1:
        volume_norm = (volume_norm - p1) / (p99 - p1 + 1e-7)
    else:
        volume_norm = np.zeros_like(volume_norm)
    return (volume_norm * 255).astype(np.uint8)

def create_3channel_input_8frame(volume: np.ndarray) -> np.ndarray:
    if len(volume) == 0:
        return np.zeros((config.IMAGE_SIZE, config.IMAGE_SIZE, 3), dtype=np.uint8)
    middle_slice = volume[len(volume) // 2]
    mip = np.max(volume, axis=0)
    std_proj = np.std(volume, axis=0).astype(np.float32)
    if std_proj.max() > std_proj.min():
        p1, p99 = np.percentile(std_proj, [5, 95])
        std_proj = np.clip(std_proj, p1, p99)
        std_proj = ((std_proj - p1) / (p99 - p1 + 1e-7) * 255).astype(np.uint8)
    else:
        std_proj = np.zeros_like(std_proj, dtype=np.uint8)
    return np.stack([middle_slice, mip, std_proj], axis=-1)

def smart_8_frame_sampling(volume_paths: List[str]) -> List[str]:
    n = len(volume_paths)
    if n == 0: return []
    if n <= 8:
        result = volume_paths[:]
        while len(result) < 8:
            result.extend(volume_paths[:8-len(result)])
        return result[:8]
    start_idx = max(0, int(n * 0.1))
    available_frames = n - start_idx
    step = max(1, available_frames // 8)
    indices = [start_idx + i * step for i in range(8)]
    indices = [min(i, n - 1) for i in indices]
    if len(set(indices)) < 8:
        indices = np.linspace(start_idx, n-1, 8).astype(int).tolist()
    return [volume_paths[i] for i in indices]

# ---------------------------------
# 6. MODEL DEFINITION (FROM TRAINING)
# ---------------------------------
class ImprovedMultiFrameModel(nn.Module):
    def __init__(self, model_name_backbone: str, num_frames=8, num_classes=14, pretrained=False):
        super(ImprovedMultiFrameModel, self).__init__()
        self.model_name_backbone = model_name_backbone
        self.num_frames = num_frames
        self.num_classes = num_classes
        self.use_metadata = config.USE_METADATA
        
        self.backbone = timm.create_model(
            self.model_name_backbone,
            pretrained=pretrained,
            num_classes=0,
            global_pool='avg'
        )
        self.feature_dim = self.backbone.num_features
        
        if self.use_metadata:
            self.meta_fc = nn.Sequential(
                nn.Linear(2, 16), nn.ReLU(), nn.Dropout(0.2),
                nn.Linear(16, 32), nn.ReLU()
            )
            classifier_input_dim = self.feature_dim + 32
        else:
            classifier_input_dim = self.feature_dim
            
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, 512),
            nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x, meta=None):
        features = self.backbone(x)
        if self.use_metadata and meta is not None:
            meta_features = self.meta_fc(meta)
            features = torch.cat([features, meta_features], dim=1)
        output = self.classifier(features)
        return output

# -------------------------
# 7. INFERENCE: TRANSFORMS
# -------------------------
def get_model_transforms() -> Dict[str, A.Compose]:
    transforms = {}
    
    transforms["pytorch"] = A.Compose([
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    transforms["tensorflow"] = A.Compose([
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2()
    ])
    return transforms

_TRANSFORMS = get_model_transforms()

# -------------------------
# 8. INFERENCE: MODEL LOADING
# -------------------------
def load_all_models() -> Dict[str, List[nn.Module]]:
    """
    Loads all 9 models (3 archs x 3 folds) from the MODEL_DIR.
    This function automatically finds the BEST scoring model for each fold.
    """
    loaded_models = {}
    
    print(f"Loading models from: {config.MODEL_DIR}")
    
    for prefix, backbone_name in config.MODELS_TO_TRAIN:
        print(f"\n--- Loading architecture: {prefix} (Backbone: {backbone_name}) ---")
        
        fold_models = []
        for fold in range(config.NUM_FOLDS):
            # 1. Find all models for this prefix AND fold
            model_paths = glob.glob(os.path.join(config.MODEL_DIR, f"{prefix}_fold{fold}*.pth"))
            
            if not model_paths:
                print(f"FATAL: No model found for {prefix} Fold {fold}. Check dataset.")
                return None # Signal failure
            
            # 2. Find the one with the best score in its name
            try:
                best_path = sorted(model_paths, key=lambda x: float(x.split('score')[-1].replace('.pth', '')))[-1]
            except Exception as e:
                print(f"  ERROR: Could not parse score from filenames for {prefix} Fold {fold}: {e}")
                print(f"  Using first file found: {model_paths[0]}")
                best_path = model_paths[0]
                
            print(f"  Loading best model for Fold {fold}: {os.path.basename(best_path)}")
            
            # 3. Load the model
            model = ImprovedMultiFrameModel(
                model_name_backbone=backbone_name,
                pretrained=False 
            )
            
            try:
                ck = torch.load(best_path, map_location=device, weights_only=False)
                model.load_state_dict(ck['model_state_dict'])
            except Exception as e:
                print(f"    ERROR loading checkpoint: {e}")
                return None # Signal failure
            
            model.to(device)
            model.eval()
            fold_models.append(model)
            
        loaded_models[prefix] = fold_models
        
    print(f"\nSuccessfully loaded {sum(len(m) for m in loaded_models.values())} total models.")
    return loaded_models

_LOADED_MODELS = load_all_models()

# -------------------------
# 9. INFERENCE FUNCTION (The "predict" function)
# -------------------------
@torch.no_grad()
def predict(series_path: str) -> pl.DataFrame:
    """
    Main inference function called by the server.
    Reads DICOMs, preprocesses them to match training, and runs ensemble inference.
    """
    if not _LOADED_MODELS:
         print("ERROR: Models are not loaded. Returning 0.5")
         return pl.DataFrame([tuple([0.5] * len(TARGET_COLS))], schema=TARGET_COLS)

    # ------------- 1. Collect & Sample Files -------------
    all_filepaths = sorted(glob.glob(os.path.join(series_path, "*.dcm")))
    if len(all_filepaths) == 0:
        return pl.DataFrame([tuple([0.5] * len(TARGET_COLS))], schema=TARGET_COLS)
    
    chosen_paths = smart_8_frame_sampling(all_filepaths)
    
    # ------------- 2. Preprocess DICOMs (Matching your training) -------------
    volume = []
    modality = 'CT'
    for i, fp in enumerate(chosen_paths):
        try:
            ds = pydicom.dcmread(fp, force=True)
            if i == 0:
                modality = getattr(ds, 'Modality', 'CT')
                
            img = ds.pixel_array.astype(np.float32)
            
            if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
                img = img * float(ds.RescaleSlope) + float(ds.RescaleIntercept)
            
            wc, ww = get_windowing_params(modality)
            img = apply_dicom_windowing(img, wc, ww)
            img = apply_clahe_normalization(img, modality)
            img = cv2.resize(img, (config.IMAGE_SIZE, config.IMAGE_SIZE), interpolation=cv2.INTER_AREA)
            volume.append(img)
        except Exception:
            volume.append(np.zeros((config.IMAGE_SIZE, config.IMAGE_SIZE), dtype=np.uint8))

    while len(volume) < 8:
        volume.append(np.zeros((config.IMAGE_SIZE, config.IMAGE_SIZE), dtype=np.uint8))
        
    vol_np = np.array(volume)
    
    # ------------- 3. Create 3-Channel Input (matching training) -------------
    vol_norm = robust_normalization(vol_np)
    input_hwc = create_3channel_input_8frame(vol_norm)
    
    # ------------- 4. Get Metadata Placeholder -------------
    meta_tensor = torch.tensor([[0.5, 0.5]], dtype=torch.float32).to(device)

    # ------------- 5. Model Inference (Ensemble Average) -------------
    all_arch_probs = [] # Store the 3 main predictions
    
    for prefix, fold_models in _LOADED_MODELS.items():
        backbone_name = config.MODELS_TO_TRAIN_DICT[prefix]
        
        # Get correct transform
        if 'tf_' in backbone_name:
            transform = _TRANSFORMS["tensorflow"]
        else:
            transform = _TRANSFORMS["pytorch"]
        
        img_tensor = transform(image=input_hwc)['image'].unsqueeze(0).to(device)
        
        fold_probs_list = [] # Store the 3 fold predictions for this arch
        
        # Loop over this arch's 3 fold models
        for model in fold_models:
            with torch.cuda.amp.autocast():
                # Standard pred
                logits = model(img_tensor, meta_tensor)
                probs = torch.sigmoid(logits)
                
                # TTA
                if config.TTA_ENABLED:
                    img_tensor_flipped = torch.flip(img_tensor, dims=[-1])
                    logits_f = model(img_tensor_flipped, meta_tensor)
                    probs_f = torch.sigmoid(logits_f)
                    probs = (probs + probs_f) / 2.0
            
            fold_probs_list.append(probs)
        
        # Level 2: Fold Averaging
        arch_prob = torch.mean(torch.stack(fold_probs_list, dim=0), dim=0)
        all_arch_probs.append(arch_prob)
    
    # Level 3: Model Averaging
    final_probs = torch.mean(torch.stack(all_arch_probs, dim=0), dim=0)
    
    # Format for polars DataFrame
    row = final_probs.cpu().numpy()[0].tolist()
    predictions = pl.DataFrame([tuple(row)], schema=TARGET_COLS)

    return predictions

# --------------------------
# 10. START SERVER
# --------------------------
print("Starting RSNA Inference Server...")
# This server will call your `predict()` function for each test series
inference_server = rsna_inference_server.RSNAInferenceServer(predict)
inference_server.serve()

# --- Clean up memory after server is done (if it ever exits) ---
if '_LOADED_MODELS' in locals() or '_LOADED_MODELS' in globals():
    del _LOADED_MODELS
gc.collect()
torch.cuda.empty_cache()

print("\nScript finished.")