# NeSVoR Inference Example

This notebook demonstrates how to perform inference (sampling) on a pre-trained NeSVoR model.

## What you need:
- A trained model file (`.pt` file from training)
- Desired output resolution

## What you get:
- High-resolution 3D volume (.nii.gz)

## Setup

In [None]:
import torch
import logging
import matplotlib.pyplot as plt
import numpy as np

from model.models import INR
from model.sample import sample_volume
from utils import Volume

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Configuration

**IMPORTANT**: Update the paths below to match your files!

In [None]:
# ===== UPDATE THESE PATHS =====
MODEL_PATH = "path/to/your/trained_model.pt"  # Your trained model
OUTPUT_PATH = "output_volume.nii.gz"          # Where to save output

# Sampling parameters
OUTPUT_RESOLUTION = 0.8      # Output resolution in mm (lower = higher resolution)
N_INFERENCE_SAMPLES = 128    # Number of PSF samples (higher = better quality, slower)
INFERENCE_BATCH_SIZE = 1024  # Batch size (higher = faster, more memory)

# Optional parameters
OUTPUT_INTENSITY_MEAN = None  # Set to rescale intensity (e.g., 1000.0)
WITH_BACKGROUND = False       # Set True to include background

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Configuration:")
print(f"  Model: {MODEL_PATH}")
print(f"  Output: {OUTPUT_PATH}")
print(f"  Resolution: {OUTPUT_RESOLUTION} mm")
print(f"  Device: {DEVICE}")

## Step 1: Load Trained Model

In [None]:
print("Loading trained model...")

# Load checkpoint
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)

# Extract components
model_state = checkpoint["model"]
mask = checkpoint["mask"]
args = checkpoint["args"]

# Create model instance
model = INR(model_state["bounding_box"], args)
model.load_state_dict(model_state)
model.to(DEVICE)
model.eval()

print("✓ Model loaded successfully")
print(f"\nModel Information:")
print(f"  Bounding box: {model.bounding_box.cpu().numpy()}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

print(f"\nMask Information:")
print(f"  Shape: {mask.image.shape}")
print(f"  Resolution: {mask.resolution_x:.3f} mm")
print(f"  Masked voxels: {mask.image.sum().item():,}")

## Step 2: Visualize Reconstruction Mask

In [None]:
# Show mask in 3 orthogonal views
mask_data = mask.image[0].cpu().numpy()

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Sagittal
mid_x = mask_data.shape[0] // 2
axes[0].imshow(mask_data[mid_x, :, :].T, cmap='binary', origin='lower')
axes[0].set_title('Sagittal View (YZ)')
axes[0].axis('off')

# Coronal
mid_y = mask_data.shape[1] // 2
axes[1].imshow(mask_data[:, mid_y, :].T, cmap='binary', origin='lower')
axes[1].set_title('Coronal View (XZ)')
axes[1].axis('off')

# Axial
mid_z = mask_data.shape[2] // 2
axes[2].imshow(mask_data[:, :, mid_z].T, cmap='binary', origin='lower')
axes[2].set_title('Axial View (XY)')
axes[2].axis('off')

plt.suptitle('Reconstruction Mask', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Step 3: Sample High-Resolution Volume

This is the main inference step where we query the trained implicit neural representation to generate a high-resolution volume.

In [None]:
print("Sampling high-resolution volume...")
print(f"  Resolution: {OUTPUT_RESOLUTION} mm")
print(f"  PSF samples: {N_INFERENCE_SAMPLES}")
print(f"  Batch size: {INFERENCE_BATCH_SIZE}")
print("\nThis may take a few minutes...")

# Sample volume
output_volume = sample_volume(
    model,
    mask,
    psf_resolution=OUTPUT_RESOLUTION,
    batch_size=INFERENCE_BATCH_SIZE,
    n_samples=N_INFERENCE_SAMPLES,
)

print("\n✓ Volume sampled successfully!")
print(f"\nOutput Volume Information:")
print(f"  Shape: {output_volume.image.shape}")
print(f"  Resolution: {output_volume.resolution_x:.3f} mm (isotropic)")
print(f"  Data type: {output_volume.image.dtype}")
print(f"  Intensity range: [{output_volume.image.min().item():.2f}, {output_volume.image.max().item():.2f}]")
print(f"  Mean intensity: {output_volume.image.mean().item():.2f}")

## Step 4: Visualize Output Volume

In [None]:
# Show orthogonal views
vol_data = output_volume.image[0].cpu().numpy()

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Sagittal
mid_x = vol_data.shape[0] // 2
axes[0].imshow(vol_data[mid_x, :, :].T, cmap='gray', origin='lower')
axes[0].set_title('Sagittal View (YZ)')
axes[0].axis('off')

# Coronal
mid_y = vol_data.shape[1] // 2
axes[1].imshow(vol_data[:, mid_y, :].T, cmap='gray', origin='lower')
axes[1].set_title('Coronal View (XZ)')
axes[1].axis('off')

# Axial
mid_z = vol_data.shape[2] // 2
axes[2].imshow(vol_data[:, :, mid_z].T, cmap='gray', origin='lower')
axes[2].set_title('Axial View (XY)')
axes[2].axis('off')

plt.suptitle(f'Output Volume ({OUTPUT_RESOLUTION}mm isotropic)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Step 5: Intensity Distribution

In [None]:
# Plot intensity histogram
data = output_volume.image[output_volume.image > 0].cpu().numpy().flatten()

fig, ax = plt.subplots(1, 1, figsize=(10, 4))
ax.hist(data, bins=50, alpha=0.7, edgecolor='black', color='steelblue')
ax.set_xlabel('Intensity')
ax.set_ylabel('Frequency')
ax.set_title('Output Volume: Intensity Distribution')
ax.grid(True, alpha=0.3)

# Add statistics
ax.axvline(data.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {data.mean():.2f}')
ax.axvline(np.median(data), color='green', linestyle='--', linewidth=2, label=f'Median: {np.median(data):.2f}')
ax.legend()

plt.tight_layout()
plt.show()

## Step 6: Interactive Slice Viewer (Optional)

Use sliders to explore different slices of the volume.

In [None]:
from ipywidgets import interact, IntSlider

def show_slice(axis, slice_idx):
    """Interactive slice viewer."""
    vol_data = output_volume.image[0].cpu().numpy()
    
    if axis == 0:  # Sagittal
        slice_data = vol_data[slice_idx, :, :]
        title = f'Sagittal (YZ) - Slice {slice_idx}/{vol_data.shape[0]-1}'
    elif axis == 1:  # Coronal
        slice_data = vol_data[:, slice_idx, :]
        title = f'Coronal (XZ) - Slice {slice_idx}/{vol_data.shape[1]-1}'
    else:  # Axial
        slice_data = vol_data[:, :, slice_idx]
        title = f'Axial (XY) - Slice {slice_idx}/{vol_data.shape[2]-1}'
    
    plt.figure(figsize=(8, 8))
    plt.imshow(slice_data.T, cmap='gray', origin='lower')
    plt.title(title, fontsize=14)
    plt.axis('off')
    plt.show()

# Create interactive viewer
vol_shape = output_volume.image[0].shape
print("📊 Interactive Slice Viewer:")
print("  Axis 0: Sagittal (YZ)")
print("  Axis 1: Coronal (XZ)")
print("  Axis 2: Axial (XY)")

interact(
    show_slice,
    axis=IntSlider(min=0, max=2, step=1, value=2, description='Axis:'),
    slice_idx=IntSlider(min=0, max=max(vol_shape)-1, step=1, value=max(vol_shape)//2, description='Slice:')
)

## Step 7: (Optional) Rescale Intensity

If you want to rescale the output to a specific mean intensity.

In [None]:
if OUTPUT_INTENSITY_MEAN is not None:
    print(f"Rescaling intensity to mean={OUTPUT_INTENSITY_MEAN}")
    output_volume.rescale(OUTPUT_INTENSITY_MEAN)
    print(f"✓ New intensity range: [{output_volume.image.min().item():.2f}, {output_volume.image.max().item():.2f}]")
    print(f"✓ New mean: {output_volume.image.mean().item():.2f}")
else:
    print("No intensity rescaling requested")

## Step 8: Save Output Volume

In [None]:
print(f"Saving volume to {OUTPUT_PATH}...")

# Save volume
output_volume.save(OUTPUT_PATH, masked=not WITH_BACKGROUND)

print("✓ Volume saved successfully!")
print(f"\nOutput file: {OUTPUT_PATH}")
print(f"  Shape: {output_volume.image.shape}")
print(f"  Resolution: {OUTPUT_RESOLUTION} mm (isotropic)")
print(f"  Background: {'Included' if WITH_BACKGROUND else 'Masked'}")

## Summary

In [None]:
print("="*80)
print("INFERENCE SUMMARY")
print("="*80)
print(f"\n📥 Input:")
print(f"  Model: {MODEL_PATH}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")

print(f"\n⚙️  Settings:")
print(f"  Resolution: {OUTPUT_RESOLUTION} mm")
print(f"  PSF samples: {N_INFERENCE_SAMPLES}")
print(f"  Device: {DEVICE}")

print(f"\n📤 Output:")
print(f"  File: {OUTPUT_PATH}")
print(f"  Shape: {output_volume.image.shape}")
print(f"  Resolution: {output_volume.resolution_x:.3f} mm (isotropic)")
print(f"  Intensity: [{output_volume.image.min().item():.2f}, {output_volume.image.max().item():.2f}]")

print("\n" + "="*80)
print("✅ INFERENCE COMPLETED SUCCESSFULLY")
print("="*80)