## Scalable pipeline for segment CT images using Container Runtime + MONAI Pretained model

In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


### Setup the cluster

In [None]:
import ray 

# Configure some logging behavior

ray.init(address='auto', ignore_reinit_error=True)


def configure_ray_logger() -> None:
    #Configure Ray logging
    ray_logger = logging.getLogger("ray")
    ray_logger.setLevel(logging.CRITICAL)

    data_logger = logging.getLogger("ray.data")
    data_logger.setLevel(logging.CRITICAL)

    #Configure root logger
    logger = logging.getLogger()
    logger.setLevel(logging.CRITICAL)

    #Configure Ray's data context
    context = ray.data.DataContext.get_current()
    context.execution_options.verbose_progress = False
    context.enable_operator_progress_bars = False

configure_ray_logger()

In [None]:
ray.nodes()

#### if want to increase num of nodes

In [None]:
from snowflake.ml.runtime_cluster import scale_cluster

# Example 1: Scale up the cluster
scale_cluster("ANDA_TEST_MULTI_NODE_INSTALL", 4)

#### Install dependencies

In [None]:
@ray.remote(num_cpus=0)  # Ensures task does not consume CPU slots
def install_deps():
    try:
        import subprocess
        packages = ["monai", "pytorch-ignite", "itk", "gdown", "torchvision", "lmdb", "transformers", "einops", "nibabel"]
        
        # Install dependencies
        subprocess.run(["pip", "install"] + packages, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

        # Verify installation
        result = subprocess.run(["pip", "show", "monai"], capture_output=True, text=True, check=True)
        return f"✅ Dependencies installed on {ray.util.get_node_ip_address()}:\n{result.stdout.splitlines()[0]}"
    
    except subprocess.CalledProcessError as e:
        return f"❌ Failed on {ray.util.get_node_ip_address()}: {e.stderr if e.stderr else e.stdout}"

# Get unique node IPs in the cluster
nodes = {node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]}

# Install ffmpeg on each unique node
tasks = [install_deps.options(resources={f"node:{node}": 0.01}).remote() for node in nodes]
results = ray.get(tasks)

# Print results
for res in results:
    print(res)

#### Define inference operations

The following code is a MonaiInferencer class that is used to perform inference on a batch of images. 
It first by loading the pretrained model from a MONAI bundle, together with other components for inference, such as preprocessing, postprocessing. 
It will save the the output file to a Snowflake stage.

In [None]:
import monai
import pandas as pd
import tempfile
import os

class MonaiInferencer:
    def __init__(self):
        self.session = get_active_session()
        # Download the bundle which includes pretrained-model check points
        monai.bundle.download(name='spleen_ct_segmentation', bundle_dir='/tmp')
        bundle_root = "/tmp/spleen_ct_segmentation/configs"
        config_file = f"{bundle_root}/inference.json"
        
        # Parse MONAI Bundle configuration, initialize MONAI 
        config = monai.bundle.ConfigParser()
        config.read_config(config_file)
        self.device = config.get_parsed_content('device')
        self.network = config.get_parsed_content("network")
        self.inferer = config.get_parsed_content("inferer")
        self.preprocessing = config.get_parsed_content("preprocessing")
        self.postprocessing = config.get_parsed_content("postprocessing")
        self.checkpointloader = config.get_parsed_content("checkpointloader")
        self.output_dir = config.get_parsed_content("output_dir")

    def _infer(self, files):
        # Create data loader
        dataset = monai.data.Dataset(data=[{"image": file} for file in files], transform=self.preprocessing)
        dataloader = monai.data.DataLoader(dataset, batch_size=1, num_workers=0)

        # Set up evaluator# based on inference.json
        evaluator = monai.engines.SupervisedEvaluator(
            device=self.device,
            val_data_loader=dataloader,
            network=self.network,
            inferer=self.inferer,
            postprocessing=self.postprocessing,
            amp=True
        )

        # Run inference
        evaluator.run()

        # Save output files to Snowflake stage
        for root, _, files in os.walk(self.output_dir):
            for file in files:
                local_path = os.path.join(root, file)
                stage_path = f"@ANDA_TEST_STAGE/{file}"
                self.session.file.put(local_path, stage_path, overwrite=True)

        return
    
    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
        temp_files = []
        try:
            # Write each binary to a temporary file.
            for binary_content in batch["file_binary"]:
                tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz")
                tmp_file.write(binary_content)
                tmp_file.close()
                temp_files.append(tmp_file.name)
            
            # Use the temporary file paths for inference.
            self._infer(temp_files)
            batch.drop(columns=['file_binary'], inplace=True)
        finally:
            # Clean up temporary files.
            for file_path in temp_files:
                try:
                    os.remove(file_path)
                except OSError:
                    pass
        return batch

#### Define Data Source

In [None]:
from snowflake.ml.ray.datasource import SFStageBinaryFileDataSource

data_source = SFStageBinaryFileDataSource(
    stage_location="ANDA_TEST_STAGE/imagesTs_replica2/",
    database=session.get_current_database(),
    schema=session.get_current_schema(),
    file_pattern="*.nii.gz",
)

ray_dataset = ray.data.read_datasource(data_source)

In [None]:
ray_dataset.count()

### Apply inference operations

In [None]:
batch_size=50

processed_ds = ray_dataset.map_batches(
    MonaiInferencer,
    batch_size=batch_size,
    batch_format='pandas',
    concurrency=4,
    num_gpus=1,
)

### Major benefits
* Only need to worries about core inference logic
* Can reuse pre-trained MONAI model/bundle
* Inference is scalable to multi-node cluster for best performance-cost ratio
* Medical images is streamed to Container Runtime, no requirement on local disk space. 

#### Apply the inference operations to the dataset

In [None]:
processed_ds.to_pandas()

### Use the Ray Dataset directly with Distributed Pytorch Training API

**Note** the following code will be fully available in the release at end of April 2025

In [None]:
from snowflake.ml.data.data_conenctor import DataConnector
from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor
from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor, PyTorchScalingConfig, WorkerResourceConfig
from snowflake.ml.data.sharded_data_connector import DataConnector, ShardedDataConnector
from snowflake.ml.modeling.distributors.pytorch import get_context
from torch.utils.data import Dataset, DataLoader

from monai.losses import DiceLoss
from monai.optimizers import Novograd

import torch
from torch.nn.parallel import DistributedDataParallel as DDP

def train_func():
    # get RaySGD context
    context = get_context()
    rank = context.get_rank()
    local_rank = context.get_local_rank()
    world_size = context.get_world_size()
    num_epochs = 5
    batch_size = 30
    lr = 1e-3
    # init NCCL
    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        rank=rank,
        world_size=world_size
    )
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")

    # grab your Ray dataset shard
    train_ds = context.get_dataset_map()['train'].get_shard().to_torch_dataset()

    # build model, loss, optimizer and definds your own pytorch model
    model = MyPytorchModel().to(device)
    model = DDP(model, device_ids=[local_rank])
    loss_fn = DiceLoss()
    optimizer = Novograd(model.parameters(), lr=lr)
    data_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
    )

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        batch_cnt = 0

        # this will rewind automatically each epoch
        for batch in data_loader:
            optimizer.zero_grad()
            # stack all non‐target cols, move to GPU
            features = torch.stack(
                [batch[k] for k in batch if k != "target"],
                dim=1
            ).to(device)
            targets = batch["target"].unsqueeze(1).to(device)

            preds = model(features)
            loss = loss_fn(preds, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            batch_cnt += 1

        # only let rank 0 print
        if rank == 0:
            avg_loss = total_loss / batch_cnt
            print(f"[Epoch {epoch+1}/{num_epochs}] loss: {avg_loss:.4f}")


train_data_connector = DataConnector.from_ray_ds(ray_dataset)

# Create pytorch distributor.
pytorch_trainer = PyTorchDistributor(  
    train_func=train_func,
    scaling_config=PyTorchScalingConfig(  
        num_nodes=4,  
        num_workers_per_node=1,  
        resource_requirements_per_worker=WorkerResourceConfig(num_cpus=6, num_gpus=1),  
    )  
) 

pytorch_trainer.run(
    dataset_map={'train': train_data_connector}
)
