[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/KristoferLintonReid/hgsoc-prognosis-CTradiomics/blob/main/examples/nnunet_inference.ipynb)

# nnUNet Inference and Dice Score Calculation

This notebook demonstrates how to run nnUNet inference and calculate the Dice score.

**Compatibility:**
- **Local Machine (macOS/Linux):** Runs with your local environment.
- **Google Colab:** Automatically installs dependencies and enables GPU acceleration.

**Prerequisites:**
- **Model:** A trained nnUNet model folder (e.g., `Task001_OVARIAN`).
- **Input Data:** An image file (`.nii.gz`) to segment.
- **Ground Truth:** (Optional) A segmentation file for Dice calculation.


In [None]:
import os
import shutil
import subprocess
import sys
import numpy as np

# --- Environment Detection ---
try:
    from google.colab import drive
    IS_COLAB = True
    print("Detected Google Colab environment.")
except ImportError:
    IS_COLAB = False
    print("Detected Local environment.")

# --- Colab Setup ---
if IS_COLAB:
    print("Installing dependencies for Colab...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "nnunet", "nibabel", "numpy<2"])
    
    print("Mounting Google Drive...")
    drive.mount('/content/drive')
    
    # Define Colab-specific paths
    # NOTE: Update these paths to match your Drive structure if needed
    base_dir = "/content/drive/MyDrive/hgsoc-prognosis-CTradiomics/examples"
    if not os.path.exists(base_dir):
        print(f"Warning: Base directory {base_dir} does not exist. Using local workspace.")
        base_dir = "/content"
else:
    # Local paths
    base_dir = os.path.abspath(".")

print(f"Working Directory: {base_dir}")

In [None]:
import nibabel as nib
import torch

# --- Configuration ---
# Paths for Model and Data
local_results_folder = os.path.join(base_dir, "nnunet_trained_models")

# Input/Output Directories (Temporary)
temp_input_dir = os.path.join(base_dir, "temp_inference_input")
temp_output_dir = os.path.join(base_dir, "temp_inference_output")

os.makedirs(temp_input_dir, exist_ok=True)
os.makedirs(temp_output_dir, exist_ok=True)

# Case Configuration
case_identifier = "Test_Case"
input_image_name = "tcga-09-0367.nii.gz"
ground_truth_name = "tcga-09-0367-seg.nii.gz"

input_image_path = os.path.join(base_dir, input_image_name)
ground_truth_path = os.path.join(base_dir, ground_truth_name)
formatted_input_path = os.path.join(temp_input_dir, f"{case_identifier}_0000.nii.gz")

# --- File Verification ---
if not os.path.exists(input_image_path):
    print(f"ERROR: Input file '{input_image_name}' not found in {base_dir}.")
    if IS_COLAB:
        print("Please upload the file to your Google Drive folder or the Colab workspace.")
else:
    print(f"Found input image: {input_image_path}")
    shutil.copy(input_image_path, formatted_input_path)

# Check for Model
model_check_path = os.path.join(local_results_folder, "nnUNet", "3d_fullres", "Task001_OVARIAN")
if not os.path.exists(model_check_path):
    print(f"WARNING: Model folder not found at {model_check_path}")
    if IS_COLAB:
        print("Please ensure your Google Drive paths are correct or upload the 'nnunet_trained_models' folder.")
else:
    print("Model folder found.")

# Check GPU
if torch.cuda.is_available():
    print(f"GPU Detected: {torch.cuda.get_device_name(0)}")
    use_gpu = True
else:
    print("No GPU detected. Using CPU (this will be slower).")
    use_gpu = False

In [None]:
# --- Set Environment Variables ---
os.environ["RESULTS_FOLDER"] = local_results_folder
os.environ["nnUNet_raw_data_base"] = os.path.join(local_results_folder, "nnUNet_raw_data_base")
os.environ["nnUNet_preprocessed"] = os.path.join(local_results_folder, "nnUNet_preprocessed")

print("Environment variables set.")

In [None]:
# --- Run Inference ---
task_id = "001"
config = "3d_fullres"

# Construct Command
cmd = [
    "nnUNet_predict",
    "-i", temp_input_dir,
    "-o", temp_output_dir,
    "-t", task_id,
    "-m", config
]

# Optimization Flags
if not use_gpu:
    # CPU Optimizations (Fast Mode)
    print("Optimizing for CPU: Single fold, No TTA.")
    cmd.extend(["-f", "0", "--disable_tta"])
else:
    # GPU: Options
    # Uncomment detection to use full ensemble if desired, but single fold is usually sufficient for demos
    print("GPU enabled. Using standard inference (Fold 0 + TTA Disabled for speed, remove flags for full quality).")
    # You can remove these flags on GPU if you want max accuracy (but slower)
    cmd.extend(["-f", "0", "--disable_tta"])

print(f"Running command: {' '.join(cmd)}")

try:
    if os.path.exists(formatted_input_path) and os.path.exists(model_check_path):
        # Stream output
        process = subprocess.Popen(
            cmd, 
            env=os.environ.copy(), 
            stdout=subprocess.PIPE, 
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1
        )
        
        print("\n--- Inference Output Stream ---")
        for line in iter(process.stdout.readline, ''):
            print(line, end='')
        
        process.stdout.close()
        return_code = process.wait()
        
        if return_code == 0:
            print("\nInference completed successfully.")
        else:
            print(f"\nError: Process finished with exit code {return_code}")
    else:
        print("Skipping inference: missing input file or model folder.")
except Exception as e:
    print(f"Error during execution: {e}")

In [None]:
# --- Calculate Dice Score ---
def dice_coefficient(y_true, y_pred):
    smooth = 1e-6
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)

if not os.path.exists(ground_truth_path):
    print(f"Ground truth file '{ground_truth_name}' not found.")
else:
    print(f"Loading Ground Truth from {ground_truth_path}...")
    gt_nii = nib.load(ground_truth_path)
    gt_data = gt_nii.get_fdata()

    prediction_filename = f"{case_identifier}.nii.gz"
    prediction_path = os.path.join(temp_output_dir, prediction_filename)

    if not os.path.exists(prediction_path):
        print(f"Error: Prediction file not found at {prediction_path}")
    else:
        print(f"Loading Prediction from {prediction_path}...")
        pred_nii = nib.load(prediction_path)
        pred_data = pred_nii.get_fdata()

        if gt_data.shape != pred_data.shape:
            print(f"Warning: Shape mismatch! GT: {gt_data.shape}, Pred: {pred_data.shape}")
        else:
            unique_labels = np.unique(gt_data)
            print(f"Unique labels in GT: {unique_labels}")

            for label in unique_labels:
                if label == 0: continue 
                
                gt_binary = (gt_data == label).astype(float)
                pred_binary = (pred_data == label).astype(float)
                
                score = dice_coefficient(gt_binary, pred_binary)
                print(f"Dice Score for Label {int(label)}: {score:.4f}")

In [None]:
# --- Cleanup (Optional) ---
# shutil.rmtree(temp_input_dir)
# shutil.rmtree(temp_output_dir)
# print("Temporary directories removed.")