# NeSVoR2 Pipeline - Interactive Visualization Notebook

This notebook provides an interactive walkthrough of the NeSVoR2 MRI super-resolution pipeline with detailed visualizations at each phase.

## Pipeline Phases:
1. **Load Inputs** - Load stacks, masks, and models
2. **Segmentation** - 2D fetal brain masking
3. **Bias Field Correction** - N4 algorithm
4. **Assessment** - Stack quality metrics
5. **Registration** - Motion correction (SVoRT)
6. **Reconstruction** - NeSVoR training
7. **Sampling** - Generate high-resolution volume
8. **Save Outputs** - Export results

## Setup and Imports

In [None]:
import os
import sys
import json
import logging
from argparse import Namespace
from typing import Dict, List, Any

import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import gridspec
import seaborn as sns
from tqdm.notebook import tqdm

# Set up plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

# Import NeSVoR modules
from utils import (
    load_stack,
    load_slices,
    save_slices,
    merge_args,
    load_mask,
    Volume,
    Stack,
    Slice,
)
from model.models import INR
from model.train import train
from model.sample import sample_volume, sample_slices
from preprocess import (
    stack_intersect,
    otsu_thresholding,
    thresholding,
    n4_bias_field_correction,
    brain_segmentation,
    assess,
)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

print("✓ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Visualization Helper Functions

In [None]:
def show_slice_montage(stack: Stack, title: str = "Stack Slices", num_slices: int = 9, cmap: str = 'gray'):
    """
    Display a montage of slices from a stack.
    
    Args:
        stack: Stack object
        title: Plot title
        num_slices: Number of slices to display
        cmap: Colormap for display
    """
    n_total = stack.slices.shape[0]
    indices = np.linspace(0, n_total - 1, min(num_slices, n_total), dtype=int)
    
    rows = int(np.ceil(np.sqrt(len(indices))))
    cols = int(np.ceil(len(indices) / rows))
    
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axes = axes.flatten() if len(indices) > 1 else [axes]
    
    for idx, ax in enumerate(axes):
        if idx < len(indices):
            slice_idx = indices[idx]
            slice_data = stack.slices[slice_idx, 0].cpu().numpy()
            
            ax.imshow(slice_data.T, cmap=cmap, origin='lower')
            ax.set_title(f'Slice {slice_idx}/{n_total-1}')
            ax.axis('off')
        else:
            ax.axis('off')
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def show_3d_views(volume_data: torch.Tensor, title: str = "3D Volume Views", masked: bool = False):
    """
    Display axial, sagittal, and coronal views of a 3D volume.
    
    Args:
        volume_data: 3D tensor (x, y, z)
        title: Plot title
        masked: Whether to use mask for background
    """
    if isinstance(volume_data, torch.Tensor):
        volume_data = volume_data.cpu().numpy()
    
    # Get middle slices
    mid_x = volume_data.shape[0] // 2
    mid_y = volume_data.shape[1] // 2
    mid_z = volume_data.shape[2] // 2
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Sagittal view (YZ plane)
    axes[0].imshow(volume_data[mid_x, :, :].T, cmap='gray', origin='lower')
    axes[0].set_title('Sagittal View (YZ)')
    axes[0].axis('off')
    
    # Coronal view (XZ plane)
    axes[1].imshow(volume_data[:, mid_y, :].T, cmap='gray', origin='lower')
    axes[1].set_title('Coronal View (XZ)')
    axes[1].axis('off')
    
    # Axial view (XY plane)
    axes[2].imshow(volume_data[:, :, mid_z].T, cmap='gray', origin='lower')
    axes[2].set_title('Axial View (XY)')
    axes[2].axis('off')
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def show_stack_info(stack: Stack, stack_idx: int = 0):
    """
    Display detailed information about a stack.
    
    Args:
        stack: Stack object
        stack_idx: Index of the stack
    """
    print(f"\n{'='*60}")
    print(f"Stack {stack_idx} Information")
    print(f"{'='*60}")
    print(f"Shape (slices):        {stack.slices.shape}")
    print(f"Data type:             {stack.slices.dtype}")
    print(f"Device:                {stack.slices.device}")
    print(f"Resolution (x, y):     {stack.resolution_x:.3f} mm, {stack.resolution_y:.3f} mm")
    print(f"Slice thickness:       {stack.thickness:.3f} mm")
    print(f"Intensity range:       [{stack.slices.min().item():.2f}, {stack.slices.max().item():.2f}]")
    print(f"Intensity mean:        {stack.slices.mean().item():.2f}")
    print(f"Intensity std:         {stack.slices.std().item():.2f}")
    
    if stack.mask is not None:
        print(f"Mask shape:            {stack.mask.shape}")
        print(f"Masked voxels:         {stack.mask.sum().item()} / {stack.mask.numel()} ({100*stack.mask.sum().item()/stack.mask.numel():.1f}%)")
    else:
        print(f"Mask:                  None")
    
    print(f"Transformation shape:  {stack.transformation.shape if hasattr(stack, 'transformation') else 'N/A'}")
    print(f"{'='*60}\n")


def show_intensity_histogram(stack: Stack, title: str = "Intensity Distribution", bins: int = 50):
    """
    Display intensity histogram for a stack.
    
    Args:
        stack: Stack object
        title: Plot title
        bins: Number of histogram bins
    """
    data = stack.slices[stack.slices > 0].cpu().numpy().flatten()
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 4))
    ax.hist(data, bins=bins, alpha=0.7, edgecolor='black')
    ax.set_xlabel('Intensity')
    ax.set_ylabel('Frequency')
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


def compare_stacks(stack1: Stack, stack2: Stack, title1: str = "Before", title2: str = "After", slice_idx: int = None):
    """
    Side-by-side comparison of two stacks.
    
    Args:
        stack1: First stack
        stack2: Second stack
        title1: Title for first stack
        title2: Title for second stack
        slice_idx: Specific slice to compare (default: middle slice)
    """
    if slice_idx is None:
        slice_idx = stack1.slices.shape[0] // 2
    
    slice1 = stack1.slices[slice_idx, 0].cpu().numpy()
    slice2 = stack2.slices[slice_idx, 0].cpu().numpy()
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Before
    axes[0].imshow(slice1.T, cmap='gray', origin='lower')
    axes[0].set_title(f'{title1} (Slice {slice_idx})')
    axes[0].axis('off')
    
    # After
    axes[1].imshow(slice2.T, cmap='gray', origin='lower')
    axes[1].set_title(f'{title2} (Slice {slice_idx})')
    axes[1].axis('off')
    
    # Difference
    diff = slice2 - slice1
    im = axes[2].imshow(diff.T, cmap='RdBu_r', origin='lower', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
    axes[2].set_title('Difference')
    axes[2].axis('off')
    plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()


def plot_training_metrics(metrics_dict: Dict[str, List[float]]):
    """
    Plot training metrics over iterations.
    
    Args:
        metrics_dict: Dictionary of metric names to values
    """
    n_metrics = len(metrics_dict)
    fig, axes = plt.subplots(1, n_metrics, figsize=(5 * n_metrics, 4))
    
    if n_metrics == 1:
        axes = [axes]
    
    for idx, (name, values) in enumerate(metrics_dict.items()):
        axes[idx].plot(values)
        axes[idx].set_xlabel('Iteration')
        axes[idx].set_ylabel(name)
        axes[idx].set_title(f'{name} over Training')
        axes[idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


print("✓ Visualization functions defined")

## Configuration

Set up the pipeline parameters. You can modify these values for your specific use case.

In [None]:
# Create configuration
args = Namespace(
    # Input files (MODIFY THESE PATHS)
    input_stacks=['stack1.nii.gz', 'stack2.nii.gz', 'stack3.nii.gz'],
    stack_masks=None,  # Optional: ['mask1.nii.gz', 'mask2.nii.gz', 'mask3.nii.gz']
    thicknesses=[3.0, 3.0, 3.0],  # Slice thickness in mm
    volume_mask=None,  # Optional: 'volume_mask.nii.gz'
    stacks_intersection=False,  # Create volume mask from stack intersection
    
    # Preprocessing options
    segmentation=False,  # Enable 2D brain segmentation
    bias_field_correction=False,  # Enable N4 bias correction
    n_levels_bias=0,
    skip_assessment=True,  # Skip quality assessment
    metric='none',  # Options: 'ncc', 'matrix-rank', 'volume', 'iqa2d', 'iqa3d', 'none'
    filter_method='none',  # Options: 'top', 'bottom', 'threshold', 'percentage', 'none'
    cutoff=0.0,
    batch_size_assess=8,
    no_augmentation_assess=False,
    background_threshold=0.0,
    otsu_thresholding=False,
    
    # Registration options
    registration=True,  # Enable registration
    svort=True,  # Use SVoRT
    svort_version='v2',  # 'v1' or 'v2'
    use_vvr=False,  # Use traditional stack registration
    force_vvr=False,
    force_scanner=False,
    
    # Training options
    skip_reconstruction=False,
    n_iter=1000,  # Number of training iterations (increase for better quality)
    n_epochs=None,
    batch_size=4096,
    learning_rate=0.01,
    single_precision=False,  # Use FP32 instead of FP16
    gamma=0.33,
    milestones=[0.5, 0.75, 0.9],
    n_samples=128 * 2,
    
    # Model architecture
    coarsest_resolution=2.0,
    finest_resolution=0.5,
    level_scale=1.39,
    n_features_per_level=2,
    log2_hashmap_size=19,
    width=64,
    depth=2,
    n_features_z=15,
    n_features_slice=16,
    no_transformation_optimization=False,
    no_slice_scale=False,
    no_pixel_variance=False,
    no_slice_variance=False,
    
    # Regularization
    weight_transformation=0.1,
    weight_bias=0.1,
    weight_image=0.01,
    image_regularization='TV',  # 'none', 'TV', 'edge', 'L2'
    weight_deform=0.1,
    delta=0.2,
    img_reg_autodiff=False,
    
    # Deformation options
    deformable=False,
    n_features_deform=8,
    n_features_per_level_deform=4,
    level_scale_deform=1.3819,
    coarsest_resolution_deform=32.0,
    finest_resolution_deform=8.0,
    
    # Sampling options
    output_resolution=0.8,  # Output volume resolution in mm
    n_inference_samples=128,
    inference_batch_size=1024,
    output_intensity_mean=None,
    with_background=False,
    
    # Output files
    output_volume='output_volume.nii.gz',
    output_model='output_model.pt',
    output_slices=None,
    simulated_slices=None,
    output_corrected_stacks=None,
    output_stack_masks=None,
    output_json='output_results.json',
    
    # System options
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    debug=False,
    dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)

print("Configuration created:")
print(f"  Device: {args.device}")
print(f"  Input stacks: {len(args.input_stacks)}")
print(f"  Training iterations: {args.n_iter}")
print(f"  Output resolution: {args.output_resolution} mm")

---
# Phase 1: Load Inputs

Load input MRI stacks, masks, and any pre-trained models.

In [None]:
print("="*80)
print("PHASE 1: LOADING INPUTS")
print("="*80)

input_dict = {}

# Load input stacks
if args.input_stacks is not None:
    input_stacks = []
    print(f"\nLoading {len(args.input_stacks)} input stacks...")
    
    for i, f in enumerate(args.input_stacks):
        if not os.path.exists(f):
            print(f"⚠ Warning: File not found: {f}")
            continue
            
        stack = load_stack(
            f,
            args.stack_masks[i] if args.stack_masks is not None else None,
            device=args.device,
        )
        
        if args.thicknesses is not None:
            stack.thickness = args.thicknesses[i]
        
        input_stacks.append(stack)
        print(f"✓ Loaded stack {i + 1}: {f}")
    
    input_dict["input_stacks"] = input_stacks
    print(f"\n✓ Successfully loaded {len(input_stacks)} stacks")

# Load pre-trained model if provided
if hasattr(args, 'input_model') and args.input_model is not None:
    if os.path.exists(args.input_model):
        print(f"\nLoading pre-trained model from {args.input_model}")
        cp = torch.load(args.input_model, map_location=args.device)
        input_dict["model"] = INR(cp["model"]["bounding_box"], cp["args"])
        input_dict["model"].load_state_dict(cp["model"])
        input_dict["mask"] = cp["mask"]
        print("✓ Model loaded successfully")

print("\n" + "="*80)

### Visualize Input Stacks

In [None]:
if "input_stacks" in input_dict:
    for i, stack in enumerate(input_dict["input_stacks"]):
        print(f"\n{'='*80}")
        print(f"STACK {i+1} ANALYSIS")
        print(f"{'='*80}")
        
        # Show detailed information
        show_stack_info(stack, i)
        
        # Show slice montage
        show_slice_montage(stack, title=f"Stack {i+1}: Slice Montage", num_slices=9)
        
        # Show intensity histogram
        show_intensity_histogram(stack, title=f"Stack {i+1}: Intensity Distribution")
        
        # Show 3D views if stack has enough slices
        if stack.slices.shape[0] >= 10:
            volume_data = stack.slices[:, 0, :, :]
            show_3d_views(volume_data, title=f"Stack {i+1}: 3D Views")
else:
    print("No input stacks to visualize")

---
# Phase 2: Segmentation

Apply 2D fetal brain segmentation to create masks for each slice.

In [None]:
print("="*80)
print("PHASE 2: SEGMENTATION")
print("="*80)

if args.segmentation and "input_stacks" in input_dict:
    input_stacks = input_dict["input_stacks"]
    print(f"\nRunning brain segmentation on {len(input_stacks)} stacks...")
    
    # Store original stacks for comparison
    original_stacks = [stack.clone() for stack in input_stacks]
    
    # Run segmentation
    segmented_stacks = brain_segmentation(input_stacks, args)
    
    input_dict["input_stacks"] = segmented_stacks
    input_dict["original_stacks"] = original_stacks
    
    print("✓ Segmentation completed")
else:
    print("\nSegmentation skipped (not enabled)")

print("\n" + "="*80)

### Visualize Segmentation Results

In [None]:
if args.segmentation and "input_stacks" in input_dict and "original_stacks" in input_dict:
    for i, (orig_stack, seg_stack) in enumerate(zip(input_dict["original_stacks"], input_dict["input_stacks"])):
        print(f"\nStack {i+1} - Segmentation Comparison")
        
        # Compare before and after
        compare_stacks(
            orig_stack, 
            seg_stack, 
            title1="Before Segmentation", 
            title2="After Segmentation"
        )
        
        # Show mask
        if seg_stack.mask is not None:
            mid_slice = seg_stack.mask.shape[0] // 2
            plt.figure(figsize=(6, 6))
            plt.imshow(seg_stack.mask[mid_slice, 0].cpu().numpy().T, cmap='binary', origin='lower')
            plt.title(f"Stack {i+1}: Segmentation Mask (Slice {mid_slice})")
            plt.axis('off')
            plt.show()
else:
    print("No segmentation results to visualize")

---
# Phase 3: Bias Field Correction

Apply N4 bias field correction to remove intensity inhomogeneities.

In [None]:
print("="*80)
print("PHASE 3: BIAS FIELD CORRECTION")
print("="*80)

if args.bias_field_correction and "input_stacks" in input_dict:
    input_stacks = input_dict["input_stacks"]
    print(f"\nRunning N4 bias field correction on {len(input_stacks)} stacks...")
    
    # Store uncorrected stacks for comparison
    uncorrected_stacks = [stack.clone() for stack in input_stacks]
    
    corrected_stacks = []
    for i, stack in enumerate(tqdm(input_stacks, desc="Bias correction")):
        corrected_stack = n4_bias_field_correction(stack)
        corrected_stacks.append(corrected_stack)
    
    input_dict["input_stacks"] = corrected_stacks
    input_dict["uncorrected_stacks"] = uncorrected_stacks
    input_dict["output_corrected_stacks"] = corrected_stacks
    
    print("✓ Bias field correction completed")
else:
    print("\nBias field correction skipped (not enabled)")

print("\n" + "="*80)

### Visualize Bias Correction Results

In [None]:
if args.bias_field_correction and "uncorrected_stacks" in input_dict:
    for i, (uncorr_stack, corr_stack) in enumerate(zip(input_dict["uncorrected_stacks"], input_dict["input_stacks"])):
        print(f"\nStack {i+1} - Bias Correction Comparison")
        
        # Compare before and after
        compare_stacks(
            uncorr_stack, 
            corr_stack, 
            title1="Before Correction", 
            title2="After Correction"
        )
        
        # Show histograms
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        
        data_before = uncorr_stack.slices[uncorr_stack.slices > 0].cpu().numpy().flatten()
        data_after = corr_stack.slices[corr_stack.slices > 0].cpu().numpy().flatten()
        
        axes[0].hist(data_before, bins=50, alpha=0.7, edgecolor='black')
        axes[0].set_title(f"Stack {i+1}: Before Correction")
        axes[0].set_xlabel('Intensity')
        axes[0].set_ylabel('Frequency')
        
        axes[1].hist(data_after, bins=50, alpha=0.7, edgecolor='black', color='orange')
        axes[1].set_title(f"Stack {i+1}: After Correction")
        axes[1].set_xlabel('Intensity')
        axes[1].set_ylabel('Frequency')
        
        plt.tight_layout()
        plt.show()
else:
    print("No bias correction results to visualize")

---
# Phase 4: Quality Assessment

Assess the quality of input stacks using various metrics.

In [None]:
print("="*80)
print("PHASE 4: QUALITY ASSESSMENT")
print("="*80)

if not args.skip_assessment and "input_stacks" in input_dict:
    input_stacks = input_dict["input_stacks"]
    print(f"\nAssessing quality of {len(input_stacks)} stacks...")
    print(f"Metric: {args.metric}")
    print(f"Filter method: {args.filter_method}")
    
    augmentation = not args.no_augmentation_assess
    
    filtered_stacks, assessment_results = assess(
        input_stacks,
        args.metric,
        args.filter_method,
        args.cutoff or 0.0,
        args.batch_size_assess,
        augmentation,
        args.device,
    )
    
    input_dict["input_stacks"] = filtered_stacks
    input_dict["assessment_results"] = assessment_results
    
    print("\n✓ Assessment completed")
    print(f"\nQuality Scores:")
    for i, score in enumerate(assessment_results):
        print(f"  Stack {i+1}: {score}")
else:
    print("\nQuality assessment skipped")

print("\n" + "="*80)

### Visualize Assessment Results

In [None]:
if "assessment_results" in input_dict:
    results = input_dict["assessment_results"]
    
    # Bar plot of quality scores
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    stack_indices = np.arange(len(results))
    
    ax.bar(stack_indices, results, alpha=0.7, edgecolor='black')
    ax.set_xlabel('Stack Index')
    ax.set_ylabel(f'Quality Score ({args.metric})')
    ax.set_title('Stack Quality Assessment')
    ax.set_xticks(stack_indices)
    ax.set_xticklabels([f'Stack {i+1}' for i in stack_indices])
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
else:
    print("No assessment results to visualize")

---
# Phase 5: Registration (Motion Correction)

Apply SVoRT registration to correct for motion between slices.

In [None]:
print("="*80)
print("PHASE 5: REGISTRATION")
print("="*80)

if args.registration and "input_stacks" in input_dict:
    input_stacks = input_dict["input_stacks"]
    print(f"\nRunning registration on {len(input_stacks)} stacks...")
    print(f"SVoRT: {args.svort}")
    print(f"SVoRT version: {args.svort_version}")
    
    if args.svort:
        from preprocess.svort import svort_predict
        
        registered_slices = svort_predict(
            dataset=input_stacks,
            device=args.device,
            svort_version=args.svort_version,
            svort=True,
            vvr=args.use_vvr,
            force_vvr=args.force_vvr,
            force_scanner=args.force_scanner,
        )
        
        input_dict["input_slices"] = registered_slices
        print(f"\n✓ SVoRT registration completed: {len(registered_slices)} slices")
    else:
        # Convert stacks to slices without registration
        slices = []
        for stack in input_stacks:
            for i in range(stack.slices.shape[0]):
                slice_img = stack.slices[i]
                slice_mask = stack.mask[i] if stack.mask is not None else None
                slice_obj = Slice(
                    slice_img,
                    slice_mask,
                    stack.transformation[i],
                    stack.resolution_x,
                    stack.resolution_y,
                    stack.thickness,
                )
                slices.append(slice_obj)
        input_dict["input_slices"] = slices
        print(f"\n✓ Created {len(slices)} slices from stacks (no registration)")
else:
    print("\nRegistration skipped")
    # Still need to convert to slices
    if "input_stacks" in input_dict and "input_slices" not in input_dict:
        slices = []
        for stack in input_dict["input_stacks"]:
            for i in range(stack.slices.shape[0]):
                slice_img = stack.slices[i]
                slice_mask = stack.mask[i] if stack.mask is not None else None
                slice_obj = Slice(
                    slice_img,
                    slice_mask,
                    stack.transformation[i],
                    stack.resolution_x,
                    stack.resolution_y,
                    stack.thickness,
                )
                slices.append(slice_obj)
        input_dict["input_slices"] = slices
        print(f"\n✓ Created {len(slices)} slices from stacks")

print("\n" + "="*80)

### Visualize Registration Results

In [None]:
if "input_slices" in input_dict:
    slices = input_dict["input_slices"]
    
    print(f"\nTotal slices after registration: {len(slices)}")
    print(f"\nSlice Information:")
    print(f"  First slice shape: {slices[0].image.shape}")
    print(f"  Resolution (x, y): {slices[0].resolution_x:.3f} mm, {slices[0].resolution_y:.3f} mm")
    print(f"  Thickness: {slices[0].thickness:.3f} mm")
    print(f"  Data type: {slices[0].image.dtype}")
    print(f"  Device: {slices[0].image.device}")
    
    # Show sample slices
    num_samples = min(9, len(slices))
    sample_indices = np.linspace(0, len(slices) - 1, num_samples, dtype=int)
    
    rows = int(np.ceil(np.sqrt(num_samples)))
    cols = int(np.ceil(num_samples / rows))
    
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axes = axes.flatten() if num_samples > 1 else [axes]
    
    for idx, ax in enumerate(axes):
        if idx < num_samples:
            slice_idx = sample_indices[idx]
            slice_data = slices[slice_idx].image[0].cpu().numpy()
            
            ax.imshow(slice_data.T, cmap='gray', origin='lower')
            ax.set_title(f'Slice {slice_idx}/{len(slices)-1}')
            ax.axis('off')
        else:
            ax.axis('off')
    
    plt.suptitle("Registered Slices Sample", fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("No registered slices to visualize")

---
# Phase 6: Reconstruction (Training)

Train the NeSVoR model using the registered slices.

In [None]:
print("="*80)
print("PHASE 6: RECONSTRUCTION (TRAINING)")
print("="*80)

if not args.skip_reconstruction and "input_slices" in input_dict:
    input_slices = input_dict["input_slices"]
    
    print(f"\nTraining NeSVoR model on {len(input_slices)} slices")
    print(f"Training iterations: {args.n_iter}")
    print(f"Batch size: {args.batch_size}")
    print(f"Learning rate: {args.learning_rate}")
    print(f"Device: {args.device}")
    print(f"Precision: {'FP16' if not args.single_precision else 'FP32'}")
    print("\nStarting training...\n")
    
    # Add progress bar for notebook
    args.progress_bar = tqdm(total=args.n_iter, desc="Training", unit="iter")
    
    try:
        # Train model
        model, output_slices, mask = train(input_slices, args)
        
        input_dict["output_model"] = model
        input_dict["output_slices"] = output_slices
        input_dict["mask"] = mask
        
        print("\n✓ Training completed successfully")
    finally:
        if hasattr(args, 'progress_bar') and args.progress_bar is not None:
            args.progress_bar.close()
            delattr(args, 'progress_bar')
else:
    print("\nReconstruction skipped")

print("\n" + "="*80)

### Visualize Model Information

In [None]:
if "output_model" in input_dict:
    model = input_dict["output_model"]
    mask = input_dict["mask"]
    
    print("\nTrained Model Information:")
    print(f"  Model type: {type(model).__name__}")
    print(f"  Bounding box: {model.bounding_box}")
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    
    print(f"\nMask Information:")
    print(f"  Mask shape: {mask.image.shape}")
    print(f"  Mask resolution: {mask.resolution_x:.3f} mm")
    print(f"  Masked voxels: {mask.image.sum().item():,}")
    
    # Visualize mask
    show_3d_views(mask.image[0], title="Reconstruction Mask")
else:
    print("No trained model to visualize")

---
# Phase 7: Sampling

Sample high-resolution volume from the trained implicit neural representation.

In [None]:
print("="*80)
print("PHASE 7: SAMPLING")
print("="*80)

if args.output_volume and "output_model" in input_dict:
    model = input_dict["output_model"]
    mask = input_dict["mask"]
    
    print(f"\nSampling volume at {args.output_resolution} mm resolution")
    print(f"Number of PSF samples: {args.n_inference_samples}")
    print(f"Batch size: {args.inference_batch_size}")
    print("\nSampling...")
    
    # Sample volume
    output_volume = sample_volume(
        model,
        mask,
        psf_resolution=args.output_resolution,
        batch_size=args.inference_batch_size,
        n_samples=args.n_inference_samples,
    )
    
    input_dict["output_volume"] = output_volume
    
    print(f"\n✓ Volume sampled: shape={output_volume.image.shape}")
    print(f"  Resolution: {output_volume.resolution_x:.3f} mm")
    print(f"  Intensity range: [{output_volume.image.min().item():.2f}, {output_volume.image.max().item():.2f}]")
    
    # Sample simulated slices if requested
    if args.simulated_slices and "output_slices" in input_dict:
        print("\nSampling simulated slices...")
        output_slices = input_dict["output_slices"]
        
        simulated = sample_slices(
            model,
            output_slices,
            mask,
            output_psf_factor=1.0,
            n_samples=args.n_inference_samples,
        )
        
        input_dict["simulated_slices"] = simulated
        print(f"✓ Sampled {len(simulated)} simulated slices")
else:
    print("\nSampling skipped (no output volume requested or model not trained)")

print("\n" + "="*80)

### Visualize Output Volume

In [None]:
if "output_volume" in input_dict:
    volume = input_dict["output_volume"]
    
    print("\nOutput Volume Information:")
    print(f"  Shape: {volume.image.shape}")
    print(f"  Resolution: {volume.resolution_x:.3f} mm")
    print(f"  Data type: {volume.image.dtype}")
    print(f"  Intensity range: [{volume.image.min().item():.2f}, {volume.image.max().item():.2f}]")
    print(f"  Mean intensity: {volume.image.mean().item():.2f}")
    
    # Show 3D orthogonal views
    show_3d_views(volume.image[0], title="Output Volume: Orthogonal Views")
    
    # Show intensity histogram
    data = volume.image[volume.image > 0].cpu().numpy().flatten()
    fig, ax = plt.subplots(1, 1, figsize=(10, 4))
    ax.hist(data, bins=50, alpha=0.7, edgecolor='black', color='green')
    ax.set_xlabel('Intensity')
    ax.set_ylabel('Frequency')
    ax.set_title('Output Volume: Intensity Distribution')
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    # Interactive slider for exploring slices
    from ipywidgets import interact, IntSlider
    
    def show_slice(axis, slice_idx):
        """Interactive slice viewer."""
        vol_data = volume.image[0].cpu().numpy()
        
        if axis == 0:  # Sagittal
            slice_data = vol_data[slice_idx, :, :]
            title = f'Sagittal (YZ) - Slice {slice_idx}/{vol_data.shape[0]-1}'
        elif axis == 1:  # Coronal
            slice_data = vol_data[:, slice_idx, :]
            title = f'Coronal (XZ) - Slice {slice_idx}/{vol_data.shape[1]-1}'
        else:  # Axial
            slice_data = vol_data[:, :, slice_idx]
            title = f'Axial (XY) - Slice {slice_idx}/{vol_data.shape[2]-1}'
        
        plt.figure(figsize=(8, 8))
        plt.imshow(slice_data.T, cmap='gray', origin='lower')
        plt.title(title)
        plt.axis('off')
        plt.show()
    
    # Create interactive viewer
    vol_shape = volume.image[0].shape
    print("\n📊 Interactive Slice Viewer:")
    interact(
        show_slice,
        axis=IntSlider(min=0, max=2, step=1, value=2, description='Axis:'),
        slice_idx=IntSlider(min=0, max=max(vol_shape)-1, step=1, value=max(vol_shape)//2, description='Slice:')
    )
else:
    print("No output volume to visualize")

### Compare Input vs Output

In [None]:
if "input_stacks" in input_dict and "output_volume" in input_dict:
    print("\n📊 Input vs Output Comparison")
    
    # Get first input stack for comparison
    input_stack = input_dict["input_stacks"][0]
    output_volume = input_dict["output_volume"]
    
    print(f"\nInput Stack (Stack 0):")
    print(f"  Shape: {input_stack.slices.shape}")
    print(f"  Resolution: {input_stack.resolution_x:.3f} mm x {input_stack.resolution_y:.3f} mm")
    print(f"  Thickness: {input_stack.thickness:.3f} mm")
    
    print(f"\nOutput Volume:")
    print(f"  Shape: {output_volume.image.shape}")
    print(f"  Resolution: {output_volume.resolution_x:.3f} mm (isotropic)")
    
    # Show side-by-side comparison
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Input stack middle slice
    mid_slice = input_stack.slices.shape[0] // 2
    axes[0].imshow(input_stack.slices[mid_slice, 0].cpu().numpy().T, cmap='gray', origin='lower')
    axes[0].set_title(f'Input Stack 0\n(Slice {mid_slice}, {input_stack.thickness:.1f}mm thick)')
    axes[0].axis('off')
    
    # Output volume middle slice
    mid_z = output_volume.image.shape[3] // 2
    axes[1].imshow(output_volume.image[0, :, :, mid_z].cpu().numpy().T, cmap='gray', origin='lower')
    axes[1].set_title(f'Output Volume\n(Axial slice, {output_volume.resolution_x:.1f}mm isotropic)')
    axes[1].axis('off')
    
    plt.suptitle('Resolution Comparison: Input vs Output', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("Cannot compare - missing input or output data")

---
# Phase 8: Save Outputs

Save all generated outputs to disk.

In [None]:
print("="*80)
print("PHASE 8: SAVE OUTPUTS")
print("="*80)

# Save output volume
if args.output_volume and "output_volume" in input_dict:
    print(f"\nSaving output volume to {args.output_volume}")
    volume = input_dict["output_volume"]
    
    # Rescale if requested
    if args.output_intensity_mean is not None:
        volume.rescale(args.output_intensity_mean)
    
    volume.save(args.output_volume, masked=not args.with_background)
    print(f"✓ Volume saved")

# Save trained model
if args.output_model and "output_model" in input_dict:
    print(f"\nSaving model to {args.output_model}")
    torch.save(
        {
            "model": input_dict["output_model"].state_dict(),
            "mask": input_dict["mask"],
            "args": args,
        },
        args.output_model,
    )
    print(f"✓ Model saved")

# Save output slices
if args.output_slices and "output_slices" in input_dict:
    print(f"\nSaving output slices to {args.output_slices}")
    save_slices(args.output_slices, input_dict["output_slices"], sep=True)
    print(f"✓ Slices saved")

# Save simulated slices
if args.simulated_slices and "simulated_slices" in input_dict:
    print(f"\nSaving simulated slices to {args.simulated_slices}")
    save_slices(args.simulated_slices, input_dict["simulated_slices"], sep=False)
    print(f"✓ Simulated slices saved")

# Save JSON results
if args.output_json:
    print(f"\nSaving configuration and results to {args.output_json}")
    output_data = vars(args).copy()
    
    # Convert device to serializable format
    if "device" in output_data:
        output_data["device"] = str(output_data["device"])
    if "dtype" in output_data:
        output_data["dtype"] = str(output_data["dtype"])
    if "progress_bar" in output_data:
        del output_data["progress_bar"]
    
    # Add assessment results if available
    if "assessment_results" in input_dict:
        output_data["assessment_results"] = input_dict["assessment_results"]
    
    with open(args.output_json, "w") as f:
        json.dump(output_data, f, indent=2)
    print(f"✓ Results saved")

print("\n" + "="*80)
print("✓ PIPELINE COMPLETED SUCCESSFULLY")
print("="*80)

---
# Summary

Display a summary of the entire pipeline execution.

In [None]:
print("\n" + "="*80)
print("PIPELINE SUMMARY")
print("="*80)

print("\n📥 Input:")
if "input_stacks" in input_dict:
    print(f"  Number of stacks: {len(input_dict['input_stacks'])}")
    for i, stack in enumerate(input_dict["input_stacks"]):
        print(f"    Stack {i+1}: {stack.slices.shape[0]} slices, {stack.resolution_x:.2f}mm resolution")

print("\n⚙️  Processing:")
print(f"  Segmentation: {'✓ Applied' if args.segmentation else '✗ Skipped'}")
print(f"  Bias correction: {'✓ Applied' if args.bias_field_correction else '✗ Skipped'}")
print(f"  Assessment: {'✓ Applied' if not args.skip_assessment else '✗ Skipped'}")
print(f"  Registration: {'✓ Applied (SVoRT ' + args.svort_version + ')' if args.registration else '✗ Skipped'}")

if "input_slices" in input_dict:
    print(f"\n  Total slices for reconstruction: {len(input_dict['input_slices'])}")

if "output_model" in input_dict:
    model = input_dict["output_model"]
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\n🧠 Model:")
    print(f"  Training iterations: {args.n_iter}")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Precision: {'FP16' if not args.single_precision else 'FP32'}")

if "output_volume" in input_dict:
    volume = input_dict["output_volume"]
    print(f"\n📤 Output:")
    print(f"  Volume shape: {volume.image.shape}")
    print(f"  Resolution: {volume.resolution_x:.3f} mm (isotropic)")
    print(f"  File: {args.output_volume}")
    
    if args.output_model:
        print(f"  Model: {args.output_model}")

print("\n" + "="*80)
print("✅ All tasks completed!")
print("="*80)