In [None]:
from pathlib import Path

import numpy as np
import dipy
import dipy.align, dipy.align.imaffine, dipy.viz, dipy.viz.regtools
import ants
import nibabel as nib

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
data_dir = Path("/mnt/storage/data/pitn/uva/001")
assert data_dir.exists()

In [None]:
t1w_file = data_dir / "sub-001_ses-01_T1w.nii.gz"
t2w_file = data_dir / "sub-001_ses-01_T2w.nii.gz"
dwi_file = data_dir / "sub-001_ses-01_run-2_dwi_epi.nii.gz"
bvals_file = data_dir / "sub-001_ses-01_run-2_dwi.bval"
bvecs_file = data_dir / "sub-001_ses-01_run-2_dwi.bvec"

t1 = nib.load(t1w_file)
t2 = nib.load(t2w_file)
dwi = nib.load(dwi_file)
bvals = np.loadtxt(bvals_file)
bvecs = np.loadtxt(bvecs_file)

In [None]:
_, mni_dir = dipy.data.fetch_mni_template()
mni_file = Path(mni_dir) / "mni_icbm152_t1_tal_nlin_asym_09c.nii"
assert mni_file.exists()

mni = nib.load(mni_file)

## Register T1w to MNI Template

In [None]:
static = mni.get_fdata()
static_affine = mni.affine

moving = t1.get_fdata()
moving_affine = t1.affine

In [None]:
for slice_i in range(3):
    dipy.viz.regtools.overlay_slices(
        static,
        dipy.align.resample(moving, static, moving_affine, static_affine).get_fdata(),
        None,
        slice_i,
        "Static",
        "Moving",
    ).set_dpi(170)

In [None]:
center_transform = dipy.align.imaffine.transform_centers_of_mass(
    static=static,
    static_grid2world=static_affine,
    moving=moving,
    moving_grid2world=moving_affine,
)
center_moving = center_transform.transform(moving)

In [None]:
for slice_i in range(3):
    dipy.viz.regtools.overlay_slices(
        static,
        center_moving,
        None,
        slice_i,
        "Static",
        "Centered",
    ).set_dpi(170)

In [None]:
# Perform rigid body registration.

static_ants = ants.from_numpy(static, origin=static_affine[:-1, -1].tolist())
center_moving_ants = ants.from_numpy(center_moving, origin=static_affine[:-1, -1].tolist())

sigmas = [2, 1, 0]
level_iters = [10000, 1000, 100]
factors = [4, 2, 1]
nbins = 32

registered = ants.registration(
    static_ants,
    center_moving_ants,
    "Rigid",
    aff_iterations=(2100, 1200, 1200, 10),
    aff_shrink_factors=(6, 4, 2, 1),
    aff_smoothing_sigmas=(3, 2, 1, 0),
    aff_sampling=32,
)
print(registered)

In [None]:
transformed = registered['warpedmovout'].numpy()

In [None]:
for slice_i in range(3):
    dipy.viz.regtools.overlay_slices(
        static,
        transformed,
        None,
        slice_i,
        "Static",
        "Registered",
    ).set_dpi(170)

In [None]:
register_affine = ants.read_transform(registered['fwdtransforms'][0], dimension=3).parameters
print(register_affine)
register_rot = register_affine[:-3].reshape(3, 3)
register_translate = register_affine[-3:]
print(register_rot)
print(register_translate)

In [None]:
# aff_mat = f_transform.affine[:3, :3]

# u, s, vT = np.linalg.svd(aff_mat)

# # Scale in direction of eigenvectors.
# P = vT.T @ np.diag(s) @ vh
# # Rotation matrix
# R = u @ vT

# eigvals, eigvecs = np.linalg.eig(aff_mat)

In [None]:
t1 = nib.Nifti1Image(transformed, static_affine)

## Register T2w to T1w

In [None]:
static = t1.get_fdata()
static_affine = t1.affine

moving = t2.get_fdata()
moving_affine = t2.affine

In [None]:
for slice_i in range(3):
    dipy.viz.regtools.overlay_slices(
        static,
        dipy.align.resample(moving, static, moving_affine, static_affine).get_fdata(),
        None,
        slice_i,
        "Static",
        "Moving",
    ).set_dpi(170)

In [None]:
center_transform = dipy.align.imaffine.transform_centers_of_mass(
    static=static,
    static_grid2world=static_affine,
    moving=moving,
    moving_grid2world=moving_affine,
)
center_moving = center_transform.transform(moving)

In [None]:
for slice_i in range(3):
    dipy.viz.regtools.overlay_slices(
        static,
        center_moving,
        None,
        slice_i,
        "Static",
        "Centered",
    ).set_dpi(170)

In [None]:
# Perform rigid body registration.

static_ants = ants.from_numpy(static, origin=static_affine[:-1, -1].tolist())
center_moving_ants = ants.from_numpy(center_moving, origin=static_affine[:-1, -1].tolist())

aff_iterations=(2100, 1200, 1200, 10)
aff_shrink_factors=(6, 4, 2, 1)
aff_smoothing_sigmas=(3, 2, 1, 0)
aff_sampling=32

registered = ants.registration(
    static_ants,
    center_moving_ants,
    "Rigid",
    aff_iterations=aff_iterations,
    aff_shrink_factors=aff_shrink_factors,
    aff_smoothing_sigmas=aff_smoothing_sigmas,
    aff_sampling=aff_sampling,
)
print(registered)

In [None]:
transformed = registered['warpedmovout'].numpy()

In [None]:
for slice_i in range(3):
    dipy.viz.regtools.overlay_slices(
        static,
        transformed,
        None,
        slice_i,
        "Static",
        "Registered",
    ).set_dpi(170)

In [None]:
register_affine = ants.read_transform(registered['fwdtransforms'][0], dimension=3).parameters
print(register_affine)
register_rot = register_affine[:-3].reshape(3, 3)
register_translate = register_affine[-3:]
print(register_rot)
print(register_translate)

In [None]:
t2 = nib.Nifti1Image(transformed, static_affine)

## Register Average of $b_0$ DWIs to T2w

In [None]:
b0_idx = np.where(bvals == 0)[0]
b0_all = dwi.get_fdata()[..., b0_idx]
avg_b0 = np.mean(b0_all, axis=-1)
moving = avg_b0
moving_affine = dwi.affine

static = t2.get_fdata()
static_affine = t2.affine

In [None]:
for slice_i in range(3):
    dipy.viz.regtools.overlay_slices(
        static,
        dipy.align.resample(moving, static, moving_affine, static_affine).get_fdata(),
        None,
        slice_i,
        "Static",
        "Moving",
    ).set_dpi(170)

In [None]:
center_transform = dipy.align.imaffine.transform_centers_of_mass(
    static=static,
    static_grid2world=static_affine,
    moving=moving,
    moving_grid2world=moving_affine,
)
center_moving = center_transform.transform(moving)

In [None]:
for slice_i in range(3):
    dipy.viz.regtools.overlay_slices(
        static,
        center_moving,
        None,
        slice_i,
        "Static",
        "Centered",
    ).set_dpi(170)

In [None]:
# Perform rigid body registration.

static_ants = ants.from_numpy(static, origin=static_affine[:-1, -1].tolist())
center_moving_ants = ants.from_numpy(center_moving, origin=static_affine[:-1, -1].tolist())

aff_iterations=(2100, 1200, 1200, 10)
aff_shrink_factors=(6, 4, 2, 1)
aff_smoothing_sigmas=(3, 2, 1, 0)
aff_sampling=32

registered = ants.registration(
    static_ants,
    center_moving_ants,
    "Rigid",
    aff_iterations=aff_iterations,
    aff_shrink_factors=aff_shrink_factors,
    aff_smoothing_sigmas=aff_smoothing_sigmas,
    aff_sampling=aff_sampling,
)
print(registered)

In [None]:
transformed = registered['warpedmovout'].numpy()

In [None]:
for slice_i in range(3):
    dipy.viz.regtools.overlay_slices(
        static,
        transformed,
        None,
        slice_i,
        "Static",
        "Registered",
    ).set_dpi(170)

In [None]:
register_transform = ants.read_transform(registered['fwdtransforms'][0], dimension=3)
register_affine = register_transform.parameters
print(register_affine)

In [None]:
# Perform the centering and registration transforms on each DWI volume.
new_dwi = list()

for i, dwi_bk in enumerate(np.moveaxis(dwi.get_fdata(), -1, 0)):
    print(i, "out of ", dwi.shape[-1], end=' | ')
    center_dwi_bk = center_transform.transform(dwi_bk)
    center_dwi_ants = ants.from_numpy(center_dwi_bk)
    transformed_dwi_bk = register_transform.apply_to_image(center_dwi_ants)
    new_dwi.append(transformed_dwi_bk.numpy())

new_dwi = np.stack(new_dwi, axis=-1)

In [None]:
# Rotate the bvecs to match the new orientation of the DWI.
full_affine = center_transform.codomain_grid2world.dot(center_transform.affine)
full_affine = full_affine.dot(np.linalg.inv(center_transform.domain_grid2world))

# Construct the registration affine matrix.
registration_affine = np.eye(4)
register_rot = register_affine[:-3].reshape(3, 3)
registration_affine[:-1, :-1] = register_rot
register_translate = register_affine[-3:]
registration_affine[:-1, -1] = register_translate

full_affine = full_affine.dot(registration_affine)
print(full_affine)

full_rot = full_affine[:-1, :-1]
new_bvecs = full_rot @ bvecs

In [None]:
i = 100
print(bvecs[:, i])
print(full_rot @ bvecs[:, i])
print((full_rot @ bvecs)[:, i])
print(np.dot(full_rot, bvecs[:, i]))

In [None]:
dwi = nib.Nifti1Image(new_dwi, static_affine)

## Save out images to file

In [None]:
output_dir = data_dir / "python_registration"
output_dir.mkdir(exist_ok=True, parents=True)

t1_out = output_dir / t1w_file.name
nib.save(t1, t1_out)

t2_out = output_dir / t2w_file.name
nib.save(t2, t2_out)

dwi_out = output_dir / dwi_file.name
nib.save(dwi, dwi_out)

bvecs_out = output_dir / bvecs_file.name
np.savetxt(bvecs_out, new_bvecs, fmt="%.10g")

bvals_out = output_dir / bvals_file.name
np.savetxt(bvals_out, bvals, fmt="%.10g")

In [None]:
static = t1.get_fdata()
static_affine = t1.affine

moving = t2.get_fdata()
moving_affine = t2.affine

In [None]:
# Compose the pipeline of transformations.
transform_pipeline = [
    dipy.align.center_of_mass,
    dipy.align.translation,
    dipy.align.rigid,
    dipy.align.affine,
]

# Transform parameters.
# Set parameters for a Gaussian Pyramid of registrations.
level_iters = [10000, 1000, 100]
sigmas = [3.0, 1.0, 0.0]
factors = [4, 2, 1]
nbins = 32

In [None]:
transformed, transformed_affine = dipy.align.affine_registration(
    moving,
    static,
    moving_affine=moving_affine,
    static_affine=static_affine,
    nbins=nbins,
    metric="MI",
    pipeline=transform_pipeline,
    level_iters=level_iters,
    sigmas=sigmas,
    factors=factors,
)

In [None]:
dipy.viz.regtools.overlay_slices(
    static, transformed, None, 0, "Static", "Registered", dpi=300
).set_dpi(170)
dipy.viz.regtools.overlay_slices(
    static, transformed, None, 1, "Static", "Registered", dpi=300
).set_dpi(170)
dipy.viz.regtools.overlay_slices(
    static, transformed, None, 2, "Static", "Registered", dpi=300
).set_dpi(170)

Register T1w to $b=0$ DWI

Taken from <https://dipy.org/documentation/1.4.1./examples_built/affine_registration_3d/#example-affine-registration-3d>

In [None]:
# # Set up static and moving images.
# static = dwi.get_fdata()[..., b0_idx]
# static_affine = dwi.affine
# moving = t1.get_fdata()
# moving_affine = t1.affine

In [None]:
# # Create the registration metric.
# nbins = 32
# sampling_prop = None
# metric = dipy.align.imaffine.MutualInformationMetric(nbins, sampling_prop)

# # Set up multi-resolution registration params.
# level_iters = [10000, 1000, 100]
# sigmas = [3.0, 1.0, 0.0]
# factors = [4, 2, 1]

# # Create registration class
# reg = dipy.align.imaffine.AffineRegistration(
#     metric=metric, level_iters=level_iters, sigmas=sigmas, factors=factors
# )

# # Specify transformation to perform.
# transform = dipy.align.transforms.RigidTransform3D()

In [None]:
# params0 = None

# rigid_scale = reg.optimize(
#     static,
#     moving,
#     transform,
#     params0,
#     static_grid2world=static_affine,
#     moving_grid2world=moving_affine,
#     starting_affine="mass",
# )
# transformed = rigid_scale.transform(moving)

In [None]:
# # plt.figure(dpi=120, figsize=(8, 11))
# dipy.viz.regtools.overlay_slices(
#     static, transformed, None, 0, "Static", "Registered", dpi=300
# ).set_dpi(170)
# dipy.viz.regtools.overlay_slices(
#     static, transformed, None, 1, "Static", "Registered", dpi=300
# ).set_dpi(170)
# dipy.viz.regtools.overlay_slices(
#     static, transformed, None, 2, "Static", "Registered", dpi=300
# ).set_dpi(170)

In [None]:
# # Save the registered image.
# transformed_nib_img = nib.Nifti1Image(transformed, affine=rigid_scale.affine)
# nib.save(transformed_nib_img, data_dir / "sub-001_ses-01_T1w_reg_to_b0_DWI.nii.gz")

## Register T2w to $b = 0$ DWI 

Do we need to register the T2w separately? Or will the transform found with the T1w work?

In [None]:
# dipy.viz.regtools.overlay_slices(
#     t1.get_fdata(), t2.get_fdata(), None, 0, "T1", "T2", dpi=300
# ).set_dpi(170)
# dipy.viz.regtools.overlay_slices(
#     t1.get_fdata(), t2.get_fdata(), None, 1, "T1", "T2", dpi=300
# ).set_dpi(170)
# dipy.viz.regtools.overlay_slices(
#     t1.get_fdata(), t2.get_fdata(), None, 2, "T1", "T2", dpi=300
# ).set_dpi(170)

In [None]:
# plt.imshow(np.abs(t1.affine - t2.affine))
# plt.colorbar()

Nope, we'll need to find the T2w -> DWI registration on its own. The T1w and T2w affines are largely similar, but different enough to warrant another registration.

In [None]:
# static = dwi.get_fdata()[..., b0_idx]
# static_affine = dwi.affine
# moving = t2.get_fdata()
# moving_affine = t2.affine

In [None]:
# # Create the registration metric.
# nbins = 32
# sampling_prop = None
# metric = dipy.align.imaffine.MutualInformationMetric(nbins, sampling_prop)

# # Set up multi-resolution registration params.
# level_iters = [10000, 1000, 100]
# sigmas = [3.0, 1.0, 0.0]
# factors = [4, 2, 1]

# # Create registration class
# reg = dipy.align.imaffine.AffineRegistration(
#     metric=metric, level_iters=level_iters, sigmas=sigmas, factors=factors
# )

# # Specify transformation to perform.
# transform = dipy.align.transforms.RigidTransform3D()

In [None]:
# params0 = None

# rigid_scale = reg.optimize(
#     static,
#     moving,
#     transform,
#     params0,
#     static_grid2world=static_affine,
#     moving_grid2world=moving_affine,
#     starting_affine="mass",
# )
# transformed = rigid_scale.transform(moving)

In [None]:
# dipy.viz.regtools.overlay_slices(
#     static, transformed, None, 0, "Static", "Registered", dpi=300
# ).set_dpi(170)
# dipy.viz.regtools.overlay_slices(
#     static, transformed, None, 1, "Static", "Registered", dpi=300
# ).set_dpi(170)
# dipy.viz.regtools.overlay_slices(
#     static, transformed, None, 2, "Static", "Registered", dpi=300
# ).set_dpi(170)

In [None]:
# # Save the registered image.
# transformed_nib_img = nib.Nifti1Image(transformed, affine=rigid_scale.affine)
# nib.save(transformed_nib_img, data_dir / "sub-001_ses-01_T2w_reg_to_b0_DWI.nii.gz")

---

In [None]:
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash