# TR2StrokeSeg: nn-UNet Training and Inference Example

This notebook demonstrates the complete workflow for training nn-UNet on Atlas2 dataset and testing on other datasets.

## 1. Setup Environment

In [None]:
import os
import sys
from pathlib import Path

# Set nn-UNet environment variables
os.environ['nnUNet_raw'] = '/path/to/nnUNet_raw'
os.environ['nnUNet_preprocessed'] = '/path/to/nnUNet_preprocessed'
os.environ['nnUNet_results'] = '/path/to/nnUNet_results'

print("Environment variables set:")
print(f"  nnUNet_raw: {os.environ['nnUNet_raw']}")
print(f"  nnUNet_preprocessed: {os.environ['nnUNet_preprocessed']}")
print(f"  nnUNet_results: {os.environ['nnUNet_results']}")

## 2. Prepare Atlas2 Dataset

In [None]:
from src.data_preparation.prepare_atlas2 import prepare_atlas2_dataset

# Set paths
atlas2_dir = '/path/to/atlas2/dataset'
output_dir = os.environ['nnUNet_raw']
dataset_id = 1

# Prepare dataset
prepare_atlas2_dataset(atlas2_dir, output_dir, dataset_id)

## 3. Plan and Preprocess

In [None]:
# Run preprocessing (this can take some time)
!nnUNetv2_plan_and_preprocess -d 1 --verify_dataset_integrity

## 4. Train Model

In [None]:
# Train fold 0 (this will take several hours/days)
# For demonstration, we'll just show the command
# Uncomment to actually run training

# !nnUNetv2_train 1 3d_fullres 0

print("Training command:")
print("nnUNetv2_train 1 3d_fullres 0")
print("\nNote: Training takes considerable time. Run this in a terminal or use the training script.")

## 5. Run Inference on Test Dataset

In [None]:
from src.inference.predict import predict_on_dataset

# Set paths
test_dataset_dir = '/path/to/test/dataset'
output_dir = '/path/to/output'
model_dir = os.environ['nnUNet_results']

# Run prediction
predict_on_dataset(
    test_dataset_dir=test_dataset_dir,
    output_dir=output_dir,
    model_dir=model_dir,
    dataset_id=1,
    configuration='3d_fullres',
    folds='0'  # Use single fold for quick testing
)

## 6. Evaluate Predictions (if ground truth available)

In [None]:
from src.inference.evaluate import evaluate_dataset

# Set paths
pred_dir = '/path/to/output/final_predictions'
gt_dir = '/path/to/ground/truth'
output_csv = '/path/to/evaluation_results.csv'

# Evaluate predictions
evaluate_dataset(pred_dir, gt_dir, output_csv)

## 7. Visualize Results (Optional)

In [None]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np

def visualize_prediction(image_path, pred_path, gt_path=None, slice_idx=None):
    """
    Visualize image, prediction, and optionally ground truth.
    """
    # Load images
    img = nib.load(image_path).get_fdata()
    pred = nib.load(pred_path).get_fdata()
    
    # Select middle slice if not specified
    if slice_idx is None:
        slice_idx = img.shape[2] // 2
    
    # Create figure
    if gt_path:
        gt = nib.load(gt_path).get_fdata()
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    else:
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    # Plot image
    axes[0].imshow(img[:, :, slice_idx].T, cmap='gray', origin='lower')
    axes[0].set_title('Input Image')
    axes[0].axis('off')
    
    # Plot prediction
    axes[1].imshow(img[:, :, slice_idx].T, cmap='gray', origin='lower')
    axes[1].imshow(pred[:, :, slice_idx].T, cmap='Reds', alpha=0.5, origin='lower')
    axes[1].set_title('Prediction')
    axes[1].axis('off')
    
    # Plot ground truth if available
    if gt_path:
        axes[2].imshow(img[:, :, slice_idx].T, cmap='gray', origin='lower')
        axes[2].imshow(gt[:, :, slice_idx].T, cmap='Greens', alpha=0.5, origin='lower')
        axes[2].set_title('Ground Truth')
        axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Example usage
# visualize_prediction(
#     image_path='/path/to/image.nii.gz',
#     pred_path='/path/to/prediction.nii.gz',
#     gt_path='/path/to/ground_truth.nii.gz'  # optional
# )