## Import necessary packages

In [1]:
import os
import numpy as np
import nibabel as nib
import dipy.reconst.dti as dti
import dipy.core.gradients as dpg
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline

## Data overview

In [2]:
img_t1 = nib.load("AA_041/AA_041_2_MPRAGE_GRAPPA2.nii.gz")
data_t1 = img_t1.get_fdata()

In [3]:
print(img_t1.shape)

(176, 240, 256)


In [4]:
plot_anat(img_t1, draw_cross=False, display_mode='z')

NameError: name 'plot_anat' is not defined

## Visualization of Diffusion Metrics

In [None]:
# Define the subject ID and DWI files
subject_id = "AA_041"
dwi_file = f"{subject_id}/{subject_id}_8_ep2d_diff_30_iso.nii.gz"
bval_file = f"{subject_id}/{subject_id}_8_ep2d_diff_30_iso.bval"
bvec_file = f"{subject_id}/{subject_id}_8_ep2d_diff_30_iso.bvec"

# Load the DWI data
dwi_img = nib.load(dwi_file)
data = dwi_img.get_fdata()
affine = dwi_img.affine

# Load the b-values and b-vectors
bvals = np.loadtxt(bval_file)
bvecs = np.loadtxt(bvec_file).T

# Create a mask to exclude background
mask = data[..., 0] > 0

# Prepare the gradient table
gtab = dpg.gradient_table(bvals, bvecs)

# Fit the diffusion tensor model
dti_model = dti.TensorModel(gtab)
dti_fit = dti_model.fit(data, mask=mask)

# Calculate FA, MD, AD, and RD
fa = dti_fit.fa
md = dti_fit.md
ad = dti_fit.ad
rd = dti_fit.rd

# Plot the diffusion metrics
fig1, ax = plt.subplots(1, 4, figsize=(16, 4),
                        subplot_kw={'xticks': [], 'yticks': []})

fig1.subplots_adjust(hspace=0.3, wspace=0.05)

ax.flat[0].imshow(fa[:, :, fa.shape[2] // 2], cmap="gray", origin="lower")
ax.flat[0].set_title("Functional Anisotropy (FA)")
ax.flat[1].imshow(md[:, :, md.shape[2] // 2], cmap="gray", origin="lower")
ax.flat[1].set_title("Mean Diffusivity (MD)")
ax.flat[2].imshow(ad[:, :, ad.shape[2] // 2], cmap="gray", origin="lower")
ax.flat[2].set_title("Axial Diffusivity (AD)")
ax.flat[3].imshow(rd[:, :, rd.shape[2] // 2], cmap="gray", origin="lower")
ax.flat[3].set_title("Radial Diffusivity (RD)")


plt.tight_layout()
plt.show()