# Deep Learning Segmentation: 3D U-Net Inference

This notebook demonstrates how to use a pre-trained 3D U-Net (e.g., from MONAI or nnU-Net) to segment brain MRI volumes.

- Loads preprocessed images
- Runs inference with a pre-trained model
- Saves predicted masks
- Computes Dice/Jaccard metrics if ground truth is available
- Visualizes results for a single example


In [1]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from pathlib import Path
import torch
from monai.networks.nets import UNet
from monai.transforms import Compose, LoadImaged, ScaleIntensityd, ToTensord
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from sklearn.metrics import jaccard_score


## 1. Load a Pre-trained 3D U-Net Model
(This example uses MONAI's UNet. Replace with nnU-Net or your own model as needed.)


In [2]:
# Example: create a MONAI UNet and load weights (replace with your checkpoint)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    ).to(device)
# model.load_state_dict(torch.load('path_to_checkpoint.pth', map_location=device))
model.eval()


  return torch._C._cuda_getDeviceCount() > 0


TypeError: UNet.__init__() got an unexpected keyword argument 'dimensions'

## 2. Define Preprocessing and Inference Pipeline


In [None]:
def preprocess_image(img_path):
    img = nib.load(str(img_path))
    data = img.get_fdata().astype(np.float32)
    # Normalize to [0, 1]
    data = (data - data.min()) / (data.max() - data.min())
    # Add channel and batch dimensions
    data = data[None, None, ...]  
    # shape: (1, 1, X, Y, Z)
    return torch.from_numpy(data).to(device)


## 3. Inference on a Single Example


In [None]:
img_path = Path('../data/preprocessed/IBSR_10_zscore.nii.gz')
input_tensor = preprocess_image(img_path)
with torch.no_grad():
    output = sliding_window_inference(input_tensor, roi_size=(96,96,96), sw_batch_size=1, predictor=model)
    pred = torch.argmax(output, dim=1).cpu().numpy()[0]
# Save predicted mask
    img = nib.load(str(img_path))
    nib.save(nib.Nifti1Image(pred.astype(np.uint8), img.affine), 'IBSR_10_unet_pred.nii.gz')
# Visualize central slice
    slice_idx = pred.shape[2] // 2
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1); plt.imshow(img.get_fdata()[:,:,slice_idx], cmap='gray'); plt.title('Image')
    plt.subplot(1,2,2); plt.imshow(pred[:,:,slice_idx], cmap='hot', alpha=0.7); plt.title('Predicted Mask')
    plt.show()


## 4. Compute Dice/Jaccard Metrics (if ground truth available)


In [None]:
gt_path = Path('../data/subset/IBSR_10/segmentation/analyze/IBSR_10_seg_ana.img')
if gt_path.exists():
    gt_img = nib.load(str(gt_path))
    gt_data = gt_img.get_fdata()
    gt_bin = (gt_data > 0)
    pred_bin = (pred > 0)
    # Dice
    intersection = np.logical_and(gt_bin, pred_bin).sum()
    dice = 2. * intersection / (gt_bin.sum() + pred_bin.sum())
    # Jaccard
    jaccard = intersection / np.logical_or(gt_bin, pred_bin).sum()
    print(f'Dice: {dice:.3f}, Jaccard: {jaccard:.3f}')
else:
    print('Ground truth not found.')


## 5. Batch Inference for All Images


In [None]:
input_dir = Path('/home/orion23/Documents/repos/neurocut-tms-navigation-segmentation/data/preprocessed')
out_dir = Path('/home/orion23/Documents/repos/neurocut-tms-navigation-segmentation/data/deep_learning_segmented')
out_dir.mkdir(exist_ok=True)
for img_path in input_dir.glob('*_zscore.nii.gz'):
    input_tensor = preprocess_image(img_path)
    with torch.no_grad():
        output = sliding_window_inference(input_tensor, roi_size=(96,96,96), sw_batch_size=1, predictor=model)
        pred = torch.argmax(output, dim=1).cpu().numpy()[0]
    img = nib.load(str(img_path))
    nib.save(nib.Nifti1Image(pred.astype(np.uint8), img.affine), out_dir / f'{img_path.stem}_unet_pred.nii.gz')
    # Optionally compute metrics if ground truth exists
    subject = '_'.join(img_path.stem.split('_')[:2])
    gt_path = Path(f'/home/orion23/Documents/repos/neurocut-tms-navigation-segmentation/data/subset/{subject}/segmentation/analyze/{subject}_seg_ana.img')
    if gt_path.exists():
        gt_img = nib.load(str(gt_path))
        gt_data = gt_img.get_fdata()
        gt_bin = (gt_data > 0)
        pred_bin = (pred > 0)
        intersection = np.logical_and(gt_bin, pred_bin).sum()
        dice = 2. * intersection / (gt_bin.sum() + pred_bin.sum())
        jaccard = intersection / np.logical_or(gt_bin, pred_bin).sum()
        print(f'{subject}: Dice={dice:.3f}, Jaccard={jaccard:.3f}')
    else:
        print(f'{subject}: Ground truth not found.')
