# PFT_FEM Python API Demonstration

This notebook provides a comprehensive walkthrough of the **Posterior Fossa Tumor Finite Element Modeling (PFT_FEM)** Python API. The pipeline simulates tumor growth in the cerebellum region and generates synthetic MRI images.

## Pipeline Overview

The simulation consists of five main stages:

1. **Atlas Loading** - Load the SUIT cerebellar atlas
2. **Mesh Generation** - Create a tetrahedral FEM mesh
3. **Tumor Growth Simulation** - Solve reaction-diffusion equations with mechanical coupling
4. **MRI Image Generation** - Generate synthetic MRI sequences
5. **Results Export** - Save outputs in NIfTI format

```
SUIT Atlas → Mesh Generation → FEM Simulation → MRI Synthesis → File Output
    ↓              ↓                  ↓              ↓              ↓
 AtlasData    TetMesh         TumorState[]    Dict[seq→volume]  NIfTI Files
```

## Setup and Imports

First, let's import the necessary modules from the `pft_fem` package.

In [None]:
# Core imports
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# PFT_FEM API imports
from pft_fem import (
    # Atlas Loading
    SUITAtlasLoader,
    AtlasProcessor,
    
    # Mesh Generation
    MeshGenerator,
    TetMesh,
    
    # FEM Solver
    TumorGrowthSolver,
    MaterialProperties,
    TissueType,
    TumorState,
    
    # MRI Simulation
    MRISimulator,
    TumorParameters,
    MRISequence,
    SimulationResult,
    
    # I/O Operations
    NIfTIWriter,
    load_nifti,
    save_nifti,
    
    # Spatial Transforms
    SpatialTransform,
    ANTsTransformExporter,
    
    # Biophysical Constraints
    BiophysicalConstraints,
)

# Configure matplotlib for inline display
%matplotlib inline
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 100

print("PFT_FEM API imported successfully!")

---

## Stage 1: Atlas Loading

The first stage loads the SUIT (Spatially Unbiased Infratentorial Template) cerebellar atlas. This provides:
- A T1-weighted template image
- Anatomical region labels (30 cerebellar/brainstem regions)
- Coordinate transformation matrices

The `SUITAtlasLoader` can load either a real SUIT atlas from disk or generate a synthetic version for testing.

In [None]:
# Initialize the atlas loader
# Pass atlas_dir=None to use synthetic atlas, or provide path to real SUIT atlas
loader = SUITAtlasLoader(atlas_dir=None)

# Load the atlas data
atlas_data = loader.load()

# Examine the loaded data
print("=== Atlas Data Summary ===")
print(f"Template shape: {atlas_data.template.shape}")
print(f"Template dtype: {atlas_data.template.dtype}")
print(f"Labels shape: {atlas_data.labels.shape}")
print(f"Voxel size: {atlas_data.voxel_size} mm")
print(f"Number of regions: {len(atlas_data.regions)}")
print(f"\nAffine matrix:\n{atlas_data.affine}")

### Exploring Atlas Regions

The SUIT atlas contains 30 labeled regions covering the cerebellum and brainstem.

In [None]:
# List all atlas regions
print("=== Atlas Regions ===")
print(f"{'Label':<8} {'Name':<30} {'Hemisphere':<12}")
print("-" * 50)

for label, region in sorted(atlas_data.regions.items()):
    print(f"{label:<8} {region.name:<30} {region.hemisphere:<12}")

### Using AtlasProcessor

The `AtlasProcessor` class provides utilities for working with atlas data, including tissue extraction and mask generation.

In [None]:
# Create an atlas processor
processor = AtlasProcessor(atlas_data)

# Get tissue masks for different structures
cerebellum_mask = processor.get_tissue_mask("cerebellum")
brainstem_mask = processor.get_tissue_mask("brainstem")
full_mask = processor.get_tissue_mask("all")

print("=== Tissue Mask Statistics ===")
print(f"Cerebellum voxels: {np.sum(cerebellum_mask):,}")
print(f"Brainstem voxels: {np.sum(brainstem_mask):,}")
print(f"Total tissue voxels: {np.sum(full_mask):,}")

# Calculate volumes (voxel count * voxel volume)
voxel_volume = np.prod(atlas_data.voxel_size)  # mm^3
print(f"\nCerebellum volume: {np.sum(cerebellum_mask) * voxel_volume / 1000:.1f} cm³")
print(f"Brainstem volume: {np.sum(brainstem_mask) * voxel_volume / 1000:.1f} cm³")

### Visualizing the Atlas

In [None]:
# Visualize atlas slices
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Get center slices
mid_x = atlas_data.template.shape[0] // 2
mid_y = atlas_data.template.shape[1] // 2
mid_z = atlas_data.template.shape[2] // 2

# Top row: Template (T1)
axes[0, 0].imshow(atlas_data.template[mid_x, :, :].T, cmap='gray', origin='lower')
axes[0, 0].set_title('T1 Template - Sagittal')
axes[0, 0].set_xlabel('Y (mm)')
axes[0, 0].set_ylabel('Z (mm)')

axes[0, 1].imshow(atlas_data.template[:, mid_y, :].T, cmap='gray', origin='lower')
axes[0, 1].set_title('T1 Template - Coronal')
axes[0, 1].set_xlabel('X (mm)')
axes[0, 1].set_ylabel('Z (mm)')

axes[0, 2].imshow(atlas_data.template[:, :, mid_z].T, cmap='gray', origin='lower')
axes[0, 2].set_title('T1 Template - Axial')
axes[0, 2].set_xlabel('X (mm)')
axes[0, 2].set_ylabel('Y (mm)')

# Bottom row: Labels
axes[1, 0].imshow(atlas_data.labels[mid_x, :, :].T, cmap='nipy_spectral', origin='lower')
axes[1, 0].set_title('Region Labels - Sagittal')
axes[1, 0].set_xlabel('Y (mm)')
axes[1, 0].set_ylabel('Z (mm)')

axes[1, 1].imshow(atlas_data.labels[:, mid_y, :].T, cmap='nipy_spectral', origin='lower')
axes[1, 1].set_title('Region Labels - Coronal')
axes[1, 1].set_xlabel('X (mm)')
axes[1, 1].set_ylabel('Z (mm)')

axes[1, 2].imshow(atlas_data.labels[:, :, mid_z].T, cmap='nipy_spectral', origin='lower')
axes[1, 2].set_title('Region Labels - Axial')
axes[1, 2].set_xlabel('X (mm)')
axes[1, 2].set_ylabel('Y (mm)')

plt.tight_layout()
plt.suptitle('SUIT Cerebellar Atlas', y=1.02, fontsize=14)
plt.show()

---

## Stage 2: Mesh Generation

The second stage converts the volumetric atlas into a tetrahedral finite element mesh suitable for FEM simulation.

The `MeshGenerator` class creates a `TetMesh` object containing:
- **Nodes**: 3D coordinates of mesh vertices
- **Elements**: Tetrahedral connectivity (4 node indices per element)
- **Labels**: Tissue type at each node
- **Boundary info**: Surface nodes for boundary conditions

In [None]:
# Initialize the mesh generator
# subdivision_method: "five" or "six" tetrahedra per voxel
generator = MeshGenerator(subdivision_method="five")

# Generate mesh from the tissue mask
print("Generating tetrahedral mesh...")
mesh = generator.from_mask(
    mask=full_mask,
    voxel_size=atlas_data.voxel_size,
    labels=atlas_data.labels,
    affine=atlas_data.affine,
    simplify=True  # Reduce mesh complexity while preserving geometry
)

print("\n=== Mesh Statistics ===")
print(f"Number of nodes: {mesh.nodes.shape[0]:,}")
print(f"Number of elements: {mesh.elements.shape[0]:,}")
print(f"Number of boundary nodes: {len(mesh.boundary_nodes):,}")
print(f"Node coordinate range:")
print(f"  X: [{mesh.nodes[:, 0].min():.1f}, {mesh.nodes[:, 0].max():.1f}] mm")
print(f"  Y: [{mesh.nodes[:, 1].min():.1f}, {mesh.nodes[:, 1].max():.1f}] mm")
print(f"  Z: [{mesh.nodes[:, 2].min():.1f}, {mesh.nodes[:, 2].max():.1f}] mm")

### Mesh Quality Analysis

Good mesh quality is essential for accurate FEM results. We can compute various quality metrics.

In [None]:
# Compute mesh quality metrics
metrics = mesh.compute_quality_metrics()

print("=== Mesh Quality Metrics ===")
print(f"Number of elements: {metrics['num_elements']:,}")
print(f"Number of nodes: {metrics['num_nodes']:,}")
print(f"\nElement Volume Statistics:")
print(f"  Min volume: {metrics['min_volume']:.4f} mm³")
print(f"  Max volume: {metrics['max_volume']:.4f} mm³")
print(f"  Mean volume: {metrics['mean_volume']:.4f} mm³")
print(f"  Total volume: {metrics['total_volume']:.1f} mm³")
print(f"\nElement Quality (aspect ratio):")
print(f"  Min quality: {metrics['min_quality']:.3f}")
print(f"  Max quality: {metrics['max_quality']:.3f}")
print(f"  Mean quality: {metrics['mean_quality']:.3f}")

### Visualizing the Mesh

In [None]:
# Visualize mesh nodes (3D scatter plot of a subset)
fig = plt.figure(figsize=(15, 5))

# Sample nodes for visualization (too many to plot all)
sample_size = min(5000, mesh.nodes.shape[0])
sample_idx = np.random.choice(mesh.nodes.shape[0], sample_size, replace=False)
sample_nodes = mesh.nodes[sample_idx]
sample_labels = mesh.node_labels[sample_idx]

# 3D scatter plot
ax1 = fig.add_subplot(131, projection='3d')
scatter = ax1.scatter(
    sample_nodes[:, 0], sample_nodes[:, 1], sample_nodes[:, 2],
    c=sample_labels, cmap='nipy_spectral', s=1, alpha=0.5
)
ax1.set_xlabel('X (mm)')
ax1.set_ylabel('Y (mm)')
ax1.set_zlabel('Z (mm)')
ax1.set_title('Mesh Nodes (3D view)')

# 2D projections
ax2 = fig.add_subplot(132)
ax2.scatter(sample_nodes[:, 0], sample_nodes[:, 1], c=sample_labels, cmap='nipy_spectral', s=1, alpha=0.3)
ax2.set_xlabel('X (mm)')
ax2.set_ylabel('Y (mm)')
ax2.set_title('XY Projection (Axial)')
ax2.set_aspect('equal')

ax3 = fig.add_subplot(133)
ax3.scatter(sample_nodes[:, 0], sample_nodes[:, 2], c=sample_labels, cmap='nipy_spectral', s=1, alpha=0.3)
ax3.set_xlabel('X (mm)')
ax3.set_ylabel('Z (mm)')
ax3.set_title('XZ Projection (Coronal)')
ax3.set_aspect('equal')

plt.tight_layout()
plt.show()

---

## Stage 3: Tumor Growth Simulation (FEM)

The third stage simulates tumor growth using coupled reaction-diffusion and mechanical equilibrium equations.

### Physical Models

**Reaction-Diffusion (Fisher-Kolmogorov equation):**
$$\frac{\partial c}{\partial t} = D\nabla^2 c + \rho c\left(1 - \frac{c}{K}\right)$$

Where:
- $c$ = tumor cell density
- $D$ = diffusion coefficient (cell migration)
- $\rho$ = proliferation rate (cell division)
- $K$ = carrying capacity

**Mechanical Equilibrium (Linear Elasticity):**
$$\nabla \cdot \sigma + f = 0$$

Where $f = \alpha \cdot c \cdot \nabla c$ represents growth-induced body forces.

In [None]:
# Define material properties for brain tissue
material = MaterialProperties(
    young_modulus=3000.0,          # Pa (brain tissue stiffness)
    poisson_ratio=0.45,            # Nearly incompressible
    proliferation_rate=0.012,      # 1/day (cell division rate)
    diffusion_coefficient=0.15,    # mm²/day (cell migration)
    carrying_capacity=1.0,         # Maximum normalized cell density
    growth_stress_coefficient=0.1  # Pa (coupling strength)
)

print("=== Material Properties ===")
print(f"Young's Modulus: {material.young_modulus} Pa")
print(f"Poisson's Ratio: {material.poisson_ratio}")
print(f"Proliferation Rate: {material.proliferation_rate} /day")
print(f"Diffusion Coefficient: {material.diffusion_coefficient} mm²/day")
print(f"Carrying Capacity: {material.carrying_capacity}")

In [None]:
# Initialize the FEM solver
solver = TumorGrowthSolver(mesh, material)

# Place tumor seed at center of cerebellum
tumor_center = (
    mesh.nodes[:, 0].mean(),  # X center
    mesh.nodes[:, 1].mean(),  # Y center  
    mesh.nodes[:, 2].mean()   # Z center
)

# Define initial tumor state (seed)
initial_state = TumorState.initial(
    mesh=mesh,
    seed_center=tumor_center,
    seed_radius=5.0,          # Initial radius in mm
    seed_density=0.8          # Initial cell density (normalized)
)

print("=== Initial Tumor State ===")
print(f"Seed center: ({tumor_center[0]:.1f}, {tumor_center[1]:.1f}, {tumor_center[2]:.1f}) mm")
print(f"Seed radius: 5.0 mm")
print(f"Initial volume: {solver.compute_tumor_volume(initial_state):.1f} mm³")
print(f"Max cell density: {initial_state.cell_density.max():.2f}")

In [None]:
# Run the tumor growth simulation
print("Running tumor growth simulation...")
print("="*50)

# Track progress with a callback
def progress_callback(state):
    if state.time % 5 == 0:  # Print every 5 days
        volume = solver.compute_tumor_volume(state)
        max_disp = solver.compute_max_displacement(state)
        print(f"Day {state.time:3.0f}: Volume = {volume:8.1f} mm³, Max displacement = {max_disp:.2f} mm")

# Simulate tumor growth
states = solver.simulate(
    initial_state=initial_state,
    duration=30.0,            # Simulation duration in days
    dt=1.0,                   # Time step in days
    callback=progress_callback
)

print("="*50)
print(f"Simulation complete! Generated {len(states)} time points.")

In [None]:
# Analyze tumor growth over time
times = [s.time for s in states]
volumes = [solver.compute_tumor_volume(s) for s in states]
max_displacements = [solver.compute_max_displacement(s) for s in states]

# Get final state
final_state = states[-1]

print("=== Simulation Results ===")
print(f"Initial tumor volume: {volumes[0]:.1f} mm³")
print(f"Final tumor volume: {volumes[-1]:.1f} mm³")
print(f"Volume increase: {volumes[-1]/volumes[0]:.1f}x")
print(f"Maximum tissue displacement: {max_displacements[-1]:.2f} mm")

In [None]:
# Plot tumor growth dynamics
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Volume over time
axes[0].plot(times, volumes, 'b-', linewidth=2)
axes[0].fill_between(times, volumes, alpha=0.3)
axes[0].set_xlabel('Time (days)')
axes[0].set_ylabel('Tumor Volume (mm³)')
axes[0].set_title('Tumor Volume Growth')
axes[0].grid(True, alpha=0.3)

# Displacement over time
axes[1].plot(times, max_displacements, 'r-', linewidth=2)
axes[1].fill_between(times, max_displacements, alpha=0.3, color='red')
axes[1].set_xlabel('Time (days)')
axes[1].set_ylabel('Max Displacement (mm)')
axes[1].set_title('Maximum Tissue Displacement')
axes[1].grid(True, alpha=0.3)

# Growth rate (derivative)
growth_rates = np.gradient(volumes, times)
axes[2].plot(times, growth_rates, 'g-', linewidth=2)
axes[2].fill_between(times, growth_rates, alpha=0.3, color='green')
axes[2].set_xlabel('Time (days)')
axes[2].set_ylabel('Growth Rate (mm³/day)')
axes[2].set_title('Tumor Growth Rate')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Visualizing Tumor Cell Density

In [None]:
# Visualize cell density distribution at different time points
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Select time points to visualize
time_indices = [0, len(states)//4, len(states)//2, 3*len(states)//4, -1]

for idx, (ax_row, ti) in enumerate(zip(axes.flat, time_indices)):
    state = states[ti]
    
    # Get nodes with significant tumor density
    tumor_mask = state.cell_density > 0.1
    tumor_nodes = mesh.nodes[tumor_mask]
    tumor_density = state.cell_density[tumor_mask]
    
    if len(tumor_nodes) > 0:
        scatter = ax_row.scatter(
            tumor_nodes[:, 0], tumor_nodes[:, 1],
            c=tumor_density, cmap='hot', s=5, alpha=0.7,
            vmin=0, vmax=1
        )
        plt.colorbar(scatter, ax=ax_row, label='Cell Density')
    
    ax_row.set_xlabel('X (mm)')
    ax_row.set_ylabel('Y (mm)')
    ax_row.set_title(f'Day {state.time:.0f}')
    ax_row.set_aspect('equal')

# Hide unused subplot
axes[1, 2].axis('off')

plt.suptitle('Tumor Cell Density Evolution (XY projection)', fontsize=14)
plt.tight_layout()
plt.show()

---

## Stage 4: MRI Image Generation

The fourth stage generates synthetic MRI images from the simulation results. The `MRISimulator` class provides a high-level interface for this.

### Supported MRI Sequences

| Sequence | Description | Characteristics |
|----------|-------------|----------------|
| **T1** | Anatomical imaging | Gray matter darker than white matter |
| **T2** | Fluid-sensitive | CSF appears bright |
| **FLAIR** | Fluid-attenuated | CSF suppressed, edema bright |
| **T1_contrast** | Gadolinium enhanced | Enhancing tumor rim |
| **DWI** | Diffusion-weighted | Restricted diffusion in tumor |

In [None]:
# Define tumor parameters for MRI simulation
tumor_params = TumorParameters(
    center=tumor_center,
    initial_radius=5.0,
    proliferation_rate=0.012,
    diffusion_rate=0.15,
    necrotic_threshold=0.9,    # Density threshold for necrotic core
    edema_extent=10.0,         # Edema extends beyond tumor (mm)
    enhancement_ring=True      # Model contrast-enhancing rim
)

print("=== Tumor Parameters ===")
print(f"Center: ({tumor_params.center[0]:.1f}, {tumor_params.center[1]:.1f}, {tumor_params.center[2]:.1f}) mm")
print(f"Initial radius: {tumor_params.initial_radius} mm")
print(f"Proliferation rate: {tumor_params.proliferation_rate} /day")
print(f"Diffusion rate: {tumor_params.diffusion_rate} mm²/day")
print(f"Necrotic threshold: {tumor_params.necrotic_threshold}")
print(f"Edema extent: {tumor_params.edema_extent} mm")

In [None]:
# Initialize MRI simulator
simulator = MRISimulator(atlas_data, tumor_params)

# Run the full pipeline (or use step-by-step control)
print("Running full simulation pipeline...")
print("="*50)

result = simulator.run_full_pipeline(
    duration_days=30.0,
    sequences=[
        MRISequence.T1,
        MRISequence.T2,
        MRISequence.FLAIR,
        MRISequence.T1_CONTRAST,
        MRISequence.DWI
    ],
    verbose=True
)

print("="*50)
print("\n=== Simulation Result Summary ===")
print(f"Number of time points: {len(result.tumor_states)}")
print(f"MRI sequences generated: {list(result.mri_images.keys())}")
print(f"Output volume shape: {result.mri_images['T1'].shape}")
print(f"Tumor mask voxels: {np.sum(result.tumor_mask):,}")
print(f"Edema mask voxels: {np.sum(result.edema_mask):,}")

In [None]:
# Visualize all MRI sequences
sequences = list(result.mri_images.keys())
n_sequences = len(sequences)

fig, axes = plt.subplots(n_sequences, 3, figsize=(15, 4*n_sequences))

# Get center slices
shape = result.mri_images['T1'].shape
mid_x, mid_y, mid_z = shape[0]//2, shape[1]//2, shape[2]//2

for row, seq_name in enumerate(sequences):
    mri = result.mri_images[seq_name]
    
    # Sagittal
    axes[row, 0].imshow(mri[mid_x, :, :].T, cmap='gray', origin='lower')
    axes[row, 0].set_title(f'{seq_name} - Sagittal')
    axes[row, 0].set_xlabel('Y')
    axes[row, 0].set_ylabel('Z')
    
    # Coronal
    axes[row, 1].imshow(mri[:, mid_y, :].T, cmap='gray', origin='lower')
    axes[row, 1].set_title(f'{seq_name} - Coronal')
    axes[row, 1].set_xlabel('X')
    axes[row, 1].set_ylabel('Z')
    
    # Axial
    axes[row, 2].imshow(mri[:, :, mid_z].T, cmap='gray', origin='lower')
    axes[row, 2].set_title(f'{seq_name} - Axial')
    axes[row, 2].set_xlabel('X')
    axes[row, 2].set_ylabel('Y')

plt.tight_layout()
plt.suptitle('Synthetic MRI Sequences', y=1.01, fontsize=14)
plt.show()

In [None]:
# Visualize tumor and edema masks overlaid on T1
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

t1 = result.mri_images['T1']
tumor_mask = result.tumor_mask
edema_mask = result.edema_mask

# Create RGB overlay
def create_overlay(image_slice, tumor_slice, edema_slice):
    # Normalize image to [0, 1]
    img_norm = (image_slice - image_slice.min()) / (image_slice.max() - image_slice.min() + 1e-8)
    
    # Create RGB image
    rgb = np.stack([img_norm, img_norm, img_norm], axis=-1)
    
    # Overlay tumor in red
    rgb[tumor_slice, 0] = 1.0
    rgb[tumor_slice, 1] = 0.3
    rgb[tumor_slice, 2] = 0.3
    
    # Overlay edema in yellow (only where not tumor)
    edema_only = edema_slice & ~tumor_slice
    rgb[edema_only, 0] = 1.0
    rgb[edema_only, 1] = 0.8
    rgb[edema_only, 2] = 0.2
    
    return rgb

# Sagittal
overlay = create_overlay(t1[mid_x, :, :].T, tumor_mask[mid_x, :, :].T, edema_mask[mid_x, :, :].T)
axes[0].imshow(overlay, origin='lower')
axes[0].set_title('Sagittal View')
axes[0].set_xlabel('Y')
axes[0].set_ylabel('Z')

# Coronal
overlay = create_overlay(t1[:, mid_y, :].T, tumor_mask[:, mid_y, :].T, edema_mask[:, mid_y, :].T)
axes[1].imshow(overlay, origin='lower')
axes[1].set_title('Coronal View')
axes[1].set_xlabel('X')
axes[1].set_ylabel('Z')

# Axial
overlay = create_overlay(t1[:, :, mid_z].T, tumor_mask[:, :, mid_z].T, edema_mask[:, :, mid_z].T)
axes[2].imshow(overlay, origin='lower')
axes[2].set_title('Axial View')
axes[2].set_xlabel('X')
axes[2].set_ylabel('Y')

# Add legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='red', alpha=0.7, label='Tumor'),
    Patch(facecolor='yellow', alpha=0.7, label='Edema')
]
fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98))

plt.suptitle('T1 with Tumor and Edema Overlay', fontsize=14)
plt.tight_layout()
plt.show()

---

## Stage 5: Results Export

The final stage saves all simulation outputs to NIfTI format for use with standard neuroimaging tools.

In [None]:
# Set up output directory
output_dir = Path("./simulation_output")
output_dir.mkdir(exist_ok=True)

# Initialize the NIfTI writer
writer = NIfTIWriter(
    output_dir=str(output_dir),
    affine=atlas_data.affine,
    base_name="pft_demo"
)

print(f"Output directory: {output_dir.absolute()}")

In [None]:
# Write all simulation results
print("Saving simulation results...")
print("="*50)

output_paths = writer.write_simulation_results(
    result,
    export_transform=True  # Also export ANTs-compatible spatial transforms
)

print("\n=== Generated Files ===")
for name, path in output_paths.items():
    print(f"  {name}: {path}")

In [None]:
# Verify saved files by loading one back
t1_loaded = load_nifti(output_paths['mri_T1'])

print("=== Verification ===")
print(f"Loaded T1 shape: {t1_loaded.shape}")
print(f"Matches original: {np.allclose(t1_loaded, result.mri_images['T1'])}")

---

## Advanced Usage

### Step-by-Step Pipeline Control

For more control over the simulation, you can run each stage independently.

In [None]:
# Example: Step-by-step pipeline execution
print("=== Step-by-Step Pipeline ===")

# Step 1: Load atlas
print("\n[Step 1] Loading atlas...")
loader = SUITAtlasLoader()
atlas = loader.load()
print(f"  Atlas shape: {atlas.template.shape}")

# Step 2: Create processor and mesh
print("\n[Step 2] Generating mesh...")
proc = AtlasProcessor(atlas)
mask = proc.get_tissue_mask("all")
gen = MeshGenerator()
msh = gen.from_mask(mask, atlas.voxel_size, atlas.labels, atlas.affine)
print(f"  Mesh: {msh.nodes.shape[0]} nodes, {msh.elements.shape[0]} elements")

# Step 3: Initialize simulator with custom parameters
print("\n[Step 3] Setting up simulator...")
params = TumorParameters(
    center=(5.0, -3.0, 2.0),  # Slightly off-center
    initial_radius=4.0,
    proliferation_rate=0.015,  # Faster growth
    diffusion_rate=0.20        # More diffuse
)
sim = MRISimulator(atlas, params)
sim.setup(mesh_resolution=2.0)
print("  Simulator configured")

# Step 4: Run growth simulation
print("\n[Step 4] Simulating growth...")
tumor_states = sim.simulate_growth(
    duration_days=20.0,
    time_step=2.0,  # Larger time steps
    verbose=False
)
print(f"  Generated {len(tumor_states)} states")

# Step 5: Generate specific MRI sequences
print("\n[Step 5] Generating MRI images...")
mri = sim.generate_mri(
    tumor_state=tumor_states[-1],
    sequences=[MRISequence.T1, MRISequence.FLAIR],
    TR=500.0,   # Repetition time
    TE=15.0,    # Echo time
    TI=1200.0   # Inversion time (for FLAIR)
)
print(f"  Sequences: {list(mri.keys())}")

print("\n[Done] Step-by-step pipeline complete!")

### Biophysical Constraints

The `BiophysicalConstraints` class allows tissue-specific material properties.

In [None]:
# Configure tissue-specific properties
constraints = BiophysicalConstraints()

# Get properties for different tissue types
tissue_types = [TissueType.GRAY_MATTER, TissueType.WHITE_MATTER, 
                TissueType.CSF, TissueType.TUMOR]

print("=== Tissue-Specific Properties ===")
print(f"{'Tissue':<15} {'Stiffness':<12} {'Diffusion':<12} {'Notes'}")
print("-" * 60)

for tissue in tissue_types:
    props = constraints.get_properties(tissue)
    notes = {
        TissueType.GRAY_MATTER: "Baseline",
        TissueType.WHITE_MATTER: "Stiffer, faster diffusion along fibers",
        TissueType.CSF: "Fluid barrier",
        TissueType.TUMOR: "Dense, restricted diffusion"
    }.get(tissue, "")
    print(f"{tissue.name:<15} {props.stiffness_factor:<12.2f} {props.diffusion_factor:<12.2f} {notes}")

### Spatial Transforms

The simulation generates spatial transforms that map coordinates from the original SUIT space to the deformed (tumor-affected) space.

In [None]:
# Access spatial transform from simulation result
transform = result.spatial_transform

print("=== Spatial Transform ===")
print(f"Transform type: {type(transform).__name__}")
print(f"Deformation field shape: {transform.deformation_field.shape}")
print(f"Max deformation magnitude: {np.max(np.linalg.norm(transform.deformation_field, axis=-1)):.2f} mm")

# Example: Transform a point
original_point = np.array([45.0, 55.0, 45.0])  # Point in SUIT space
deformed_point = transform.apply(original_point)

print(f"\nExample point transformation:")
print(f"  Original: {original_point}")
print(f"  Deformed: {deformed_point}")
print(f"  Displacement: {np.linalg.norm(deformed_point - original_point):.2f} mm")

In [None]:
# Export transforms in ANTs-compatible format
exporter = ANTsTransformExporter(output_dir=str(output_dir))

ants_paths = exporter.export(
    transform=result.spatial_transform,
    base_name="pft_demo_transform"
)

print("=== ANTs Transform Files ===")
for name, path in ants_paths.items():
    print(f"  {name}: {path}")

### Simulation Metadata

The simulation result includes detailed metadata about parameters and statistics.

In [None]:
# Explore simulation metadata
import json

print("=== Simulation Metadata ===")
print(json.dumps(result.metadata, indent=2, default=str))

---

## Cleanup

Optionally remove the output files created during this demo.

In [None]:
# Uncomment to remove output directory
# import shutil
# shutil.rmtree(output_dir)
# print(f"Removed {output_dir}")

print("Demo complete! Output files are in:", output_dir.absolute())

---

## Summary

This notebook demonstrated the complete PFT_FEM Python API:

1. **Atlas Loading**: `SUITAtlasLoader` and `AtlasProcessor` for loading and manipulating the SUIT cerebellar atlas

2. **Mesh Generation**: `MeshGenerator` and `TetMesh` for creating tetrahedral FEM meshes

3. **Tumor Growth Simulation**: `TumorGrowthSolver`, `MaterialProperties`, and `TumorState` for coupled reaction-diffusion simulation

4. **MRI Image Generation**: `MRISimulator`, `TumorParameters`, and `MRISequence` for synthetic MRI synthesis

5. **Results Export**: `NIfTIWriter`, `SpatialTransform`, and `ANTsTransformExporter` for saving outputs

### Key Classes Reference

| Stage | Main Classes | Purpose |
|-------|--------------|---------- |
| Atlas | `SUITAtlasLoader`, `AtlasProcessor`, `AtlasData` | Load and process SUIT atlas |
| Mesh | `MeshGenerator`, `TetMesh` | Create FEM mesh |
| FEM | `TumorGrowthSolver`, `MaterialProperties`, `TumorState` | Simulate tumor growth |
| MRI | `MRISimulator`, `TumorParameters`, `MRISequence` | Generate synthetic MRI |
| I/O | `NIfTIWriter`, `load_nifti`, `save_nifti` | File operations |
| Transforms | `SpatialTransform`, `ANTsTransformExporter` | Coordinate mappings |

For more information, see the [README](../README.md) or run `pft-simulate --help` for CLI usage.