# MONAI Lung CT Registration Training Pipeline

This notebook demonstrates distributed medical image registration using **MONAI** (Medical Open Network for AI) on **Snowflake's Notebooks with Container Runtime**.

## Overview
- **Task**: Register lung CT scans (align inspiration/expiration phases)
- **Architecture**: LocalNet (CNN-based deformation field prediction)
- **Framework**: Ray for distributed computing, PyTorch for deep learning
- **Data Source**: NIfTI medical images stored in Snowflake stages

## Workflow
1. **Setup & Dependencies** - Install MONAI and configure Ray cluster
2. **Data Loading** - Read paired CT scans from Snowflake stages
3. **Model Training** - Train registration network with validation
4. **Model Registry** - Save and register trained model for deployment


## Step 1: Initialize Snowflake Session

Import core libraries and establish connection to Snowflake's active session. This session provides access to:
- Snowflake stages for data storage
- SQL execution capabilities
- File I/O operations


In [None]:
# Core Python libraries for data manipulation and UI
import streamlit as st  # For interactive notebook UI elements
import pandas as pd     # For tabular data handling

# Snowflake-specific imports
from snowflake.snowpark.context import get_active_session

# Establish connection to the active Snowflake session
# This provides access to stages, warehouses, and compute resources
session = get_active_session()

# Set query tag for consumption tracking
session.query_tag = '{"origin":"sf_sit-is","name":"distributed_medical_image_processing_with_monai","version":{"major":1,"minor":0},"attributes":{"is_quickstart":1,"source":"notebook"}}'

# Database name - matches setup.sql
DATABASE_NAME = "MONAI_DB"


## Step 2: Configure Distributed Computing Environment

This cell performs three critical tasks:

1. **Ray Initialization** - Connect to the Ray cluster for distributed processing
2. **Cluster Scaling** - Scale to 4 worker nodes for parallel training
3. **Dependency Installation** - Install MONAI and medical imaging libraries on all nodes

**Why distributed?** Medical image processing is computationally intensive. Ray allows us to parallelize data loading and preprocessing across multiple nodes.


In [None]:
# ============================================================================
# IMPROVED MONAI LUNG CT REGISTRATION WITH SNOWFLAKE ML + RAY
# ============================================================================

# Standard Python libraries
import streamlit as st
import pandas as pd
import logging
import tempfile
import os

# Distributed computing and ML frameworks
import ray                          # Distributed computing framework
import torch                        # PyTorch deep learning
import torch.nn as nn              # Neural network modules
import torch.optim as optim        # Optimization algorithms
import numpy as np                 # Numerical operations

# Snowflake integrations
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.files import SnowflakeFile      # For reading files from stages
from snowflake.ml.runtime_cluster import scale_cluster  # For scaling Ray cluster



session = get_active_session()



# ============================================================================
# RAY CLUSTER SETUP
# ============================================================================

# Connect to the Ray cluster in the Snowflake container runtime
# 'auto' discovers the cluster head node automatically
# ignore_reinit_error allows re-running this cell without errors
ray.init(address='auto', ignore_reinit_error=True)

def configure_ray_logger() -> None:
    """
    Configure logging levels to reduce Ray's verbose output.
    
    We suppress Ray internals (CRITICAL only) but keep application logs (INFO)
    to see training progress without being overwhelmed by cluster messages.
    """
    # Suppress Ray core logging (only show critical errors)
    ray_logger = logging.getLogger("ray")
    ray_logger.setLevel(logging.CRITICAL)
    
    # Suppress Ray Data logging (dataset operations)
    data_logger = logging.getLogger("ray.data")
    data_logger.setLevel(logging.CRITICAL)
    
    # Keep INFO level for our application logs (training metrics, etc.)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    
    # Disable Ray's progress bars for cleaner output
    context = ray.data.DataContext.get_current()
    context.execution_options.verbose_progress = False
    context.enable_operator_progress_bars = False

# Apply logging configuration
configure_ray_logger()

# Scale the Ray cluster to 4 worker nodes
# More nodes = faster parallel data loading and preprocessing
scale_cluster(4)

# ============================================================================
# DISTRIBUTED DEPENDENCY INSTALLATION
# ============================================================================

@ray.remote(num_cpus=0)  # Uses negligible CPU (just runs pip install)
def install_deps():
    """
    Install required packages on a single Ray node.
    
    Packages include:
    - monai: Medical imaging transformations and networks
    - pytorch-ignite: Training utilities
    - itk: Medical image I/O
    - nibabel: NIfTI file format support
    - torchvision: Additional vision utilities
    - transformers, einops: Tensor manipulation
    """
    try:
        import subprocess
        
        # List of required packages for medical image registration
        packages = [
            "monai",            # Core medical imaging framework
            "pytorch-ignite",   # Training loop utilities
            "itk",             # Insight Toolkit for medical imaging
            "gdown",           # Google Drive downloader (if needed)
            "torchvision",     # Computer vision utilities
            "lmdb",            # Lightning Memory-Mapped Database
            "transformers",    # Attention mechanisms (if using vision transformers)
            "einops",          # Tensor operations made easy
            "nibabel"          # NIfTI and medical image file I/O
        ]
        
        # Install packages with pip (suppress output for cleaner logs)
        subprocess.run(
            ["pip", "install"] + packages, 
            check=True, 
            stdout=subprocess.PIPE, 
            stderr=subprocess.PIPE
        )
        
        # Verify installation by checking MONAI version
        result = subprocess.run(
            ["pip", "show", "monai"], 
            capture_output=True, 
            text=True, 
            check=True
        )
        
        # Return success message with node IP for tracking
        return f"✅ Dependencies installed on {ray.util.get_node_ip_address()}:\n{result.stdout.splitlines()[0]}"
    
    except subprocess.CalledProcessError as e:
        # Return error message if installation fails
        error_msg = e.stderr if e.stderr else e.stdout
        return f"❌ Failed on {ray.util.get_node_ip_address()}: {error_msg}"

# Get all alive nodes in the Ray cluster
nodes = {node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]}

# Create installation tasks for each node (parallel execution)
# resource constraint ensures each task runs on a different node
tasks = [install_deps.options(resources={f"node:{node}": 0.01}).remote() for node in nodes]

# Wait for all installations to complete and collect results
results = ray.get(tasks)

# Print installation status for each node
for res in results:
    print(res)


## Step 3: Data Exploration - Visualize Medical Images

Before training, let's inspect our data! This cell provides an interactive viewer for NIfTI (Neuroimaging Informatics Technology Initiative) files stored in Snowflake stages.

**NIfTI Format**: Standard medical imaging format storing 3D volumetric data (like CT or MRI scans).

**Interactive Slider**: Navigate through the scan's depth (z-axis) to view different slices of the lung.


In [None]:
# Visualization and medical imaging libraries
import streamlit as st           # Interactive UI components
import matplotlib.pyplot as plt  # Plotting 2D slices
import nibabel as nib           # NIfTI medical image I/O
import numpy as np              # Numerical operations
import tempfile                 # Temporary file handling
import os                       # File system operations

# Snowflake integration
from snowflake.snowpark.files import SnowflakeFile
from snowflake.snowpark.context import get_active_session

# MONAI medical imaging framework components
from monai.networks.nets import LocalNet                                  # Deformable registration network
from monai.networks.blocks import Warp                                    # Spatial warping layer
from monai.losses import GlobalMutualInformationLoss, BendingEnergyLoss  # Registration objectives
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd, Resized
)
from monai.data import Dataset, DataLoader


def load_nifti_from_stage(stage_path):
    """
    Load a NIfTI medical image from Snowflake stage into memory.
    
    This function:
    1. Downloads the .nii.gz file from Snowflake stage to a temp location
    2. Loads it with nibabel (neuroimaging library)
    3. Returns the 3D volume as a NumPy array
    4. Cleans up temporary files
    
    Args:
        stage_path (str): Path to NIfTI file in stage (with or without @ prefix)
    
    Returns:
        numpy.ndarray: 3D medical image volume, or None if loading fails
    """
    session = get_active_session()
    
    # Ensure path starts with @ symbol (required for Snowflake stages)
    clean_path = stage_path if stage_path.startswith("@") else f"@{stage_path}"
    
    try:
        # Step 1: Download file from Snowflake stage to temporary location
        with SnowflakeFile.open(clean_path, 'rb') as f:
            with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp:
                tmp.write(f.read())
                tmp_name = tmp.name
        
        # Step 2: Load NIfTI file using nibabel library
        img = nib.load(tmp_name)
        data = img.get_fdata()  # Extract raw voxel data as NumPy array
        
        # Step 3: Clean up temporary file to free disk space
        os.unlink(tmp_name)
        return data
        
    except Exception as e:
        # Display error in Streamlit UI if file loading fails
        st.error(f"Error reading {clean_path}: {e}")
        return None


def visualize_nifti_interactive(stage_path):
    """
    Create an interactive 3D volume viewer with Streamlit slider.
    
    Displays 2D slices from a 3D medical image, allowing users to navigate
    through the volume depth (z-axis).
    
    Args:
        stage_path (str): Path to NIfTI file in Snowflake stage
    """
    st.write(f"### Viewing: `{stage_path}`")
    
    # 1. Load the 3D volume from Snowflake
    vol_data = load_nifti_from_stage(stage_path)
    if vol_data is None:
        return  # Exit if loading failed

    # 2. Get volume dimensions
    # Typical shape: (Height, Width, Depth) or (X, Y, Z)
    x, y, z = vol_data.shape
    st.write(f"Volume Shape: {vol_data.shape}")

    # 3. Create interactive slider for slice selection
    # Z-axis represents depth (axial view in medical imaging)
    # Default to middle slice
    slice_idx = st.slider("Select Slice", min_value=0, max_value=z-1, value=z//2)

    # 4. Extract the selected 2D slice
    slice_data = vol_data[:, :, slice_idx]
    
    # Rotate 90° for correct anatomical orientation
    # (Medical images often have unexpected orientations when loaded)
    slice_data = np.rot90(slice_data)

    # 5. Display the slice using matplotlib
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(slice_data, cmap='gray')  # Grayscale colormap for CT scans
    ax.axis('off')  # Hide axes for cleaner visualization
    ax.set_title(f"Slice {slice_idx}")
    
    st.pyplot(fig)


# ============================================================================
# EXAMPLE USAGE - Visualize a sample lung CT scan
# ============================================================================
# This demonstrates how to view medical images stored in Snowflake stages
# Replace with any path from your dataset (e.g., from load_paired_paths())

sample_path = f"{DATABASE_NAME}.UTILS.monai_medical_images_stg/scansExp/case_001_exp.nii.gz" 
visualize_nifti_interactive(sample_path)

## Step 4: Load Paired Image Paths (Metadata Only)

**Key Insight**: We DON'T load the actual image data yet - just the file paths!

This cell:
1. Lists all NIfTI files in the Snowflake stage using SQL
2. Pairs "fixed" (expiration) and "moving" (inspiration) scans
3. Returns a lightweight Ray dataset containing only file paths

**Why paths only?** Loading all medical images into memory would consume GB/TB of RAM. Instead, we'll load images just-in-time during training.


In [None]:
def load_paired_paths(stage_location=f"@{DATABASE_NAME}.UTILS.MONAI_MEDICAL_IMAGES_STG"):
    """
    Load paired CT scan paths for medical image registration.
    
    This function performs METADATA OPERATIONS ONLY - no actual image data is loaded.
    It identifies pairs of:
    - Fixed image (expiration CT)
    - Moving image (inspiration CT)
    - Fixed label (lung mask for expiration)
    - Moving label (lung mask for inspiration)
    
    Args:
        stage_location (str): Snowflake stage containing medical images
    
    Returns:
        ray.data.Dataset: Lightweight dataset with file paths only
    """
    session = get_active_session()
    
    # ========================================================================
    # STEP 1: List all NIfTI files using SQL (fast, metadata-only operation)
    # ========================================================================
    # This returns only file paths, not binary data
    df_files = session.sql(
        f"LIST {stage_location} PATTERN = '.*.nii.gz'"
    ).select('"name"').to_pandas()
    
    # Clean up column names and add fully-qualified path
    df_files.rename(columns={'"name"': 'name'})
    df_files["name"] = f"{DATABASE_NAME}.UTILS." + df_files["name"]
  
    print(df_files)
    
    # ========================================================================
    # STEP 2: Categorize files by type using regex patterns
    # ========================================================================
    # Expected folder structure:
    #   - scans/case_XXX_exp.nii.gz    (expiration CT)
    #   - scans/case_XXX_insp.nii.gz   (inspiration CT)
    #   - lungMasks/case_XXX_exp.nii.gz    (expiration lung segmentation)
    #   - lungMasks/case_XXX_insp.nii.gz   (inspiration lung segmentation)
    
    fixed_imgs = df_files[df_files['name'].str.contains("scans.*_exp", regex=True)]
    moving_imgs = df_files[df_files['name'].str.contains("scans.*_insp", regex=True)]
    fixed_masks = df_files[df_files['name'].str.contains("lungMasks.*_exp", regex=True)]
    moving_masks = df_files[df_files['name'].str.contains("lungMasks.*_insp", regex=True)]
    
    pairs = []
    
    # ========================================================================
    # STEP 3: Match files into training pairs using case IDs
    # ========================================================================
    # For each fixed (expiration) image, find its corresponding:
    # - Moving (inspiration) image
    # - Fixed and moving lung masks
    
    for _, row in fixed_imgs.iterrows():
        f_path = row['name']
        
        # Extract case identifier (e.g., "case_001" from "case_001_exp.nii.gz")
        filename = f_path.split('/')[-1] 
        case_id = filename.split('_exp')[0] 
        
        # Find matching files by case ID
        m_path = moving_imgs[moving_imgs['name'].str.contains(f"{case_id}_insp")].iloc[0]['name']
        fl_path = fixed_masks[fixed_masks['name'].str.contains(f"{case_id}_exp")].iloc[0]['name']
        ml_path = moving_masks[moving_masks['name'].str.contains(f"{case_id}_insp")].iloc[0]['name']
        
        # Store the complete pair
        pairs.append({
            "case_id": case_id,
            "fixed_path": f_path,           # Expiration CT
            "moving_path": m_path,          # Inspiration CT
            "fixed_label_path": fl_path,    # Expiration lung mask
            "moving_label_path": ml_path    # Inspiration lung mask
        })
        
    print(f"✅ Paired {len(pairs)} cases using paths only (no binary data loaded).")
    
    # Convert to Ray dataset for distributed processing
    return ray.data.from_pandas(pd.DataFrame(pairs))


# ============================================================================
# EXECUTE: Load paired file paths
# ============================================================================
# This dataset is tiny (only strings), not the actual GB-sized medical images
ray_dataset = load_paired_paths()

## Step 5: Define Custom Dataset with Just-In-Time Loading

This cell creates a PyTorch `Dataset` that:
1. **Loads images on-demand** - Downloads from Snowflake stage only when needed
2. **Applies preprocessing** - Normalizes intensities, resizes volumes
3. **Handles labels correctly** - Preserves binary masks without intensity scaling

**Why custom dataset?** Standard datasets can't handle Snowflake stage I/O. This custom implementation downloads each file to a temporary location, processes it, then cleans up.


In [None]:
# ============================================================================
# HELPER FUNCTIONS FOR SNOWFLAKE STAGE I/O
# ============================================================================

def read_file_from_stage(stage_path):
    """
    Read binary file content from Snowflake stage.
    
    Args:
        stage_path (str): Path to file in stage (with or without @ prefix)
    
    Returns:
        bytes: Raw file content
    """
    session = get_active_session()
    # Ensure path has @ prefix for Snowflake stage access
    clean_path = stage_path if stage_path.startswith("@") else f"@{stage_path}"
    
    with SnowflakeFile.open(clean_path, 'rb') as f:
        return f.read()


# ============================================================================
# CUSTOM PYTORCH DATASET WITH SNOWFLAKE INTEGRATION
# ============================================================================

class SnowflakeStageDataset(Dataset):
    """
    PyTorch Dataset for medical images stored in Snowflake stages.
    
    Key Features:
    1. **Just-In-Time Loading**: Downloads images from Snowflake only when needed
    2. **Automatic Cleanup**: Deletes temporary files after processing
    3. **MONAI Integration**: Applies medical imaging transformations
    4. **Separate Transforms**: Different pipelines for images vs. labels
    
    This approach avoids loading all GB/TB of medical data into memory at once.
    
    Args:
        data_dicts (list): List of dicts with keys like 'fixed_path', 'moving_path', etc.
        transform_img (Callable): MONAI transforms for images (includes normalization)
        transform_lbl (Callable): MONAI transforms for labels (NO normalization)
    """
    def __init__(self, data_dicts, transform_img=None, transform_lbl=None):
        self.data_dicts = data_dicts
        self.transform_img = transform_img
        self.transform_lbl = transform_lbl

    def __len__(self):
        """Return number of image pairs in dataset."""
        return len(self.data_dicts)

    def __getitem__(self, index):
        """
        Load and preprocess a single training sample.
        
        This method:
        1. Downloads 4 files from Snowflake (fixed/moving images + labels)
        2. Saves them to temporary locations
        3. Applies MONAI transformations
        4. Cleans up temporary files
        5. Returns preprocessed tensors
        
        Args:
            index (int): Dataset index
        
        Returns:
            dict: Dictionary with 'fixed', 'moving', 'fixed_label', 'moving_label' tensors
        """
        item = self.data_dicts[index]
        
        # ====================================================================
        # STEP 1: Download files from Snowflake to temporary locations
        # ====================================================================
        temp_files = {}
        for key in ["fixed", "moving", "fixed_label", "moving_label"]:
            path = item[f"{key}_path"]
            binary = read_file_from_stage(path)
            
            # Create temporary file with .nii.gz extension (required by nibabel)
            tf = tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False)
            tf.write(binary)
            tf.close()
            temp_files[key] = tf.name
        
        # ====================================================================
        # STEP 2: Prepare data dictionary for MONAI transforms
        # ====================================================================
        # MONAI expects file paths as input to LoadImaged transform
        data = {
            "fixed": temp_files["fixed"],
            "moving": temp_files["moving"],
            "fixed_label": temp_files["fixed_label"],
            "moving_label": temp_files["moving_label"],
        }
        
        # ====================================================================
        # STEP 3: Apply MONAI transformations (load, normalize, resize, augment)
        # ====================================================================
        if self.transform_img:
            data = self.transform_img(data)
        
        # ====================================================================
        # STEP 4: Clean up temporary files to free disk space
        # ====================================================================
        for fpath in temp_files.values():
            if os.path.exists(fpath):
                os.unlink(fpath)
        
        return data

## Step 6: Define Training Function

This is the core training loop! The function:

1. **Defines Transforms** - Preprocessing pipeline for medical images (windowing, resizing, augmentation)
2. **Creates Data Loaders** - Wraps our custom dataset for batch iteration
3. **Builds Model** - LocalNet architecture for deformation field prediction
4. **Training Loop** - Iterates through epochs, computing loss and updating weights
5. **Validation** - Monitors Dice score on held-out data
6. **Model Checkpointing** - Saves best model to Snowflake stage

**Key Concepts:**
- **Deformation Field (DDF)**: 3D vector field showing how to warp moving image to match fixed image
- **Global Mutual Information Loss**: Measures image similarity (registration quality)
- **Bending Energy Loss**: Regularization to keep deformations smooth and realistic


In [None]:
# ============================================================================
# MONAI MEDICAL IMAGING IMPORTS
# ============================================================================

from monai.networks.nets import LocalNet            # CNN-based registration network
from monai.networks.blocks import Warp              # Spatial transformation layer
from monai.losses import GlobalMutualInformationLoss, BendingEnergyLoss  # Registration losses
from monai.transforms import (
    Compose,                    # Chain multiple transforms
    LoadImaged,                 # Load medical images from disk
    EnsureChannelFirstd,        # Ensure channel-first format (C, H, W, D)
    ScaleIntensityRanged,       # CT windowing (normalize HU values)
    Resized,                    # Resize 3D volumes
    LoadImage,                  # Single-image loader
    EnsureChannelFirst,         # Single-image channel formatting
    ScaleIntensityRange,        # Single-image intensity scaling
    Resize,                     # Single-image resizing
    RandAffined,                # Random affine augmentation (rotation, translation, scale)
    RandGaussianNoised,         # Add random Gaussian noise
    RandGaussianSmoothd         # Random Gaussian smoothing (blur)
)
from monai.data import Dataset, DataLoader


# ============================================================================
# DISTRIBUTED TRAINING FUNCTION (GPU-ACCELERATED)
# ============================================================================

@ray.remote(num_gpus=1)
def train_registration_model(
    train_files, 
    val_files=None,
    config=None,                                      # ← FIX: Use None as default
    save_path=None                                    # ← FIX: Use None as default  
):
    """
    Train a deep learning model for medical image registration.
    
    This function performs end-to-end training:
    1. **Data Preprocessing**: CT windowing, resizing, augmentation
    2. **Model Architecture**: LocalNet (U-Net-like CNN for deformation field prediction)
    3. **Loss Function**: Mutual Information (similarity) + Bending Energy (smoothness)
    4. **Optimization**: Adam with learning rate scheduling
    5. **Validation**: Dice score monitoring on held-out data
    6. **Checkpointing**: Save best model to Snowflake stage
    
    Args:
        train_files (list): List of training sample dictionaries with file paths
        val_files (list, optional): List of validation sample dictionaries
        config (dict): Training configuration (epochs, batch size, learning rate, etc.)
        save_path (str): Snowflake stage path for saving model checkpoints
    
    Returns:
        dict: Training results including best model path, final metrics, and history
    """
    print(f"🚀 Starting training with config: {config}")
    
    # Determine compute device (GPU if available, otherwise CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"📍 Using device: {device}")
    
    # ========================================================================
    # DATA PREPROCESSING PIPELINE
    # ========================================================================
    # These transforms must match those used during inference for consistency!
    
    # ========================================================================
    # IMAGE TRANSFORMS (with data augmentation for training)
    # ========================================================================
    train_img_transforms = Compose([
        # Step 1: Load NIfTI files from disk
        LoadImaged(keys=["fixed", "moving"]),
        
        # Step 2: Ensure channel-first format: (C, H, W, D)
        # Medical images are often (H, W, D), but PyTorch needs (C, H, W, D)
        EnsureChannelFirstd(keys=["fixed", "moving"]),
        
        # Step 3: CT WINDOWING - Map Hounsfield Units (HU) to [0, 1] range
        # This is CRITICAL for consistent training and inference!
        # CT_MIN_HU = -1000 (air), CT_MAX_HU = 500 (soft tissue/bone)
        # Values outside this range are clipped
        ScaleIntensityRanged(
            keys=["fixed", "moving"],
            a_min=CT_MIN_HU,   # Input minimum (Hounsfield Units)
            a_max=CT_MAX_HU,   # Input maximum (Hounsfield Units)
            b_min=0.0,         # Output minimum (normalized)
            b_max=1.0,         # Output maximum (normalized)
            clip=True          # Clip values outside [CT_MIN_HU, CT_MAX_HU]
        ),
        
        # Step 4: Resize volumes to consistent shape for batching
        # Trilinear interpolation for smooth resampling
        Resized(
            keys=["fixed", "moving"], 
            spatial_size=config["spatial_size"],  # e.g., (96, 96, 104)
            mode="trilinear"
        ),
        
        # ====================================================================
        # DATA AUGMENTATION (improves model generalization)
        # ====================================================================
        
        # Random affine transformations (50% probability)
        # Simulates patient positioning variations
        RandAffined(
            keys=["fixed", "moving"],
            prob=0.5,                              # Apply 50% of the time
            rotate_range=(0.1, 0.1, 0.1),         # ±0.1 radians (~6°) per axis
            translate_range=(10, 10, 10),         # ±10 pixels shift
            scale_range=(0.1, 0.1, 0.1),          # ±10% scaling
            mode=["bilinear", "bilinear"],        # Smooth interpolation
            padding_mode="zeros"                   # Fill empty regions with zeros
        ),
        
        # Random Gaussian noise (30% probability)
        # Simulates scanner noise and improves robustness
        RandGaussianNoised(
            keys=["fixed", "moving"], 
            prob=0.3,     # Apply 30% of the time
            std=0.05      # Standard deviation of noise
        ),
        
        # Random Gaussian smoothing (30% probability)
        # Simulates different scanner resolutions
        RandGaussianSmoothd(
            keys=["fixed", "moving"], 
            prob=0.3, 
            sigma_x=(0.5, 1.5),  # Blur kernel size range per axis
            sigma_y=(0.5, 1.5), 
            sigma_z=(0.5, 1.5)
        ),
    ])
    
    # ========================================================================
    # LABEL (SEGMENTATION MASK) TRANSFORMS
    # ========================================================================
    # CRITICAL: Labels are binary masks (0 or 1) - NO intensity normalization!
    train_lbl_transforms = Compose([
        # Step 1: Load segmentation masks from disk
        LoadImaged(keys=["fixed_label", "moving_label"]),
        
        # Step 2: Ensure channel-first format
        EnsureChannelFirstd(keys=["fixed_label", "moving_label"]),
        
        # Step 3: Resize masks to match image size
        # ⚠️ IMPORTANT: Use "nearest" interpolation to preserve binary values!
        # Trilinear/bilinear would create fractional values between 0 and 1
        Resized(
            keys=["fixed_label", "moving_label"], 
            spatial_size=config["spatial_size"], 
            mode="nearest"  # Preserve discrete labels (no blending)
        ),
        
        # Step 4: Apply SAME affine transformations as images
        # This ensures images and labels stay aligned after augmentation
        RandAffined(
            keys=["fixed_label", "moving_label"],
            prob=0.5,                              # Must match image augmentation
            rotate_range=(0.1, 0.1, 0.1),
            translate_range=(10, 10, 10),
            scale_range=(0.1, 0.1, 0.1),
            mode=["nearest", "nearest"],           # Nearest for labels (preserve binary)
            padding_mode="zeros"
        ),
    ])
    
    # Validation transforms (no augmentation)
    val_transforms = Compose([
        LoadImaged(keys=["fixed", "moving", "fixed_label", "moving_label"]),
        EnsureChannelFirstd(keys=["fixed", "moving", "fixed_label", "moving_label"]),
        ScaleIntensityRanged(
            keys=["fixed", "moving"],
            a_min=CT_MIN_HU,
            a_max=CT_MAX_HU,
            b_min=0.0,
            b_max=1.0,
            clip=True
        ),
        Resized(keys=["fixed", "moving", "fixed_label", "moving_label"],
                spatial_size=config["spatial_size"],
                mode=["trilinear", "trilinear", "nearest", "nearest"]),
    ])
    
    # ========================================================================
    # DATASETS & LOADERS
    # ========================================================================
    
    # Combine image and label transforms
    def combined_transform(data):
        data = train_img_transforms(data)
        data = train_lbl_transforms(data)
        return data
    
    train_ds = SnowflakeStageDataset(train_files, transform_img=combined_transform)
    train_loader = DataLoader(
        train_ds, 
        batch_size=config["batch_size"], 
        shuffle=True, 
        num_workers=0
    )
    
    val_loader = None
    if val_files:
        val_ds = SnowflakeStageDataset(val_files, transform_img=val_transforms)
        val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0)
    
    # ========================================================================
    # MODEL SETUP
    # ========================================================================
    
    model = LocalNet(
        spatial_dims=3,
        in_channels=2,
        out_channels=3,
        num_channel_initial=config["num_channel_initial"],
        extract_levels=[3],
        out_activation=None,
        out_kernel_initializer="zeros"
    ).to(device)
    
    warp_layer = Warp().to(device)
    
    # ========================================================================
    # LOSS & OPTIMIZER
    # ========================================================================
    
    loss_sim = GlobalMutualInformationLoss()
    loss_reg = BendingEnergyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # ========================================================================
    # TRAINING LOOP
    # ========================================================================
    
    best_val_dice = 0.0
    training_history = []
    
    for epoch in range(config["max_epochs"]):
        model.train()
        epoch_loss = 0
        epoch_sim_loss = 0
        epoch_reg_loss = 0
        step = 0
        
        for batch_data in train_loader:
            step += 1
            optimizer.zero_grad()
            
            fixed = batch_data["fixed"].to(device)
            moving = batch_data["moving"].to(device)
            
            # Forward
            input_tensor = torch.cat((moving, fixed), dim=1)
            ddf = model(input_tensor)
            pred_image = warp_layer(moving, ddf)
            
            # Loss calculation
            sim_loss = loss_sim(pred_image, fixed)
            reg_loss = loss_reg(ddf)
            total_loss = sim_loss + config["reg_weight"] * reg_loss
            
            # Backward
            total_loss.backward()
            optimizer.step()
            
            epoch_loss += total_loss.item()
            epoch_sim_loss += sim_loss.item()
            epoch_reg_loss += reg_loss.item()
        
        avg_loss = epoch_loss / step
        avg_sim = epoch_sim_loss / step
        avg_reg = epoch_reg_loss / step
        
        print(f"✅ Epoch {epoch + 1}/{config['max_epochs']}")
        print(f"   Loss: {avg_loss:.4f} (Sim: {avg_sim:.4f}, Reg: {avg_reg:.4f})")
        
        # ====================================================================
        # VALIDATION
        # ====================================================================
        
        if val_loader:
            model.eval()
            val_dice_scores = []
            
            with torch.no_grad():
                for val_batch in val_loader:
                    fixed = val_batch["fixed"].to(device)
                    moving = val_batch["moving"].to(device)
                    fixed_lbl = val_batch["fixed_label"].to(device)
                    moving_lbl = val_batch["moving_label"].to(device)
                    
                    input_tensor = torch.cat((moving, fixed), dim=1)
                    ddf = model(input_tensor)
                    pred_label = warp_layer(moving_lbl, ddf)
                    
                    # Dice calculation
                    intersection = (pred_label * fixed_lbl).sum()
                    dice = (2.0 * intersection + 1e-5) / (pred_label.sum() + fixed_lbl.sum() + 1e-5)
                    val_dice_scores.append(dice.item())
            
            avg_val_dice = np.mean(val_dice_scores)
            print(f"   Val Dice: {avg_val_dice:.4f}")
            
            # Update scheduler
            scheduler.step(avg_loss)
            
            # Save best model
            if avg_val_dice > best_val_dice:
                best_val_dice = avg_val_dice
                print(f"   🏆 New best model! Dice: {best_val_dice:.4f}")
                save_model_to_stage(model, f"{save_path}/best_model.pth")
        
        # Record history
        training_history.append({
            "epoch": epoch + 1,
            "train_loss": avg_loss,
            "train_sim": avg_sim,
            "train_reg": avg_reg,
            "val_dice": avg_val_dice if val_loader else None
        })
        
        # Periodic checkpoint
        if (epoch + 1) % config["save_interval"] == 0:
            checkpoint_path = f"{save_path}/checkpoint_epoch_{epoch+1}.pth"
            save_model_to_stage(model, checkpoint_path)
            print(f"   💾 Checkpoint saved: {checkpoint_path}")
    
    # ========================================================================
    # SAVE FINAL MODEL
    # ========================================================================
    
    final_path = f"{save_path}/final_model.pth"
    save_model_to_stage(model, final_path)
    print(f"🏁 Training complete! Final model: {final_path}")
    
    return {
        "final_model_path": final_path,
        "best_model_path": f"{save_path}/best_model.pth",
        "best_val_dice": best_val_dice,
        "training_history": training_history
    }



def save_model_to_stage(model, stage_path):
    """
    Save PyTorch model to Snowflake stage with correct filename.
    
    Args:
        stage_path: Full stage path like "@RESULTS_STG/final_model.pth"
    """
    import os
    
    # Extract the desired filename from stage_path
    # "@RESULTS_STG/final_model.pth" -> "final_model.pth"
    filename = stage_path.split('/')[-1]
    
    # Create local file with EXACT name we want
    local_path = f"/tmp/{filename}"
    
    # Save model locally
    torch.save(model.state_dict(), local_path)
    
    # Extract stage directory (without filename)
    # "@RESULTS_STG/final_model.pth" -> "@RESULTS_STG"
    stage_dir = '/'.join(stage_path.split('/')[:-1])
    
    try:
        session = get_active_session()
        
        # Upload to stage DIRECTORY only (filename preserved from local)
        session.file.put(
            local_path,
            stage_dir,  # ✅ Just the directory!
            overwrite=True,
            auto_compress=False
        )
        
        print(f"✅ Saved to {stage_path}")
        
    except Exception as e:
        print(f"❌ Failed to save: {e}")
    finally:
        # Clean up local file
        if os.path.exists(local_path):
            os.unlink(local_path)


## Step 7: Prepare Training and Validation Sets

Now we split our dataset into:
- **Training Set (80%)** - Used for model optimization
- **Validation Set (20%)** - Used for monitoring performance and early stopping

This split helps prevent overfitting and ensures our model generalizes to unseen data.


In [None]:
# ============================================================================
# CONVERT RAY DATASET TO PYTHON LIST
# ============================================================================
# Ray datasets are distributed - we need to collect them for training

train_files_list = []

# Iterate through all image pairs in the Ray dataset
for row in ray_dataset.iter_rows():
    train_files_list.append({
        "fixed_path": row["fixed_path"],        # Expiration CT path
        "moving_path": row["moving_path"],      # Inspiration CT path
        "fixed_label_path": row["fixed_label_path"],    # Expiration lung mask
        "moving_label_path": row["moving_label_path"]   # Inspiration lung mask
    })


# ============================================================================
# SPLIT INTO TRAINING AND VALIDATION SETS (80/20)
# ============================================================================

# Calculate split index (80% for training)
split_idx = int(0.8 * len(train_files_list))

# Create splits
train_split = train_files_list[:split_idx]   # First 80%
val_split = train_files_list[split_idx:]     # Last 20%

# Display split information
print(f"📊 Training: {len(train_split)} pairs, Validation: {len(val_split)} pairs")


## Step 8: Launch Training 🚀

Time to train! This cell:
1. **Defines hyperparameters** - Learning rate, batch size, epochs, etc.
2. **Launches distributed training** - Runs on GPU via Ray
3. **Monitors progress** - Prints training loss and validation Dice score
4. **Saves best model** - Checkpoints to Snowflake stage

**Expected Training Time**: ~10-30 minutes depending on dataset size and GPU.

**What is Dice Score?** 
A metric for segmentation overlap (0 = no overlap, 1 = perfect overlap). We use it to evaluate registration quality by warping lung masks and comparing them.


In [None]:
# ============================================================================
# TRAINING CONFIGURATION & HYPERPARAMETERS
# ============================================================================

# ============================================================================
# CT WINDOWING PARAMETERS (Hounsfield Units)
# ============================================================================
# These define the CT intensity range we're interested in for lung imaging
CT_MIN_HU = -1000  # Air (lowest HU in lungs)
CT_MAX_HU = 500    # Soft tissue/bone (upper limit for lung CT)

# ============================================================================
# MODEL & TRAINING HYPERPARAMETERS
# ============================================================================
CONFIG = {
    # Image size after preprocessing (downsampled for memory efficiency)
    "spatial_size": (96, 96, 104),  # (Height, Width, Depth)
    
    # Network architecture parameter (controls model capacity)
    "num_channel_initial": 32,  # Number of feature channels in first layer
    
    # Training parameters
    "batch_size": 2,            # Number of image pairs per batch (limited by GPU memory)
    "learning_rate": 1e-4,      # Adam optimizer learning rate
    "max_epochs": 25,           # Total training iterations through dataset
    
    # Loss function weighting
    "reg_weight": 1.0,          # Weight for regularization (smoothness penalty)
                                # Lower = more flexible deformations
                                # Higher = smoother but potentially less accurate
    
    # Checkpointing
    "save_interval": 10,        # Save model checkpoint every N epochs
}

print("✅ Training configuration defined:")
print(f"   CT Window: [{CT_MIN_HU}, {CT_MAX_HU}] HU")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

In [None]:
# ============================================================================
# LAUNCH DISTRIBUTED TRAINING ON GPU
# ============================================================================

# Submit training task to Ray (runs on GPU-enabled node)
training_future = train_registration_model.remote(
    train_files=train_split,       # Training image pairs
    val_files=val_split,           # Validation image pairs
    config=CONFIG,                 # Hyperparameters defined above
    save_path=f"@{DATABASE_NAME}.UTILS.RESULTS_STG"  # Snowflake stage for checkpoints
)

# Wait for training to complete and retrieve results
# This will block until training finishes (could take 10-30 minutes)
training_result = ray.get(training_future)

# Display results in Streamlit UI
st.success(f"✅ Training Complete! Best Dice: {training_result['best_val_dice']:.4f}")
st.json(training_result)  # Show full training history and metrics

## Step 9: Verify Saved Model Checkpoints

Let's verify that our model was successfully saved to the Snowflake stage during training.


In [None]:
ls @results_stg;

## Step 10: Register Model in Snowflake Model Registry

Now we register our trained model in Snowflake's Model Registry for:
- **Version Control** - Track different model versions over time
- **Deployment** - Easy deployment to inference services
- **Reproducibility** - Store model metadata alongside weights
- **Collaboration** - Share models across teams

The Model Registry provides enterprise-grade model management capabilities.


In [None]:
# Required imports for model registration
import torch                                        # PyTorch deep learning framework
import io                                          # For binary stream handling
from snowflake.snowpark.context import get_active_session
from snowflake.ml.registry import Registry         # Snowflake Model Registry
from monai.networks.nets import LocalNet          # Registration network architecture

session = get_active_session()


# ============================================================================
# MODEL REGISTRATION FUNCTION
# ============================================================================

def register_model(
    model_stage_path=f"@{DATABASE_NAME}.UTILS.RESULTS_STG/best_model.pth",
    model_name="LUNG_CT_REGISTRATION",
    version_name="v1"
):
    """
    Register a trained PyTorch model in Snowflake Model Registry.
    
    This function:
    1. Loads model weights from Snowflake stage
    2. Reconstructs the model architecture
    3. Registers the model with sample input for schema inference
    4. Returns a model reference for deployment
    
    Args:
        model_stage_path (str): Path to .pth file in Snowflake stage
        model_name (str): Name for the model in registry
        version_name (str): Version identifier (e.g., "v1", "v2")
    
    Returns:
        ModelReference: Reference to registered model
    """
    print(f"🔄 Registering {model_name} v{version_name}...")
    
    # ========================================================================
    # STEP 1: Load model weights from Snowflake stage
    # ========================================================================
    raw_stream = session.file.get_stream(model_stage_path)
    state_dict = torch.load(
        io.BytesIO(raw_stream.read()), 
        map_location='cpu'  # Load to CPU first (works on any node)
    )
    
    # ========================================================================
    # STEP 2: Reconstruct model architecture
    # ========================================================================
    # Architecture must match the one used during training!
    model = LocalNet(
        spatial_dims=3,              # 3D medical images
        in_channels=2,               # Concatenated fixed + moving (2 images)
        out_channels=3,              # 3D deformation field (x, y, z displacement)
        num_channel_initial=32,      # Feature channels (must match training)
        extract_levels=[3],          # U-Net depth
        out_activation=None,         # No activation on output (regression task)
        out_kernel_initializer="zeros"  # Initialize to identity transform
    )
    
    # Load trained weights into architecture
    model.load_state_dict(state_dict)
    
    # Set to evaluation mode (disable dropout, batch norm tracking)
    model.eval()
    
    # ========================================================================
    # STEP 3: Create sample input for schema inference
    # ========================================================================
    # Model Registry uses this to infer input/output shapes
    # Shape: (batch=1, channels=2, H=96, W=96, D=104)
    sample_input = torch.randn(1, 2, 96, 96, 104)
    
    # ========================================================================
    # STEP 4: Register model in Snowflake Model Registry
    # ========================================================================
    registry = Registry(
        session, 
        database_name=DATABASE_NAME, 
        schema_name="UTILS"
    )
    
    model_ref = registry.log_model(
        model_name=model_name,
        version_name=version_name,
        model=model,                    # PyTorch model (registry handles conversion)
        sample_input_data=sample_input  # For input/output schema inference
    )
    
    print(f"✅ Registered: {model_ref.fully_qualified_model_name}")
    return model_ref





## Step 11: Execute Model Registration

Run the registration function and verify the model appears in the registry.


In [None]:
# ============================================================================
# EXECUTE MODEL REGISTRATION
# ============================================================================

# Register the best model from training
# This creates a new entry in the Model Registry
model_ref = register_model(
    model_stage_path=f"@{DATABASE_NAME}.UTILS.RESULTS_STG/best_model.pth",  # Best checkpoint
    model_name="LUNG_CT_REGISTRATION",   # Model identifier
    version_name="v1"                     # Version tag (increment for updates)
)


# ============================================================================
# VERIFY REGISTRATION SUCCESS
# ============================================================================

# Connect to Model Registry
registry = Registry(
    session, 
    database_name=DATABASE_NAME, 
    schema_name="UTILS"
)

# List all registered models
all_models = registry.show_models()

print("\n📚 All registered models in {DATABASE_NAME}.UTILS:")
print(all_models)


## Step 12: Test Model Loading & Inference

Verify we can load the registered model and run inference. This confirms the model is ready for deployment!


In [None]:
# ============================================================================
# LOAD MODEL FROM REGISTRY
# ============================================================================

# Retrieve the registered model by name and version
loaded = registry.get_model("LUNG_CT_REGISTRATION").version("v1")
print(f"\n✅ Successfully loaded: {loaded.model_name}")


# ============================================================================
# TEST INFERENCE WITH DUMMY DATA
# ============================================================================

# Create synthetic test input matching expected shape
# Shape: (batch=1, channels=2, H=96, W=96, D=104)
# Channels: [fixed_image, moving_image] concatenated
test_input = torch.randn(1, 2, 96, 96, 104)

try:
    # Run inference (predicts 3D deformation field)
    result = loaded.run(test_input)
    
    print(f"✅ Test inference works! Output shape: {result.shape}")
    print(f"   Expected: (1, 3, 96, 96, 104) - 3D displacement field")
    
except Exception as e:
    # If there's an error, it's still registered (just can't run on synthetic data)
    print(f"ℹ️  Note: {e}")
    print("   Model is registered successfully!")
    print("   (Use with actual preprocessed CT data for real inference)")

In [None]:
# Dynamic summary with actual database name
st.markdown(f"""
## 🎉 Training Complete!

### What We Accomplished

1. ✅ **Set up distributed computing** with Ray on Snowflake containers
2. ✅ **Loaded medical images** from Snowflake stages (just-in-time)
3. ✅ **Trained a deep learning model** for CT registration using MONAI
4. ✅ **Monitored training** with validation metrics (Dice score)
5. ✅ **Saved model checkpoints** to Snowflake stages
6. ✅ **Registered model** in Snowflake Model Registry

### Next Steps

- **Deploy for Inference**: Use the registered model in 03_model_inference
- **Tune Hyperparameters**: Adjust learning rate, regularization weight, or epochs
- **Add More Data**: Train on larger datasets for better generalization
- **Export Visualizations**: Save training loss curves and sample predictions

### Model Location

- **Checkpoints**: `@{DATABASE_NAME}.UTILS.RESULTS_STG/best_model.pth`
- **Registry**: `{DATABASE_NAME}.UTILS.LUNG_CT_REGISTRATION` (version v1)

---

**Ready to run inference? Proceed to 03_model_inference!** 🚀
""")