# Rodent Brain Atlas Registration and Electrode Deformation

This notebook demonstrates a complete pipeline for registering rodent brain atlases and performing electrode deformation analysis. The workflow includes:

1. **Initial Setup and Data Loading**
2. **Rigid Registration** - Center alignment of atlas to subject
3. **Affine Registration** - Linear alignment refinement  
4. **Non-linear Registration** - Deformable registration for precise anatomical matching
5. **Electrode Deformation Analysis** - Modeling tissue deformation around implanted electrodes

## Overview

This registration pipeline uses the Waxholm Space (WHS) rat brain atlas to register with individual subject brain images, followed by specialized electrode deformation modeling. The process is essential for:

- Accurate anatomical labeling of subject brain regions
- Understanding tissue displacement due to electrode implantation
- Quantifying registration quality through Jacobian determinant analysis

## Dependencies

The notebook requires the following packages:
- `nilearn` - Neuroimaging analysis and visualization
- `nibabel` - NIfTI image I/O
- `SimpleITK` - Image registration toolkit
- `MONAI` - Medical imaging transformations
- `numpy` - Numerical computing

---

## 1. Import Libraries and Initialize Environment

Import all necessary libraries for image processing, registration, and visualization.

In [None]:

import os
import sys
# All dependencies are now in the current directory
import nilearn.image as ni
import nibabel as nb
from nilearn.plotting import plot_anat, plot_prob_atlas, show, plot_stat_map
import SimpleITK as sitk
from utils import pad_nifti_image, multires_registration, interpolate_zeros
from aligner import Aligner
from warp_utils import apply_warp
import numpy as np
from monai.transforms import LoadImage, EnsureChannelFirst
from warper import Warper
from nibabel.processing import resample_to_output



# %matplotlib notebook
# import gui


## 2. Configure File Paths

Define all input and output file paths for the registration pipeline. 

**Input Files:**
- Subject BSE T2 image
- Atlas BSE T2 template  
- Atlas anatomical labels

**Output Files:**
- Registered atlas images at each stage
- Transformation files
- Quality metrics (Jacobian determinants)

In [None]:
# load original reoriented image and resample to 0.1mm cubic voxels, preserving FOV using nibabel
dir_name = '/project2/ajoshi_1183/data'
if not os.path.isdir(dir_name):
    dir_name = '/home/ajoshi/project2_ajoshi_1183/data'

fname = f"{dir_name}/RodentTools/for_Seymour/11_15_2025/MRI/Raw T2/r59/r59.r.nii.gz"
subbase=f"{dir_name}/RodentTools/for_Seymour/11_15_2025/MRI/Raw T2/r59/R59"

if not os.path.isfile(subbase + ".reoriented.nii.gz"):
    img = nb.load(fname)
    resampled_img = resample_to_output(img, voxel_sizes=(0.1, 0.1, 0.1), order=1)  # order=1: linear interpolation
    nb.save(resampled_img, subbase + ".reoriented.nii.gz")

#subbase = f"{dir_name}/RodentTools/data/test4/29408.native"#


Run BrainSuite's BSE and interactive mask tool to generate brain mask and skull stripped image

In [None]:

sub_bse_t2 = f"{dir_name}/RodentTools/for_Seymour/11_15_2025/MRI/Raw T2/r59/R59.reoriented.bse.nii.gz" #subbase+".bfc.nii.gz"

# Run BrainSuite BSE to get brain-extracted image




atlas_bse_t2 = f"{dir_name}/RodentTools/Atlases/Waxholm/WHS_SD_rat_atlas_v4_pack/WHS_SD_rat_T2star_v1.01.bse.nii.gz"
atlas_labels = f"{dir_name}/RodentTools/Atlases/Waxholm/WHS_SD_rat_atlas_v4_pack/WHS_SD_rat_atlas_v4.nii.gz"

centered_atlas = subbase+".atlas.cent.nii.gz"
centered_atlas_labels = subbase+".atlas.cent.label.nii.gz"
cent_transform_file = subbase+".cent.reg.tfm"
inv_cent_transform_file = subbase+".cent.reg.inv.tfm"
centered_atlas_linreg = subbase+".atlas.lin.nii.gz"
centered_atlas_linreg_labels = subbase+".atlas.lin.label.nii.gz"
lin_reg_map_file = subbase+".lin_ddf.map.nii.gz"


### Create Brain Mask for Damaged Tissue

**Critical for ex vivo brains with cuts/artifacts!**

This mask tells the registration algorithm which regions to use for alignment. Areas outside the mask are ignored, preventing the algorithm from trying to match damaged/cut regions.

Two approaches:
1. **Automatic** (quick): Threshold-based masking (run the cell below)
2. **Manual** (better): Create/edit mask in ITK-SNAP or BrainSuite for precise control over excluded regions


In [None]:
# Create or load brain mask to exclude damaged regions
# --- Method 1: Atlas-Guided Masking (More Robust) ---

# Step 1: Warp the atlas brain mask to the subject space using the affine transform
print("Warping atlas mask to subject space...")

# Create a mask from the atlas BSE image if it doesn't exist
atlas_mask_file = atlas_bse_t2.replace('.nii.gz', '.mask.nii.gz') 
if not os.path.isfile(atlas_mask_file):
    print("Atlas mask not found, creating one from BSE image...")
    atlas_img = nb.load(atlas_bse_t2)
    atlas_data = atlas_img.get_fdata()
    atlas_mask_data = (atlas_data > np.percentile(atlas_data[atlas_data>0], 5)).astype(np.uint8)
    atlas_mask_img = nb.Nifti1Image(atlas_mask_data, atlas_img.affine, atlas_img.header)
    nb.save(atlas_mask_img, atlas_mask_file)

# We need to apply the rigid and affine transforms to the atlas mask
# First, apply the centering (rigid) transform
centered_atlas_mask_file = subbase + ".atlas.cent.mask.nii.gz"
moving_mask = sitk.ReadImage(atlas_mask_file, sitk.sitkUInt8)
fixed_image = sitk.ReadImage(sub_bse_t2, sitk.sitkFloat32)
final_transform = sitk.ReadTransform(cent_transform_file)
moved_mask = sitk.Resample(moving_mask, fixed_image, final_transform, sitk.sitkNearestNeighbor)
sitk.WriteImage(moved_mask, centered_atlas_mask_file)

# Then, apply the affine transform using the displacement field
aff_disp_field, _ = LoadImage(image_only=False)(lin_reg_map_file)
aff_disp_field = EnsureChannelFirst()(aff_disp_field)
centered_mask, _ = LoadImage(image_only=False)(centered_atlas_mask_file)
centered_mask = EnsureChannelFirst()(centered_mask)
warped_atlas_mask = apply_warp(aff_disp_field[None,], centered_mask[None,], centered_mask[None,], interp_mode="nearest")
warped_atlas_mask_data = warped_atlas_mask[0, 0].detach().cpu().numpy()

# Step 2: Create an intensity-based mask from the subject image
print("Creating intensity mask from subject...")
sub_img = nb.load(sub_bse_t2)
sub_data = sub_img.get_fdata()
intensity_mask = (sub_data > np.percentile(sub_data[sub_data > 0], 5)).astype(np.uint8)

# Step 3: Combine the masks (intersection) and clean up
print("Combining masks and cleaning up...")
combined_mask = (warped_atlas_mask_data * intensity_mask).astype(np.uint8)

from scipy import ndimage
combined_mask = ndimage.binary_fill_holes(combined_mask)
combined_mask = ndimage.binary_opening(combined_mask, iterations=2) # Opening removes small bright spots
combined_mask = ndimage.binary_dilation(combined_mask, iterations=3) # Dilate to recover boundary

# Step 4: Save and visualize the final mask
mask_img = nb.Nifti1Image(combined_mask, sub_img.affine, sub_img.header)
target_mask_file = subbase + ".atlas_guided_mask.nii.gz"
nb.save(mask_img, target_mask_file)
print(f"Atlas-guided brain mask saved to: {target_mask_file}")

# Visualize the final mask
plot_anat(sub_bse_t2, vmax=np.percentile(sub_data, 99.9), vmin=0, title="Subject with Atlas-Guided Mask Overlay")
d = plot_anat(sub_bse_t2, vmax=np.percentile(sub_data, 99.9), vmin=0)
d.add_contours(target_mask_file, levels=[0.5], colors='g') # Green for the new mask

# --- Method 2 (RECOMMENDED FOR PRECISION): Manually create/edit mask in ITK-SNAP or BrainSuite ---
# Then load it here:
# target_mask_file = subbase + ".manual_mask.nii.gz"  # Your manually created mask


In [None]:

nonlin_reg_map_file = subbase+".nonlin_ddf.map.nii.gz"
inv_nonlin_reg_map_file = subbase+".inv.nonlin_ddf.map.nii.gz"
centered_atlas_nonlinreg = subbase+".atlas.nonlin.nii.gz"
centered_atlas_nonlinreg_labels = subbase+".atlas.nonlin.label.nii.gz"
jac_det_file = subbase+".jacobian_det.nii.gz"
inv_jac_det_file = subbase+".inv.jacobian_det.nii.gz"

## 3. Initial Data Visualization

Visualize the input data before registration to understand the initial alignment between subject and atlas images.

In [None]:

plot_anat(sub_bse_t2)
d = plot_anat(atlas_bse_t2)
d.add_contours(atlas_labels, cmap="prism")


In [None]:
d=plot_anat(sub_bse_t2, vmax=np.percentile(nb.load(sub_bse_t2).get_fdata(), 99.9),vmin=0)
d.add_contours(atlas_labels, cmap="prism")


## 4. Rigid Registration (Center Alignment)

Perform initial rigid registration to align the center of mass between the atlas and subject images. This step corrects for gross positioning differences and provides a good starting point for subsequent registrations.

**Key Steps:**
1. Initialize centered transform using geometry
2. Perform multi-resolution registration
3. Save forward and inverse transformations
4. Apply transformation to atlas and labels

In [None]:
fixed_image = sitk.ReadImage(sub_bse_t2, sitk.sitkFloat32)
moving_image = sitk.ReadImage(atlas_bse_t2, sitk.sitkFloat32)
initial_transform = sitk.CenteredTransformInitializer(
    fixed_image,
    moving_image,
    sitk.Euler3DTransform(),
    sitk.CenteredTransformInitializerFilter.GEOMETRY,
)

final_transform, _ = multires_registration(
    fixed_image, moving_image, initial_transform)


# save the transformation in a file
sitk.WriteTransform(final_transform, cent_transform_file)

# invert the transform and also save to a file
inv_transform = final_transform.GetInverse()
sitk.WriteTransform(inv_transform, inv_cent_transform_file)

# load from the file and apply the transformation
final_transform = sitk.ReadTransform(cent_transform_file)
moved_image = sitk.Resample(moving_image, fixed_image, final_transform)

sitk.WriteImage(moved_image, centered_atlas)

moving_image = sitk.ReadImage(atlas_labels, sitk.sitkUInt16)
moved_image = sitk.Resample(
    moving_image,
    fixed_image,
    transform=final_transform,
    interpolator=sitk.sitkNearestNeighbor,
)
sitk.WriteImage(moved_image, centered_atlas_labels)





### 4.1 Visualize Rigid Registration Results

Check the quality of the rigid registration by overlaying the centered atlas with the subject image.

In [None]:

plot_anat(centered_atlas)
d=plot_anat(sub_bse_t2, vmax=np.percentile(nb.load(sub_bse_t2).get_fdata(), 99.9),vmin=0)
d.add_contours(centered_atlas_labels, cmap="prism")

## 5. Affine Registration

Perform affine (linear) registration to account for scaling, rotation, and shearing differences between the atlas and subject. This builds upon the rigid registration to provide better anatomical alignment.

**Parameters:**
- **Loss function:** Cross-correlation (CC)
- **Transformation type:** Affine (12 degrees of freedom)

In [None]:
aligner = Aligner()
aligner.affine_reg(
    fixed_file=sub_bse_t2,
    moving_file=centered_atlas,
    output_file=centered_atlas_linreg,
    ddf_file=lin_reg_map_file,
    loss="mse",
)


### 5.1 Apply Affine Transform to Labels

Apply the computed affine transformation to the atlas labels using the displacement field.

In [None]:
d=plot_anat(sub_bse_t2, vmax=np.percentile(nb.load(sub_bse_t2).get_fdata(), 99.9),vmin=0)
d.add_contours(centered_atlas_linreg, cmap="prism")

In [None]:
disp_field, meta = LoadImage(image_only=False)(lin_reg_map_file)
disp_field = EnsureChannelFirst()(disp_field)
print(disp_field.shape)

at1, meta = LoadImage(image_only=False)(centered_atlas_labels)
at_lab = EnsureChannelFirst()(at1)
print(at_lab.shape)

warped_lab = apply_warp(
    disp_field[None,], at_lab[None,], at_lab[None,], interp_mode="nearest"
)
nb.save(
    nb.Nifti1Image(warped_lab[0, 0].detach().cpu().numpy(), at_lab.affine),
    centered_atlas_linreg_labels,
)


### 5.2 Compare Registration Results

Visual comparison between rigid and affine registration results to assess improvement in anatomical alignment.

In [None]:
d = plot_anat(sub_bse_t2, vmax=np.percentile(nb.load(sub_bse_t2).get_fdata(), 99),vmin=0)
d.add_contours(centered_atlas_labels, cmap="prism")
d = plot_anat(sub_bse_t2, vmax=np.percentile(nb.load(sub_bse_t2).get_fdata(), 99.95),vmin=np.percentile(nb.load(sub_bse_t2).get_fdata(), 15))
d.add_contours(centered_atlas_linreg_labels, cmap="hsv")


## 6. Non-linear Registration

Perform deformable (non-linear) registration to capture local anatomical variations that cannot be corrected by linear transformations alone. This step uses a neural network-based approach for precise anatomical matching.

**Parameters optimized for ex vivo brains with artifacts:**
- **Network input size:** 96×96×96 (increased from 64 for better resolution)
- **Learning rate:** 5e-5 (decreased for stability with damaged tissue)
- **Maximum epochs:** 7500 (increased for better convergence)
- **Loss function:** Cross-correlation (CC) - robust to intensity variations
- **Regularization penalty:** 2.0 (increased to prevent overfitting to artifacts)

**Note:** For ex vivo brains with cuts or damage, consider:
1. Creating a brain mask to exclude damaged regions
2. Using the `target_mask` parameter to focus registration on intact tissue
3. Increasing regularization (reg_penalty) to prevent unrealistic deformations


### Tips for Improving Registration Quality with Damaged Ex Vivo Brains

**Common Issues with Ex Vivo Tissue:**
- Cuts, tears, or missing tissue sections
- Uneven fixation causing local distortions
- Artifacts from sectioning or handling
- Different contrast properties vs in vivo imaging

**Strategies to Improve Results:**

1. **Masking** (most important): Create a binary mask excluding damaged regions
   - Use `target_mask` parameter in `nonlinear_reg()`
   - Mask out cuts, tears, and artifact regions
   - Focus registration only on intact tissue

2. **Regularization**: Increase `reg_penalty` (1.5-3.0)
   - Prevents unrealistic deformations at artifact boundaries
   - Smooths out registration in problematic areas
   - Higher values = smoother, more constrained deformations

3. **Network Resolution**: Increase `nn_input_size` (96 or 128)
   - Better captures fine anatomical details
   - May require more GPU memory
   - Generally improves results for high-res images

4. **Learning Rate**: Decrease `lr` (1e-5 to 5e-5)
   - More stable convergence with noisy/damaged data
   - Prevents overshooting during optimization
   - May need more epochs to converge

5. **Training Duration**: Increase `max_epochs` (7500-10000)
   - Allows more time for convergence
   - Monitor loss to ensure convergence (should plateau)
   - Use early stopping if loss stops improving

6. **Loss Function**: Keep `loss="cc"` (cross-correlation)
   - Most robust for intensity variations in ex vivo
   - Alternative: Try `loss="mi"` (mutual information) for severe artifacts

7. **Pre-processing**: Improve affine registration first
   - Ensure good initial alignment before non-linear
   - Consider increasing affine max_epochs
   - Check rigid registration quality visually


In [None]:
nonlin_reg = Warper()
nonlin_reg.nonlinear_reg(
    target_file=sub_bse_t2,
    moving_file=centered_atlas_linreg,
    output_file=centered_atlas_nonlinreg,
    target_mask=target_mask_file,  # Use MASK to exclude damaged regions
    ddf_file=nonlin_reg_map_file,
    inv_ddf_file=inv_nonlin_reg_map_file,
    reg_penalty=2.0,  # INCREASED: More regularization for damaged tissue
    nn_input_size=96,  # INCREASED: Better resolution for fine details
    lr=5e-5,  # DECREASED: More stable convergence
    max_epochs=7500,  # INCREASED: More iterations for convergence
    loss="cc",
    jacobian_determinant_file=jac_det_file,
    inv_jacobian_determinant_file=inv_jac_det_file,
)


### 6.1 Apply Non-linear Transformation

Apply the computed deformation field to transform the atlas labels to match the subject anatomy.

In [None]:
disp_field, meta = LoadImage(image_only=False)(nonlin_reg_map_file)
disp_field = EnsureChannelFirst()(disp_field)
print(disp_field.shape)

at1, meta = LoadImage(image_only=False)(centered_atlas_linreg_labels)
at_lab = EnsureChannelFirst()(at1)
print(at_lab.shape)

warped_lab = apply_warp(
    disp_field[None,], at_lab[None,], at_lab[None,], interp_mode="nearest"
)
nb.save(
    nb.Nifti1Image(
        np.uint16(warped_lab[0, 0].detach().cpu().numpy()), at_lab.affine),
    centered_atlas_nonlinreg_labels,
)


### 6.2 Compare Registration Stages

Visual comparison between affine and non-linear registration results to evaluate the improvement in anatomical detail matching.

In [None]:
d = plot_anat(sub_bse_t2, vmax=np.percentile(nb.load(sub_bse_t2).get_fdata(), 99),vmin=0)
d.add_contours(centered_atlas_linreg_labels, cmap="prism")
d = plot_anat(sub_bse_t2, vmax=np.percentile(nb.load(sub_bse_t2).get_fdata(), 99),vmin=0)
d.add_contours(centered_atlas_nonlinreg_labels, cmap="prism")


## 7. Registration Quality Assessment

Evaluate the quality of the non-linear registration using Jacobian determinant analysis. The Jacobian determinant measures local volume changes and helps identify areas of expansion/compression in the deformation field.

**Interpretation:**
- **Jacobian = 1:** No volume change (perfect preservation)
- **Jacobian > 1:** Local expansion 
- **Jacobian < 1:** Local compression
- **Values close to 0:** Potential registration artifacts

In [None]:
plot_anat(jac_det_file, colorbar=True)

jac=nb.load(jac_det_file)
jac = jac.get_fdata() - 1

from nilearn.image import new_img_like
jac = new_img_like(sub_bse_t2,jac)
plot_stat_map(jac,sub_bse_t2,title='Jac det')

### 7.1 Forward Jacobian Determinant

Analysis of the forward transformation (atlas → subject) Jacobian determinant.

In [None]:
plot_anat(inv_jac_det_file, colorbar=True)

jac=nb.load(inv_jac_det_file)
jac = jac.get_fdata() - 1

from nilearn.image import new_img_like
jac = new_img_like(sub_bse_t2,jac)
plot_stat_map(jac,sub_bse_t2,title='Jac det inv')


### 7.2 Inverse Jacobian Determinant  

Analysis of the inverse transformation (subject → atlas) Jacobian determinant to assess bidirectional registration quality.

## 8. Electrode Deformation Analysis

This section demonstrates specialized electrode deformation modeling, which is crucial for understanding how brain tissue deforms around implanted electrodes. This analysis helps in:

- **Correcting registration artifacts** caused by electrode implantation
- **Modeling tissue displacement** around foreign objects
- **Improving anatomical accuracy** in regions affected by electrodes

The deformation modeling uses specific target points that represent the electrode center and tip positions, allowing for precise modeling of the cylindrical deformation field around the electrode tract.

### Key Parameters:
- **Target points:** Coordinates defining electrode center and tip
- **Deformation model:** Cylindrical field around electrode tract  
- **Label handling:** Nearest-neighbor interpolation for discrete labels

In [None]:
# Import required modules for electrode deformation
from deform_image_by_electrode import deform_image_by_electrode
import SimpleITK as sitk

### 8.1 First Electrode Deformation

Model deformation around the first electrode using the linearly registered atlas as the starting point.

In [None]:
# First electrode deformation example
# Define paths for the first electrode
target_path = subbase + ".atlas.lin.label.nii.gz"
target_electrode_path = subbase + ".atlas.lin.electrode1.label.nii.gz"
target_electrode_deformed_path = subbase + ".atlas.lin.electrode1.deformed.label.nii.gz"
target_electrode_deformed_path = subbase + ".reoriented.atlas.lin.electrode1.deformed.label.nii.gz"

# Define target points for first electrode (center and tip)
target_pts = [[-2.08, -1.74, 6.86], [-2.18, -3.12, -0.54]]

# Perform electrode deformation
deform_image_by_electrode(
    target_path=target_path,
    target_electrode_path=target_electrode_path,
    target_electrode_deformed_path=target_electrode_deformed_path,
    target_pts=target_pts, 
    islabel=True  # mark center of electrode and tip of electrode
)

### 8.2 Second Electrode Deformation

Model deformation around the second electrode using the result from the first electrode deformation. This sequential approach accounts for the cumulative effects of multiple electrode insertions.

In [None]:
# Second electrode deformation example
# Define paths for the second electrode (using first electrode's deformed result as input)
target_path = subbase + ".reoriented.atlas.lin.electrode1.deformed.label.nii.gz"
target_electrode_path = subbase + ".reoriented.atlas.lin.electrode2.label.nii.gz"
target_electrode_deformed_path = subbase + ".reoriented.atlas.lin.electrode2.deformed.label.nii.gz"

# Define target points for second electrode (center and tip)
target_pts = [[-4.00, -5.72, 6.86], [-4.70, -6.62, -0.40]]

# Perform electrode deformation
deform_image_by_electrode(
    target_path=target_path,
    target_electrode_path=target_electrode_path,
    target_electrode_deformed_path=target_electrode_deformed_path,
    target_pts=target_pts, 
    islabel=True  # mark center of electrode and tip of electrode
)

---

## Summary

This notebook demonstrated a complete pipeline for rodent brain atlas registration and electrode deformation analysis:

### Registration Pipeline:
1. **Rigid Registration** - Initial center alignment
2. **Affine Registration** - Linear transformation refinement  
3. **Non-linear Registration** - Deformable registration for precise anatomical matching
4. **Quality Assessment** - Jacobian determinant analysis

### Electrode Deformation Modeling:
- Specialized deformation field modeling around electrode tracts
- Sequential processing for multiple electrodes
- Tissue displacement correction for improved registration accuracy

### Key Outputs:
- Registered atlas images at each transformation stage
- Anatomical labels aligned to subject space
- Quality metrics (Jacobian determinants)
- Electrode deformation-corrected registrations

### Applications:
- Accurate anatomical labeling of rodent brain regions
- Quantification of tissue deformation due to electrode implantation  
- Quality control for registration accuracy
- Research in neuroscience and medical imaging

---

**Note:** File paths in this notebook are configured for a specific dataset (R57). Update the file paths in section 2 to match your data organization before running the pipeline.