In [1]:
import os
import scipy
import logging
import tempfile
import numpy as np
import nibabel as nib
import multiprocessing
from dipy.data import get_sphere
from dipy.io import read_bvals_bvecs
from joblib import Parallel, delayed
from dipy.core.sphere import Sphere
from dipy.reconst.shm import sf_to_sh
from dipy.core.gradients import gradient_table_from_bvals_bvecs
from scilpy.reconst.multi_processes import fit_from_model, convert_sh_basis
from scilpy.reconst.raw_signal import compute_sh_coefficients
from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel
from dipy.reconst.shm import real_sh_descoteaux_from_index, sh_to_sf
from scilpy.utils.bvec_bval_tools import (DEFAULT_B0_THRESHOLD,
                                          check_b0_threshold, identify_shells,
                                          is_normalized_bvecs, normalize_bvecs)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
n = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/dwi_fodf.nii.gz').get_fdata()
vol = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/dwi_fodf.nii.gz')
og_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/bvec_fodf.bvec'
ob_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/bval_fodf.bval'
vec_folder = '/home/local/VANDERBILT/kanakap/try_fx_emp/bvec'
val_folder = '/home/local/VANDERBILT/kanakap/try_fx_emp/bval'
mask = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/mask.nii.gz').get_fdata()
mean_b0_vol = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/b0_mean.nii.gz').get_fdata()

In [48]:
n = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/sub-cIVs001_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.nii.gz').get_fdata()
vol = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/sub-cIVs001_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.nii.gz')
og_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/sub-cIVs001_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.bvec'
ob_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/sub-cIVs001_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.bval'
vec_folder = '/home/local/VANDERBILT/kanakap/try_fx_emp/bvec_tensor'
val_folder = '/home/local/VANDERBILT/kanakap/try_fx_emp/bval_tensor'
mask = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/mask.nii.gz').get_fdata()
# mean_b0_vol = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/b0_mean.nii.gz').get_fdata()

In [2]:
def compute_dwi_attenuation(dwi_weights: np.ndarray, b0: np.ndarray):
    """ Compute signal attenuation by dividing the dwi signal with the b0.
    Parameters:
    -----------
    dwi_weights : np.ndarray of shape (X, Y, Z, #gradients)
        Diffusion weighted images.
    b0 : np.ndarray of shape (X, Y, Z)
        B0 image.
    Returns
    -------
    dwi_attenuation : np.ndarray
        Signal attenuation (Diffusion weights normalized by the B0).
    """
    b0 = b0[..., None]  # Easier to work if it is a 4D array.

    # Make sure that, in every voxels, weights are lower in the b0. Should
    # always be the case, but with the noise we never know!
    erroneous_voxels = np.any(dwi_weights > b0, axis=3)
    nb_erroneous_voxels = np.sum(erroneous_voxels)
    if nb_erroneous_voxels != 0:
        logging.info("# of voxels where `dwi_signal > b0` in any direction: "
                     "{}".format(nb_erroneous_voxels))
        dwi_weights = np.minimum(dwi_weights, b0)

    # Compute attenuation
    dwi_attenuation = dwi_weights / b0

    # Make sure we didn't divide by 0.
    dwi_attenuation[np.logical_not(np.isfinite(dwi_attenuation))] = 0.

    return dwi_attenuation

In [3]:
# DWI TO SF (With the voxelwise bvec and bval)
def reconstruct_signal_at_voxel(i,j,k,n,og_bvec,og_bval,bvec_stack,bval_stack,dwi_hat):
        dwi = n[i][j][k]
        og_gradient_table = gradient_table_from_bvals_bvecs(og_bval, og_bvec)
        vec = bvec_stack[i,j,k,:,:]
        val = bval_stack[i,j,k,:]
        gradient_table = gradient_table_from_bvals_bvecs(val, vec)
        sh_order=10
        basis_type='tournier07'
        smooth=0.00
        use_attenuation=True
        force_b0_threshold=False
        mask=None
        sphere=None

        # Extracting infos
        b0_mask = gradient_table.b0s_mask
        bvecs = gradient_table.bvecs
        bvals = gradient_table.bvals
        
        dwi = np.reshape(dwi,[1,1,1,bvals.shape[0]])

        if not is_normalized_bvecs(bvecs):
                logging.warning("Your b-vectors do not seem normalized...")
                bvecs = normalize_bvecs(bvecs)

        b0_threshold = check_b0_threshold(force_b0_threshold, bvals.min())

        # Ensure that this is on a single shell.
        shell_values, _ = identify_shells(bvals)
        shell_values.sort()
        # if shell_values.shape[0] != 2 or shell_values[0] > b0_threshold:
        #     raise ValueError("Can only work on single shell signals.")

        # Keeping b0-based infos
        bvecs = bvecs[np.logical_not(b0_mask)]
        weights = dwi[..., np.logical_not(b0_mask)]

        # scale singal with bval correction 
        b0 = dwi[..., b0_mask].mean(axis=3)
        norm_gg = np.divide(bvals[np.logical_not(b0_mask)] , og_bval[np.logical_not(b0_mask)])
        weights_scaled = b0 * np.exp (np.divide( (np.log (np.divide(weights,b0)) ) , norm_gg))

        # Compute attenuation using the b0.
        if use_attenuation:
                weights_scaled = compute_dwi_attenuation(weights_scaled, b0)

        # # Get cartesian coords from bvecs
        sphere = Sphere(xyz=bvecs)

        # SF TO SH
        # Fit SH
        sh = sf_to_sh(weights_scaled, sphere, sh_order, basis_type, smooth=smooth)

        # Apply mask
        if mask is not None:
                sh *= mask[..., None]

        # Reconstructing DWI
        # SH to SF 
        og_bvecs = og_gradient_table.bvecs

        if not is_normalized_bvecs(og_bvecs):
                logging.warning("Your b-vectors do not seem normalized...")
                og_bvecs = normalize_bvecs(og_bvecs)

        og_bvecs = og_bvecs[np.logical_not(b0_mask)]

        og_sphere = Sphere(xyz=og_bvecs)

        sf = sh_to_sf(sh, og_sphere, sh_order=sh_order, basis_type=basis_type)

        # SF TO DWI (inverse of compute_dwi_attenuation) here weights_hat is DWI with bvec corrected
        b0 = b0[..., None]
        weights_hat = sf * b0 
        dwi_hat[i,j,k,:] = weights_hat


In [None]:
print(n.shape[0],n.shape[1],n.shape[2],ind_0_2000[0].shape[0])

In [None]:
ind_b0 = np.nonzero(og_bval==0)
ind_non_b0 = np.nonzero(og_bval)
ind_b0 = np.squeeze(ind_b0)
ind_non_b0 = np.squeeze(ind_non_b0)
num_cores = 10
path = tempfile.mkdtemp()
dwi_hat_path = os.path.join(path,'dwi_hat.mmap')
dwi_hat = np.memmap(dwi_hat_path, dtype=float, shape=(n.shape[0],n.shape[1],n.shape[2],56), mode='w+')
xaxis = range(n.shape[0])
yaxis = range(n.shape[1])
zaxis = range(n.shape[2]) 
# mean_b0 = np.mean(n[:,:,:,ind_b0] , 3)

# Run parallal
results = Parallel(n_jobs=num_cores)(delayed(reconstruct_signal_at_voxel)(i,j,k,n,og_bvec,og_bval,bvec_stack,bval_stack,dwi_hat) for k in zaxis for j in yaxis for i in xaxis)


In [None]:
path

In [4]:
# TO RUN FOR SHELLS SEPARATELY 
n = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/sub-cIVs001_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.nii.gz').get_fdata()
vol = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/sub-cIVs001_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.nii.gz')
og_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/sub-cIVs001_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.bvec'
ob_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/sub-cIVs001_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.bval'
vec_folder = '/home/local/VANDERBILT/kanakap/try_fx_emp/bvec_tensor'
val_folder = '/home/local/VANDERBILT/kanakap/try_fx_emp/bval_tensor'
mask = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/mask.nii.gz').get_fdata()

og_bval, og_bvec = read_bvals_bvecs(ob_file,og_file)
ind_1000 = np.where(og_bval == 1000)
ind_2000 = np.where(og_bval == 2000)
ind_b0 = np.nonzero(og_bval==0)
ind_b0 = np.squeeze(ind_b0)
ind_0_1000 = np.where((og_bval == 0) | (og_bval == 1000))
ind_0_2000 = np.where((og_bval == 0) | (og_bval == 2000))

# LOAD THE VOXELWISE BVALS AND BVECS
bvec_vols = []
for i in sorted(os.listdir(vec_folder)):
    if i.endswith('.nii.gz'):
        bvec_vol = nib.load(vec_folder + '/' + i).get_fdata()
        bvec_vol = np.expand_dims(bvec_vol,4)
        bvec_vol = np.transpose(bvec_vol,(0,1,2,4,3))
        bvec_vols.append(bvec_vol)
bvec_stack = np.stack(bvec_vols,3)
bvec_stack = bvec_stack.squeeze()

bval_vols = []
for i in sorted(os.listdir(val_folder)):
    if i.endswith('.nii.gz'):
        bval_vol = nib.load(val_folder + '/' + i).get_fdata()
        bval_vols.append(bval_vol)
bval_stack = np.stack(bval_vols,3)

num_cores = 10
path = tempfile.mkdtemp()
xaxis = range(n.shape[0])
yaxis = range(n.shape[1])
zaxis = range(n.shape[2]) 

# for dwi with 0 1000
len1 = ind_0_1000[0]
dwi_hat_path1 = os.path.join(path,'dwi_hat1.mmap')
dwi_hat1 = np.memmap(dwi_hat_path1, dtype=float, shape=(n.shape[0],n.shape[1],n.shape[2],ind_1000[0].shape[0]), mode='w+')
data = n[:,:,:,len1]
org_bvec = og_bvec[len1,:]
org_bval = og_bval[len1]
corr_bvec = bvec_stack[:,:,:,len1,:]
corr_bval = bval_stack[:,:,:,len1]
results = Parallel(n_jobs=num_cores)(delayed(reconstruct_signal_at_voxel)(i,j,k,data,org_bvec,org_bval,corr_bvec,corr_bval,dwi_hat1) for k in zaxis for j in yaxis for i in xaxis)




In [None]:
bvec_stack[:,:,:,len2,:].shape

In [5]:
# for dwi with 0 2000
len2 = ind_0_2000[0]
dwi_hat_path2 = os.path.join(path,'dwi_hat2.mmap')
dwi_hat2 = np.memmap(dwi_hat_path2, dtype=float, shape=(n.shape[0],n.shape[1],n.shape[2],ind_2000[0].shape[0]), mode='w+')
data = n[:,:,:,len2]
org_bvec = og_bvec[len2,:]
org_bval = og_bval[len2]
corr_bvec = bvec_stack[:,:,:,len2,:]
corr_bval = bval_stack[:,:,:,len2]

results = Parallel(n_jobs=num_cores)(delayed(reconstruct_signal_at_voxel)(i,j,k,data,org_bvec,org_bval,corr_bvec,corr_bval,dwi_hat2) for k in zaxis for j in yaxis for i in xaxis)



In [None]:
bval_stack[:,:,:,len1].shape

In [None]:
og_bvec.shape

In [None]:
ind_b0 = np.nonzero(og_bval==0)
ind_b0 = np.squeeze(ind_b0)
ind_b0

In [None]:
n[28][38][23]#[..., np.logical_not(b0_mask)]

In [None]:
dwi_hat[28][38][23]

In [6]:
dwmri_corrected = np.zeros((n.shape[0],n.shape[1],n.shape[2],n.shape[3]))

dwmri_corrected[:,:,:,ind_b0] = n[:,:,:,ind_b0] 
dwmri_corrected[:,:,:,ind_1000[0]] = dwi_hat1
dwmri_corrected[:,:,:,ind_2000[0]] = dwi_hat2


nib.save(nib.Nifti1Image(dwmri_corrected.astype(np.float32),vol.affine),"dwi_hat_combined_whole_img.nii" )

In [None]:
bvec_stack.shape

In [None]:
# DWI TO SF
def reconstruct_signal_at_voxel_bval(i,j,k,mask,dwi_hat,mean_b0_vol,ind_non_b0_shape,og_bval,bval_stack,dwi_hat_bval):
    #if mask[i,j,k] == 1:
    dwi = dwi_hat[i][j][k]
    b0 = mean_b0_vol[i][j][k]
    val = bval_stack[i,j,k,:]
    vec = bvec_stack[i,j,k,:,:]
    gradient_table = gradient_table_from_bvals_bvecs(val, vec)

    # Extracting infos
    b0_mask = gradient_table.b0s_mask
    bvecs = gradient_table.bvecs
    bvals = gradient_table.bvals
    dwi = np.reshape(dwi,[1,1,1,ind_non_b0_shape])

    norm_gg = np.divide(bvals[np.logical_not(b0_mask)] , og_bval[np.logical_not(b0_mask)])


    dwi_hat_bval[i,j,k,:] = b0 * np.exp (np.divide( (np.log (np.divide(dwi,b0)) ) , norm_gg))


In [None]:
num_cores = 10

dwi_hat_bval_path = os.path.join(path,'dwi_hat_bval.mmap')
dwi_hat_bval = np.memmap(dwi_hat_bval_path, dtype=float, shape=(n.shape[0],n.shape[1],n.shape[2],ind_non_b0.shape[0]), mode='w+')
xaxis = range(n.shape[0])
yaxis = range(n.shape[1])
zaxis = range(n.shape[2]) 
mean_b0 = np.mean(n[:,:,:,ind_b0] , 3)
# get info 
results = Parallel(n_jobs=num_cores)(delayed(reconstruct_signal_at_voxel_bval)(i,j,k,mask,dwi_hat,mean_b0,ind_non_b0.shape[0],og_bval,bval_stack,dwi_hat_bval) for k in zaxis for j in yaxis for i in xaxis)

In [None]:
# Add the b0 volume back 
ind_b0 = np.nonzero(og_bval==0)
ind_non_b0 = np.nonzero(og_bval)
ind_b0 = np.squeeze(ind_b0)
ind_non_b0 = np.squeeze(ind_non_b0)

dwmri_corrected_bval = np.zeros((n.shape[0],n.shape[1],n.shape[2],n.shape[3]))

dwmri_corrected_bval[:,:,:,ind_b0] = n[:,:,:,ind_b0] 
dwmri_corrected_bval[:,:,:,ind_non_b0] = dwi_hat_bval 
dwmri_corrected_bval = np.nan_to_num(dwmri_corrected_bval)

nib.save(nib.Nifti1Image(dwmri_corrected_bval.astype(np.float32),vol.affine),"dwi_hat_bval_whole_img.nii" )


In [None]:
import shutil
try:
    shutil.rmtree(path)
except:
    print("Couldn't delete folder")

In [None]:
dwmri_corrected[:,:,:,ind_b0].shape

In [None]:
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure(figsize = (7, 7))
ax = plt.axes(projection ="3d")
for i in range(72):
    ax.scatter3D(vec[i][0], vec[i][1], vec[i][2], color ='b')