# MONAI Lung CT Registration - Distributed Inference

This notebook demonstrates **distributed inference** using the trained registration model from Notebook 3.

## Overview
- **Task**: Apply trained registration model to medical images at scale
- **Architecture**: Same LocalNet model, now deployed for inference
- **Framework**: Ray for parallel processing across multiple GPUs
- **Data Source**: Trained model from Snowflake Model Registry + NIfTI images from stages

## Workflow
1. **Setup & Dependencies** - Configure Ray cluster and install libraries
2. **Load Trained Model** - Retrieve model from Snowflake Model Registry
3. **Define Inference Pipeline** - Create distributed inference class
4. **Run Parallel Inference** - Process multiple cases simultaneously
5. **Save Results** - Store registered images and metrics to Snowflake

## Key Features
- **Model Registry Integration**: Load versioned models directly from Snowflake
- **GPU Acceleration**: Parallel inference across multiple GPUs
- **Automatic Scaling**: Ray distributes workload across available compute
- **Result Persistence**: Save registered images back to Snowflake stages


## Step 1: Initialize Snowflake Session

Set up connection to Snowflake for accessing:
- Model Registry (to load trained models)
- Stages (for reading input images and saving results)
- Compute resources (Ray cluster)


In [None]:
# Core Python libraries
import streamlit as st  # Interactive UI for Snowflake notebooks
import pandas as pd     # Data manipulation

# Snowflake integration
from snowflake.snowpark.context import get_active_session

# Establish Snowflake session connection
# Provides access to Model Registry, stages, and compute resources
session = get_active_session()

# Set query tag for consumption tracking
session.query_tag = '{"origin":"sf_sit-is","name":"distributed_medical_image_processing_with_monai","version":{"major":1,"minor":0},"attributes":{"is_quickstart":1,"source":"notebook"}}'

# Database name - matches setup.sql
DATABASE_NAME = "MONAI_DB"


## Step 2: Configure Distributed Inference Environment

Set up the distributed computing environment for parallel inference:

1. **Ray Cluster** - Initialize connection to distributed compute cluster
2. **Cluster Scaling** - Scale to 4 nodes for parallel GPU inference
3. **Dependencies** - Install MONAI and medical imaging libraries on all workers

**Why distributed inference?** 
- Process multiple medical images simultaneously across GPUs
- Dramatically reduce total inference time (e.g., 10 cases in parallel vs sequential)
- Efficient utilization of available compute resources


In [None]:
# ============================================================================
# DISTRIBUTED INFERENCE SETUP WITH RAY
# ============================================================================

# Standard Python libraries
import streamlit as st
import pandas as pd
import logging
import tempfile
import os

# Deep learning and distributed computing
import ray                          # Distributed computing framework
import torch                        # PyTorch deep learning
import torch.nn as nn              # Neural network modules
import torch.optim as optim        # Optimization algorithms (not used in inference)
import numpy as np                 # Numerical operations

# Snowflake integrations
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.files import SnowflakeFile      # Read files from stages
from snowflake.ml.runtime_cluster import scale_cluster  # Scale Ray cluster



session = get_active_session()



# ============================================================================
# RAY CLUSTER INITIALIZATION
# ============================================================================

# Connect to Ray cluster in Snowflake container runtime
# 'auto' automatically discovers the cluster head node
ray.init(address='auto', ignore_reinit_error=True)

def configure_ray_logger() -> None:
    """
    Configure logging to reduce Ray's verbose output during inference.
    
    Suppresses Ray internals while keeping application logs for tracking
    inference progress and results.
    """
    # Suppress Ray core logging (only critical errors)
    ray_logger = logging.getLogger("ray")
    ray_logger.setLevel(logging.CRITICAL)
    
    # Suppress Ray Data logging (dataset operations)
    data_logger = logging.getLogger("ray.data")
    data_logger.setLevel(logging.CRITICAL)
    
    # Keep INFO level for our inference logs (progress, metrics)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    
    # Disable progress bars for cleaner output
    context = ray.data.DataContext.get_current()
    context.execution_options.verbose_progress = False
    context.enable_operator_progress_bars = False

# Apply logging configuration
configure_ray_logger()

# Scale cluster to 4 worker nodes (one GPU per node for parallel inference)
scale_cluster(4)

# ============================================================================
# DISTRIBUTED DEPENDENCY INSTALLATION
# ============================================================================

@ray.remote(num_cpus=0)  # Minimal CPU usage for pip install
def install_deps():
    """
    Install required Python packages on a single Ray worker node.
    
    This runs on each node in the cluster to ensure all workers have
    the necessary libraries for medical image inference.
    
    Returns:
        str: Success or failure message with node IP address
    """
    try:
        import subprocess
        
        # Required packages for medical image registration inference
        packages = [
            "monai",            # Medical imaging framework
            "pytorch-ignite",   # Training utilities (dependencies)
            "itk",             # Medical image I/O
            "gdown",           # File downloader (if needed)
            "torchvision",     # Vision utilities
            "lmdb",            # Database library
            "transformers",    # Tensor utilities
            "einops",          # Tensor operations
            "nibabel"          # NIfTI file format support
        ]
        
        # Install packages silently (suppress output)
        subprocess.run(
            ["pip", "install"] + packages, 
            check=True, 
            stdout=subprocess.PIPE, 
            stderr=subprocess.PIPE
        )
        
        # Verify MONAI installation
        result = subprocess.run(
            ["pip", "show", "monai"], 
            capture_output=True, 
            text=True, 
            check=True
        )
        
        # Return success message with node identifier
        return f"✅ Dependencies installed on {ray.util.get_node_ip_address()}:\n{result.stdout.splitlines()[0]}"
    
    except subprocess.CalledProcessError as e:
        # Return error message if installation fails
        error_msg = e.stderr if e.stderr else e.stdout
        return f"❌ Failed on {ray.util.get_node_ip_address()}: {error_msg}"

# Get all alive nodes in the Ray cluster
nodes = {node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]}

# Create installation task for each node (runs in parallel)
tasks = [install_deps.options(resources={f"node:{node}": 0.01}).remote() for node in nodes]

# Wait for all installations to complete
results = ray.get(tasks)

# Print installation status for each node
for res in results:
    print(res)


## Step 3: Define Distributed Inference Class

This cell creates a `RegistryBasedInferencer` class that:

1. **Loads Model from Registry** - Retrieves trained model from Snowflake Model Registry
2. **Preprocesses Images** - Applies same transforms as training (CT windowing, resizing)
3. **Runs Inference** - Predicts deformation fields and applies warping
4. **Calculates Metrics** - Computes Dice score to evaluate registration quality
5. **Saves Results** - Writes registered images back to Snowflake stages

**Key Design**: This class is designed to run on Ray workers, with each worker processing a batch of cases independently. Ray handles distribution across GPUs automatically.

**Model Registry vs Direct Loading**: The class supports both loading from Model Registry and fallback to stage paths for flexibility.


In [None]:
# ============================================================================
# DISTRIBUTED INFERENCE IMPLEMENTATION
# ============================================================================
# This cell implements a complete inference pipeline using:
# - Snowflake Model Registry for model loading
# - Ray for distributed parallel processing
# - MONAI for medical image transformations

# Core libraries
import streamlit as st
import pandas as pd
import torch                    # PyTorch for model inference
import ray                      # Distributed computing
import tempfile                 # Temporary file management
import os
import io                       # Binary I/O operations
import numpy as np

# Snowflake integrations
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.files import SnowflakeFile
from snowflake.ml.registry import Registry        # Model Registry access

# MONAI medical imaging components
from monai.networks.nets import LocalNet          # Registration network architecture
from monai.networks.blocks import Warp            # Spatial warping layer
from monai.transforms import (
    Compose,                    # Chain transformations
    LoadImage,                  # Load medical images
    EnsureChannelFirst,         # Ensure channel-first format
    ScaleIntensityRange,        # CT windowing (HU normalization)
    Resize                      # Resize 3D volumes
)

session = get_active_session()


# ============================================================================
# INFERENCE CONFIGURATION
# ============================================================================

# CT windowing parameters (must match training)
CT_MIN_HU = -1000  # Air (minimum Hounsfield Unit)
CT_MAX_HU = 500    # Soft tissue/bone (maximum HU)

# Inference configuration parameters
INFERENCE_CONFIG = {
    "spatial_size": (96, 96, 104),              # Volume dimensions (H, W, D)
    "num_channel_initial": 32,                  # Model architecture parameter
    "database": DATABASE_NAME,               # Snowflake database
    "schema": "UTILS",                          # Schema containing model
    "model_name": "LUNG_CT_REGISTRATION",       # Model name in registry
    "model_version": "v1"                       # Model version to use
}


# ============================================================================
# DISTRIBUTED INFERENCE CLASS
# ============================================================================

class RegistryBasedInferencer:
    """
    Ray-compatible inference class for medical image registration.
    
    This class is designed to run on Ray workers for distributed inference:
    - Each worker loads the model from Snowflake Model Registry
    - Processes batches of medical image pairs independently
    - Applies deformation fields to register images
    - Calculates registration quality metrics (Dice score)
    - Saves results back to Snowflake stages
    
    **Distributed Execution**:
    Ray instantiates this class on each worker, and the __call__ method
    processes batches of data in parallel across multiple GPUs.
    
    **Model Loading**:
    Supports two modes:
    1. Model Registry: Load versioned model from Snowflake registry
    2. Fallback: Load directly from stage path if registry fails
    """
    
    def __init__(
        self,
        model_name: str = "LUNG_CT_REGISTRATION",
        model_version: str = "v1",
        database: str = None,
        schema: str = "UTILS",
        spatial_size: tuple = (96, 96, 104),
        num_channel_initial: int = 32,
        save_to_stage: bool = True,
        fallback_stage_path: str = None  # Fallback if registry loading fails
    ):
        """
        Initialize the inference worker with model and configuration.
        
        This method is called once per Ray worker when the class is instantiated.
        It loads the model, sets up preprocessing transforms, and prepares for inference.
        
        Args:
            model_name (str): Model name in Snowflake Model Registry
            model_version (str): Version identifier (e.g., "v1", "v2")
            database (str): Snowflake database containing the model
            schema (str): Schema containing the model
            spatial_size (tuple): Target image size (H, W, D) for preprocessing
            num_channel_initial (int): Model architecture parameter (must match training)
            save_to_stage (bool): Whether to save registered images to Snowflake stages
            fallback_stage_path (str): Direct path to model weights if registry fails
                                       (e.g., f"@{DATABASE_NAME}.UTILS.RESULTS_STG/best_model.pth")
        """
        # Store configuration
        self.model_name = model_name
        self.model_version = model_version
        self.database = database
        self.schema = schema
        self.spatial_size = spatial_size
        self.num_channel_initial = int(num_channel_initial)
        self.save_to_stage = save_to_stage
        self.fallback_stage_path = fallback_stage_path
        
        # ====================================================================
        # STEP 1: Establish Snowflake session
        # ====================================================================
        try:
            self.session = get_active_session()
            print(f"✅ Worker connected to Snowflake")
        except Exception as e:
            print(f"⚠️ Worker session failed: {e}")
            self.session = None
        
        # ====================================================================
        # STEP 2: Detect compute device (GPU preferred)
        # ====================================================================
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"📍 Using device: {self.device}")
        
        # ====================================================================
        # STEP 3: Load trained model from registry or fallback
        # ====================================================================
        self._load_model_from_registry()
        
        # ====================================================================
        # STEP 4: Initialize warping layer for applying deformations
        # ====================================================================
        self.warp_layer = Warp().to(self.device)
        
        # ====================================================================
        # STEP 5: Setup preprocessing transforms (must match training)
        # ====================================================================
        self._setup_preprocessing()
    
    def _load_model_from_registry(self):
        """
        Load trained model from Snowflake Model Registry with fallback support.
        
        Loading Strategy:
        1. Try to connect to Model Registry and retrieve model reference
        2. Create model architecture (LocalNet) matching training configuration
        3. Load trained weights from fallback stage path
        4. Set model to evaluation mode
        
        **Note**: Currently uses fallback_stage_path for weight loading because
        direct model.run() from registry has compatibility issues with Ray workers.
        For production, ensure fallback_stage_path points to latest trained model.
        """
        print(f"📥 Loading model from registry: {self.model_name} {self.model_version}")
        
        try:
            # ================================================================
            # STEP 1: Connect to Snowflake Model Registry
            # ================================================================
            registry = Registry(
                session=self.session,
                database_name=self.database,
                schema_name=self.schema
            )
            
            # Get model reference (metadata and versioning info)
            model_ref = registry.get_model(self.model_name).version(self.model_version)
            print(f"✅ Model reference obtained: {model_ref.fully_qualified_model_name}")
            
            # ================================================================
            # STEP 2: Reconstruct model architecture
            # ================================================================
            # Must match the architecture used during training!
            self.model = LocalNet(
                spatial_dims=3,              # 3D medical images
                in_channels=2,               # Concatenated fixed + moving images
                out_channels=3,              # 3D deformation field (x, y, z)
                num_channel_initial=self.num_channel_initial,  # Feature channels
                extract_levels=[3],          # U-Net depth
                out_activation=None,         # No activation (regression task)
                out_kernel_initializer="zeros"  # Identity transform initialization
            ).to(self.device)
            
            # ================================================================
            # STEP 3: Load trained weights
            # ================================================================
            # Currently using fallback path for weight loading
            # Registry.run() has known issues with Ray distributed workers
            
            if self.fallback_stage_path:
                print(f"📥 Loading weights from fallback path: {self.fallback_stage_path}")
                
                # Download model weights from Snowflake stage
                raw_stream = self.session.file.get_stream(self.fallback_stage_path)
                model_buffer = io.BytesIO(raw_stream.read())
                
                # Load state dict (trained parameters)
                state_dict = torch.load(model_buffer, map_location=self.device)
                self.model.load_state_dict(state_dict)
            else:
                # No fallback provided - model will have random weights!
                print("⚠️ No fallback path provided. Model may not have trained weights.")
                print("   Provide fallback_stage_path parameter for production use.")
            
            # ================================================================
            # STEP 4: Set to evaluation mode
            # ================================================================
            # Disables dropout, batch norm tracking, etc.
            self.model.eval()
            print(f"✅ Model loaded successfully on {self.device}")
            
        except Exception as e:
            # Registry loading failed, try direct stage loading
            print(f"❌ Failed to load from registry: {e}")
            print(f"   Attempting fallback to stage path...")
            
            if self.fallback_stage_path:
                self._load_from_stage_fallback()
            else:
                raise RuntimeError(f"Model loading failed and no fallback provided: {e}")
    
    def _load_from_stage_fallback(self):
        """Fallback: Load directly from stage"""
        print(f"📥 Loading from stage: {self.fallback_stage_path}")
        
        raw_stream = self.session.file.get_stream(self.fallback_stage_path)
        model_buffer = io.BytesIO(raw_stream.read())
        state_dict = torch.load(model_buffer, map_location=self.device)
        
        self.model = LocalNet(
            spatial_dims=3,
            in_channels=2,
            out_channels=3,
            num_channel_initial=self.num_channel_initial,
            extract_levels=[3],
            out_activation=None,
            out_kernel_initializer="zeros"
        ).to(self.device)
        
        self.model.load_state_dict(state_dict)
        self.model.eval()
        print(f"✅ Model loaded from stage on {self.device}")
    
    def _setup_preprocessing(self):
        """
        Setup MONAI preprocessing transforms for images and labels.
        
        **CRITICAL**: These transforms must EXACTLY match training preprocessing!
        Any mismatch will cause poor inference results.
        
        Image Pipeline:
        1. Load NIfTI file
        2. Ensure channel-first format
        3. Apply CT windowing (normalize HU to [0,1])
        4. Resize to model input size
        
        Label Pipeline:
        1. Load NIfTI file
        2. Ensure channel-first format
        3. Resize with nearest neighbor (preserve binary values)
        """
        # ====================================================================
        # IMAGE PREPROCESSING (with CT windowing)
        # ====================================================================
        self.preprocess_image = Compose([
            LoadImage(image_only=True),           # Load NIfTI file
            EnsureChannelFirst(),                 # Add channel dimension
            ScaleIntensityRange(                  # CT windowing (CRITICAL!)
                a_min=CT_MIN_HU,                  # Input min: -1000 HU (air)
                a_max=CT_MAX_HU,                  # Input max: 500 HU (tissue)
                b_min=0.0,                        # Output min: 0
                b_max=1.0,                        # Output max: 1
                clip=True                         # Clip outliers
            ),
            Resize(                               # Resize to model dimensions
                spatial_size=self.spatial_size, 
                mode="trilinear"                  # Smooth interpolation
            ),
        ])
        
        # ====================================================================
        # LABEL (SEGMENTATION MASK) PREPROCESSING
        # ====================================================================
        # NO intensity normalization - labels are binary (0 or 1)
        self.preprocess_label = Compose([
            LoadImage(image_only=True),           # Load NIfTI file
            EnsureChannelFirst(),                 # Add channel dimension
            Resize(                               # Resize to model dimensions
                spatial_size=self.spatial_size, 
                mode="nearest"                    # Nearest neighbor (preserve binary)
            ),
        ])
    
    def _read_file_from_stage(self, stage_path: str) -> bytes:
        """Read file from Snowflake stage"""
        clean_path = stage_path.strip()
        if not clean_path.startswith("@"):
            clean_path = f"@{clean_path}"
        with SnowflakeFile.open(clean_path, 'rb') as f:
            return f.read()
    
    def _load_and_preprocess(self, file_binary: bytes, is_label: bool = False) -> torch.Tensor:
        """Load and preprocess a file"""
        with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp:
            tmp.write(file_binary)
            tmp_path = tmp.name
        
        try:
            if is_label:
                image = self.preprocess_label(tmp_path)
            else:
                image = self.preprocess_image(tmp_path)
            return image.unsqueeze(0).to(self.device)
        finally:
            if os.path.exists(tmp_path):
                os.unlink(tmp_path)
    
    def _save_result(self, tensor: torch.Tensor, case_id: str, suffix: str = "") -> str:
        """Save result to stage"""
        if not self.save_to_stage or self.session is None:
            return None
        
        import nibabel as nib
        
        with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp:
            array = tensor.cpu().numpy()[0, 0]
            nib.save(nib.Nifti1Image(array, affine=np.eye(4)), tmp.name)
            
            try:
                output_filename = f"registered_{case_id}{suffix}.nii.gz"
                stage_path = f"@{self.database}.{self.schema}.RESULTS_STG/{output_filename}"
                self.session.file.put(tmp.name, stage_path, overwrite=True, auto_compress=False)
                return stage_path
            except Exception as e:
                print(f"⚠️ Failed to save: {e}")
                return None
            finally:
                os.unlink(tmp.name)
    
    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
        """
        Process a batch of medical image pairs (main inference entry point).
        
        This method is called by Ray for each batch of data in distributed processing.
        Each worker processes its assigned batch independently in parallel.
        
        For each case:
        1. Load images from Snowflake stages
        2. Apply preprocessing transforms
        3. Run inference (predict deformation field)
        4. Apply warping to register images
        5. Calculate Dice score (registration quality metric)
        6. Save results back to Snowflake
        
        Args:
            batch (pd.DataFrame): Batch of cases with columns:
                - case_id: Unique case identifier
                - fixed_path: Path to expiration CT scan
                - moving_path: Path to inspiration CT scan
                - fixed_label_path: Path to expiration lung mask
                - moving_label_path: Path to inspiration lung mask
        
        Returns:
            pd.DataFrame: Results with columns:
                - case_id, status, dice_score, output_image, output_label,
                  model_name, model_version (and error if failed)
        """
        results = []
        
        # Process each case in the batch sequentially
        for idx, row in batch.iterrows():
            case_id = row.get('case_id', 'unknown')
            
            try:
                print(f"🔄 Processing case: {case_id}")
                
                # ============================================================
                # STEP 1: Load images from Snowflake stages
                # ============================================================
                fixed_bin = self._read_file_from_stage(row['fixed_path'])
                moving_bin = self._read_file_from_stage(row['moving_path'])
                fixed_lbl_bin = self._read_file_from_stage(row['fixed_label_path'])
                moving_lbl_bin = self._read_file_from_stage(row['moving_label_path'])
                
                # ============================================================
                # STEP 2: Preprocess images (CT windowing, resizing)
                # ============================================================
                fixed_img = self._load_and_preprocess(fixed_bin, is_label=False)
                moving_img = self._load_and_preprocess(moving_bin, is_label=False)
                fixed_lbl = self._load_and_preprocess(fixed_lbl_bin, is_label=True)
                moving_lbl = self._load_and_preprocess(moving_lbl_bin, is_label=True)
                
                # ============================================================
                # STEP 3: Run inference
                # ============================================================
                with torch.no_grad():  # Disable gradient computation (inference only)
                    # Concatenate images for model input (channel dimension)
                    # Shape: (1, 2, H, W, D) where 2 = [moving, fixed]
                    input_tensor = torch.cat((moving_img, fixed_img), dim=1)
                    
                    # Predict 3D deformation field
                    # Output shape: (1, 3, H, W, D) where 3 = [dx, dy, dz]
                    ddf = self.model(input_tensor)
                    
                    # Apply warping to register moving image/label to fixed space
                    reg_label = self.warp_layer(moving_lbl, ddf)
                    reg_image = self.warp_layer(moving_img, ddf)
                    
                    # Calculate Dice score (overlap between registered and fixed masks)
                    # Range: [0, 1] where 1 = perfect overlap
                    intersection = (reg_label * fixed_lbl).sum()
                    dice = (2.0 * intersection + 1e-5) / (reg_label.sum() + fixed_lbl.sum() + 1e-5)
                
                # ============================================================
                # STEP 4: Save results to Snowflake stages
                # ============================================================
                img_path = self._save_result(reg_image, case_id, "_img")
                lbl_path = self._save_result(reg_label, case_id, "_label")
                
                # ============================================================
                # STEP 5: Record results
                # ============================================================
                results.append({
                    "case_id": case_id,
                    "status": "success",
                    "dice_score": float(dice.item()),
                    "output_image": img_path,
                    "output_label": lbl_path,
                    "model_name": self.model_name,
                    "model_version": self.model_version
                })
                
                print(f"✅ {case_id}: Dice = {dice.item():.4f}")
                
            except Exception as e:
                # Log error and continue with next case
                print(f"❌ Failed on {case_id}: {e}")
                results.append({
                    "case_id": case_id,
                    "status": "failed",
                    "error": str(e),
                    "dice_score": 0.0,
                    "model_name": self.model_name,
                    "model_version": self.model_version
                })
        
        # Return results as DataFrame for Ray to aggregate
        return pd.DataFrame(results)



## Step 4: Define Distributed Inference Orchestrator

This function orchestrates the distributed inference process:

1. **Configures Ray Workers** - Sets up inference parameters for all workers
2. **Partitions Dataset** - Divides data across multiple workers
3. **Launches Parallel Inference** - Executes inference across GPUs simultaneously
4. **Aggregates Results** - Collects and combines results from all workers

**Key Parameter**: `num_workers` controls parallelism (typically matches number of available GPUs).

**Result**: A DataFrame containing registration results (Dice scores, output paths) for all processed cases.


In [None]:
def run_distributed_inference_with_registry(
    ray_dataset,
    model_name: str = "LUNG_CT_REGISTRATION",
    model_version: str = "V1",
    fallback_stage_path: str = f'@{DATABASE_NAME}.UTILS.RESULTS_STG/best_model.pth',
    num_workers: int = 4,
    batch_size: int = 1
):
    """
    Orchestrate distributed inference across multiple GPU workers.
    
    This function:
    1. Configures inference parameters (model, preprocessing, etc.)
    2. Partitions the dataset across workers
    3. Launches Ray workers with GPU allocation
    4. Aggregates results from all workers
    
    **Distributed Execution**:
    - Ray creates `num_workers` instances of RegistryBasedInferencer
    - Each worker gets 1 GPU and processes batches independently
    - Results are automatically aggregated by Ray
    
    Args:
        ray_dataset (ray.data.Dataset): Dataset with case information
            Required columns: case_id, fixed_path, moving_path, 
                            fixed_label_path, moving_label_path
        model_name (str): Model name in Snowflake Model Registry
        model_version (str): Model version to use
        fallback_stage_path (str): Direct path to model weights file
            Example: f"@{DATABASE_NAME}.UTILS.RESULTS_STG/best_model.pth"
        num_workers (int): Number of parallel workers (typically = number of GPUs)
        batch_size (int): Number of cases each worker processes per batch
    
    Returns:
        pd.DataFrame: Results with columns:
            - case_id: Case identifier
            - status: "success" or "failed"
            - dice_score: Registration quality metric [0, 1]
            - output_image: Path to registered image
            - output_label: Path to registered label
            - model_name, model_version: Model identifiers
    """
    
    st.subheader(f"🚀 Distributed Inference: {model_name} v{model_version}")
    
    # ========================================================================
    # STEP 1: Configure inference parameters for all workers
    # ========================================================================
    inference_args = {
        "model_name": model_name,
        "model_version": model_version,
        "database": INFERENCE_CONFIG["database"],
        "schema": INFERENCE_CONFIG["schema"],
        "spatial_size": INFERENCE_CONFIG["spatial_size"],
        "num_channel_initial": INFERENCE_CONFIG["num_channel_initial"],
        "save_to_stage": True,
        "fallback_stage_path": fallback_stage_path  # Critical for loading weights!
    }
    
    st.info(f"🔧 Using {num_workers} workers with {batch_size} case(s) per batch")
    
    # ========================================================================
    # STEP 2: Partition dataset for parallel processing
    # ========================================================================
    # Repartition into 20 blocks for fine-grained parallelism
    # Ray will distribute these blocks across the workers
    ds_partitioned = ray_dataset.repartition(20)
    
    # ========================================================================
    # STEP 3: Launch distributed inference
    # ========================================================================
    with st.spinner(f"Running inference on {num_workers} GPUs..."):
        results = ds_partitioned.map_batches(
            RegistryBasedInferencer,            # Inference class to instantiate
            fn_constructor_kwargs=inference_args,  # Arguments for __init__
            batch_size=batch_size,              # Cases per batch
            concurrency=num_workers,            # Number of parallel workers
            num_gpus=1,                         # GPU allocation per worker
            batch_format="pandas"               # Input/output format
        )
        
        # ========================================================================
        # STEP 4: Collect and aggregate results from all workers
        # ========================================================================
        final_df = results.to_pandas()
    
    st.success(f"✅ Inference complete! Processed {len(final_df)} cases.")
    
    return final_df

## Step 5: Load Test Dataset Paths

Load file paths for the medical images we want to process. 

**Important**: This loads only metadata (file paths), not the actual image data! The actual images will be loaded just-in-time during inference by each worker.

This function identifies pairs of:
- Fixed images (expiration CT scans)
- Moving images (inspiration CT scans)
- Corresponding lung segmentation masks

**Note**: In production, you might load a different test set than what was used for training validation.


In [None]:
def load_paired_paths(stage_location=f"@{DATABASE_NAME}.UTILS.MONAI_MEDICAL_IMAGES_STG"):
    """
    Load paired CT scan paths for inference (metadata only, no binary data).
    
    This function:
    1. Lists all NIfTI files in Snowflake stage using SQL
    2. Categorizes files into fixed/moving images and masks
    3. Pairs them by case ID
    4. Returns a Ray dataset with file paths only
    
    Args:
        stage_location (str): Snowflake stage containing medical images
    
    Returns:
        ray.data.Dataset: Lightweight dataset containing only file paths
    """
    session = get_active_session()
    
    # ========================================================================
    # STEP 1: List all NIfTI files using SQL (metadata only, fast)
    # ========================================================================
    df_files = session.sql(
        f"LIST {stage_location} PATTERN = '.*.nii.gz'"
    ).select('"name"').to_pandas()
    
    # Clean up column names and add fully-qualified path
    df_files.rename(columns={'"name"': 'name'})
    df_files["name"] = f"{DATABASE_NAME}.UTILS." + df_files["name"]
    
    print(df_files)
    
    # ========================================================================
    # STEP 2: Categorize files by type using regex patterns
    # ========================================================================
    # Expected folder structure:
    #   - scans/case_XXX_exp.nii.gz    (expiration CT)
    #   - scans/case_XXX_insp.nii.gz   (inspiration CT)
    #   - lungMasks/case_XXX_exp.nii.gz    (expiration mask)
    #   - lungMasks/case_XXX_insp.nii.gz   (inspiration mask)
    
    fixed_imgs = df_files[df_files['name'].str.contains("scans.*_exp", regex=True)]
    moving_imgs = df_files[df_files['name'].str.contains("scans.*_insp", regex=True)]
    fixed_masks = df_files[df_files['name'].str.contains("lungMasks.*_exp", regex=True)]
    moving_masks = df_files[df_files['name'].str.contains("lungMasks.*_insp", regex=True)]
    
    pairs = []
    
    # ========================================================================
    # STEP 3: Match files into inference pairs by case ID
    # ========================================================================
    for _, row in fixed_imgs.iterrows():
        f_path = row['name']
        
        # Extract case identifier (e.g., "case_001" from "case_001_exp.nii.gz")
        filename = f_path.split('/')[-1] 
        case_id = filename.split('_exp')[0] 
        
        # Find matching files for this case
        m_path = moving_imgs[moving_imgs['name'].str.contains(f"{case_id}_insp")].iloc[0]['name']
        fl_path = fixed_masks[fixed_masks['name'].str.contains(f"{case_id}_exp")].iloc[0]['name']
        ml_path = moving_masks[moving_masks['name'].str.contains(f"{case_id}_insp")].iloc[0]['name']
        
        # Store complete pair
        pairs.append({
            "case_id": case_id,
            "fixed_path": f_path,           # Expiration CT
            "moving_path": m_path,          # Inspiration CT
            "fixed_label_path": fl_path,    # Expiration lung mask
            "moving_label_path": ml_path    # Inspiration lung mask
        })
        
    print(f"✅ Paired {len(pairs)} cases using paths only (no binary data loaded).")
    
    # Convert to Ray dataset for distributed processing
    return ray.data.from_pandas(pd.DataFrame(pairs))


# ============================================================================
# EXECUTE: Load test dataset paths
# ============================================================================
# This dataset is tiny (only strings), not the actual GB-sized medical images
ray_dataset = load_paired_paths()

## Step 6: Verify Trained Model Availability

Quick check to confirm our trained model exists in the Snowflake stage before running inference.


In [None]:
ls @RESULTS_STG;

## Step 7: Execute Distributed Inference 🚀

**This is where the magic happens!**

Launching parallel inference across multiple GPUs. The system will:
1. Distribute the dataset across 4 GPU workers
2. Each worker loads the trained model
3. Process cases in parallel
4. Save registered images to Snowflake stages
5. Return metrics (Dice scores) for all cases

**Expected Time**: Depends on dataset size and number of GPUs. With 4 GPUs processing 10 cases, expect ~2-5 minutes.

**Monitoring**: Watch the output for per-case progress and Dice scores!


In [None]:
# ============================================================================
# RUN DISTRIBUTED INFERENCE
# ============================================================================

# Execute inference with trained model from registry
# This will process all cases in the ray_dataset using 4 parallel GPU workers
out_df = run_distributed_inference_with_registry(ray_dataset)

# Display summary statistics
st.write("### Inference Results Summary")
st.write(f"**Total cases processed**: {len(out_df)}")
st.write(f"**Successful**: {(out_df['status'] == 'success').sum()}")
st.write(f"**Failed**: {(out_df['status'] == 'failed').sum()}")
st.write(f"**Average Dice Score**: {out_df[out_df['status'] == 'success']['dice_score'].mean():.4f}")

## Step 8: Review Inference Results

Display the complete results DataFrame with all inference metrics and output paths.


In [None]:
out_df

## Step 9: Persist Results to Snowflake Table

Save the inference results to a Snowflake table for:
- **Long-term storage** - Results persist beyond the notebook session
- **Analytics** - Query and analyze results using SQL
- **Downstream processing** - Other applications can access the results
- **Audit trail** - Track inference runs over time

The table will contain case IDs, Dice scores, output paths, and model versions.


In [None]:
# ============================================================================
# SAVE RESULTS TO SNOWFLAKE TABLE
# ============================================================================

# Write results DataFrame to a Snowflake table
# This enables SQL-based analysis and persistence beyond the notebook session
session.write_pandas(
    out_df,                              # Results DataFrame
    'MONAI_PAIRED_LUNG_RESULTS',        # Table name
    database=DATABASE_NAME,          # Target database
    schema='RESULTS',                   # Target schema
    auto_create_table=True,             # Create table if it doesn't exist
    overwrite=True                      # Replace existing data (use False to append)
)

st.success(f"✅ Results saved to {DATABASE_NAME}.RESULTS.MONAI_PAIRED_LUNG_RESULTS")
st.info("💡 You can now query these results using SQL in Snowflake!")


In [None]:
# Dynamic summary with actual database name
sql_example = f"""SELECT 
    case_id,
    dice_score,
    status,
    model_name,
    model_version
FROM {DATABASE_NAME}.RESULTS.MONAI_PAIRED_LUNG_RESULTS
ORDER BY dice_score DESC;"""

st.markdown(f"""
## 🎉 Inference Complete!

### What We Accomplished

1. ✅ **Configured distributed environment** with Ray cluster (4 GPU workers)
2. ✅ **Loaded trained model** from Snowflake Model Registry
3. ✅ **Processed medical images** in parallel across multiple GPUs
4. ✅ **Generated registrations** for all test cases
5. ✅ **Calculated quality metrics** (Dice scores)
6. ✅ **Saved results** to Snowflake stages and tables

### Results Analysis

- **Registration Quality**: Check Dice scores in the results table
  - Dice > 0.8: Excellent registration
  - Dice 0.6-0.8: Good registration
  - Dice < 0.6: May need review

### Accessing Results

**Registered Images**: Available in `@{DATABASE_NAME}.UTILS.RESULTS_STG`
- Format: `registered_<case_id>_img.nii.gz` (registered CT scan)
- Format: `registered_<case_id>_label.nii.gz` (registered lung mask)

**Metrics Table**: `{DATABASE_NAME}.RESULTS.MONAI_PAIRED_LUNG_RESULTS`
""")

# Display SQL in a code block
st.code(sql_example, language="sql")

st.markdown(f"""
### Next Steps

- **Visualize Results**: Load registered images and compare with originals
- **Quality Control**: Review cases with low Dice scores
- **Deploy to Production**: Use this pipeline for clinical workflows
- **Retrain Model**: If results are suboptimal, retrain with more data or tuned hyperparameters

---

**Distributed medical image inference with Snowflake + Ray + MONAI = Complete! 🚀**
""")