In [137]:
import os
import scipy
import logging
import tempfile
import numpy as np
import nibabel as nib
import multiprocessing
from dipy.data import get_sphere
from dipy.io import read_bvals_bvecs
from joblib import Parallel, delayed
from dipy.core.sphere import Sphere
from dipy.reconst.shm import sf_to_sh
from dipy.core.gradients import gradient_table_from_bvals_bvecs
from scilpy.reconst.multi_processes import fit_from_model, convert_sh_basis
from scilpy.reconst.raw_signal import compute_sh_coefficients
from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel
from dipy.reconst.shm import real_sh_descoteaux_from_index, sh_to_sf
from scilpy.utils.bvec_bval_tools import (DEFAULT_B0_THRESHOLD,
                                          check_b0_threshold, identify_shells,
                                          is_normalized_bvecs, normalize_bvecs)

In [138]:
n = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/dwi_fodf.nii.gz').get_fdata()
vol = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/dwi_fodf.nii.gz')
og_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/bvec_fodf.bvec'
ob_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/bval_fodf.bval'
vec_folder = '/home/local/VANDERBILT/kanakap/try_fx_emp/bvec'
val_folder = '/home/local/VANDERBILT/kanakap/try_fx_emp/bval'
mask = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/mask.nii.gz').get_fdata()
mean_b0_vol = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs001/ses-s1Bx2/prequal_dwi_cat/b0_mean.nii.gz').get_fdata()
og_bval, og_bvec = read_bvals_bvecs(ob_file,og_file)

bvec_vols = []
for i in sorted(os.listdir(vec_folder)):
    if i.endswith('.nii.gz'):
        bvec_vol = nib.load(vec_folder + '/' + i).get_fdata()
        bvec_vol = np.expand_dims(bvec_vol,4)
        bvec_vol = np.transpose(bvec_vol,(0,1,2,4,3))
        bvec_vols.append(bvec_vol)
bvec_stack = np.stack(bvec_vols,3)
bvec_stack = bvec_stack.squeeze()

bval_vols = []
for i in sorted(os.listdir(val_folder)):
    if i.endswith('.nii.gz'):
        bval_vol = nib.load(val_folder + '/' + i).get_fdata()
        bval_vols.append(bval_vol)
bval_stack = np.stack(bval_vols,3)

In [139]:
def compute_dwi_attenuation(dwi_weights: np.ndarray, b0: np.ndarray):
    """ Compute signal attenuation by dividing the dwi signal with the b0.
    Parameters:
    -----------
    dwi_weights : np.ndarray of shape (X, Y, Z, #gradients)
        Diffusion weighted images.
    b0 : np.ndarray of shape (X, Y, Z)
        B0 image.
    Returns
    -------
    dwi_attenuation : np.ndarray
        Signal attenuation (Diffusion weights normalized by the B0).
    """
    b0 = b0[..., None]  # Easier to work if it is a 4D array.

    # Make sure that, in every voxels, weights are lower in the b0. Should
    # always be the case, but with the noise we never know!
    erroneous_voxels = np.any(dwi_weights > b0, axis=3)
    nb_erroneous_voxels = np.sum(erroneous_voxels)
    if nb_erroneous_voxels != 0:
        logging.info("# of voxels where `dwi_signal > b0` in any direction: "
                     "{}".format(nb_erroneous_voxels))
        dwi_weights = np.minimum(dwi_weights, b0)

    # Compute attenuation
    dwi_attenuation = dwi_weights / b0

    # Make sure we didn't divide by 0.
    dwi_attenuation[np.logical_not(np.isfinite(dwi_attenuation))] = 0.

    return dwi_attenuation

In [147]:
# DWI TO SF
def reconstruct_signal_at_voxel(i,j,k,mask,n,og_bval,og_bvec,bvec_stack,bval_stack,dwi_hat):
        #if mask[i,j,k] == 1:
        dwi = n[i][j][k]
        og_gradient_table = gradient_table_from_bvals_bvecs(og_bval, og_bvec)
        vec = bvec_stack[i,j,k,:,:]
        val = bval_stack[i,j,k,:]
        gradient_table = gradient_table_from_bvals_bvecs(val, vec)
        sh_order=8
        basis_type='tournier07'
        smooth=0.00
        use_attenuation=True
        force_b0_threshold=False
        mask=None
        sphere=None
        # Extracting infos
        b0_mask = gradient_table.b0s_mask
        bvecs = gradient_table.bvecs
        bvals = gradient_table.bvals
        dwi = np.reshape(dwi,[1,1,1,bvals.shape[0]])

        if not is_normalized_bvecs(bvecs):
                logging.warning("Your b-vectors do not seem normalized...")
                bvecs = normalize_bvecs(bvecs)

        b0_threshold = check_b0_threshold(force_b0_threshold, bvals.min())

        # Ensure that this is on a single shell.
        shell_values, _ = identify_shells(bvals)
        shell_values.sort()
        # if shell_values.shape[0] != 2 or shell_values[0] > b0_threshold:
        #     raise ValueError("Can only work on single shell signals.")

        # Keeping b0-based infos
        bvecs = bvecs[np.logical_not(b0_mask)]
        weights = dwi[..., np.logical_not(b0_mask)]

        # Compute attenuation using the b0.
        if use_attenuation:
                b0 = dwi[..., b0_mask].mean(axis=3)
                weights = compute_dwi_attenuation(weights, b0)

        # # Get cartesian coords from bvecs
        sphere = Sphere(xyz=bvecs)

        # SF TO SH
        # Fit SH
        sh = sf_to_sh(weights, sphere, sh_order, basis_type, smooth=smooth)

        # Apply mask
        if mask is not None:
                sh *= mask[..., None]

        # Reconstructing DWI
        # SH to SF 
        og_bvecs = og_gradient_table.bvecs

        if not is_normalized_bvecs(og_bvecs):
                logging.warning("Your b-vectors do not seem normalized...")
                og_bvecs = normalize_bvecs(og_bvecs)

        og_bvecs = og_bvecs[np.logical_not(b0_mask)]

        og_sphere = Sphere(xyz=og_bvecs)

        sf = sh_to_sf(sh, og_sphere, sh_order=sh_order, basis_type=basis_type)

        # SF TO DWI (inverse of compute_dwi_attenuation)
        b0 = b0[..., None]
        weights_hat = sf * b0 
        dwi_hat[i,j,k,:] = weights_hat


In [157]:
num_cores = 10
path = tempfile.mkdtemp()
dwi_hat_path = os.path.join(path,'dwi_hat.mmap')
dwi_hat = np.memmap(dwi_hat_path, dtype=float, shape=(n.shape[0],n.shape[1],n.shape[2],56), mode='w+')
xaxis = range(n.shape[0])
yaxis = range(n.shape[1])
zaxis = range(n.shape[2]) 
results = Parallel(n_jobs=num_cores)(delayed(reconstruct_signal_at_voxel)(i,j,k,mask,n,og_bval,og_bvec,bvec_stack,bval_stack,dwi_hat) for k in zaxis for j in yaxis for i in xaxis)



In [None]:
n[28][38][23]#[..., np.logical_not(b0_mask)]

In [None]:
dwi_hat[28][38][23]

In [149]:
ind_b0 = np.nonzero(og_bval==0)
ind_non_b0 = np.nonzero(og_bval)
ind_b0 = np.squeeze(ind_b0)
ind_non_b0 = np.squeeze(ind_non_b0)

dwmri_corrected = np.zeros((n.shape[0],n.shape[1],n.shape[2],n.shape[3]))

dwmri_corrected[:,:,:,ind_b0] = n[:,:,:,ind_b0] 
dwmri_corrected[:,:,:,ind_non_b0] = dwi_hat

nib.save(nib.Nifti1Image(dwmri_corrected.astype(np.float32),vol.affine),"dwi_hat_whole_img8.nii" )


In [165]:
m2 = np.mean(n[:,:,:,ind_b0] , 3)
m2.shape

(112, 112, 54)

In [174]:
bvec_stack.shape

(112, 112, 54, 72, 3)

In [177]:
# DWI TO SF
def reconstruct_signal_at_voxel_bval(i,j,k,mask,dwi_hat,mean_b0_vol,og_bval,bval_stack,dwi_hat_bval):
    #if mask[i,j,k] == 1:
    dwi = dwi_hat[i][j][k]
    b0 = mean_b0_vol[i][j][k]
    val = bval_stack[i,j,k,:]
    vec = bvec_stack[i,j,k,:,:]
    gradient_table = gradient_table_from_bvals_bvecs(val, vec)

    # Extracting infos
    b0_mask = gradient_table.b0s_mask
    bvecs = gradient_table.bvecs
    bvals = gradient_table.bvals
    dwi = np.reshape(dwi,[1,1,1,56])

    norm_gg = np.divide(bvals[np.logical_not(b0_mask)] , og_bval[np.logical_not(b0_mask)])


    dwi_hat_bval[i,j,k,:] = b0 * np.exp (np.divide( (np.log (np.divide(dwi,b0)) ) , norm_gg))


In [168]:
norm_gg

array([1.00497333, 1.02325439, 1.00533148, 1.00797021, 0.99779193,
       1.014242  , 1.0291272 , 1.00150208, 1.00191901, 1.01033093,
       1.00468115, 1.02697974, 1.02649463, 1.00022925, 0.99933575,
       1.03295435, 1.01247272, 0.99942987, 1.00787476, 1.0068573 ,
       1.01738007, 1.01761212, 1.00151904, 1.01684814, 1.01957715,
       1.0041463 , 1.00347589, 0.9993457 , 1.03361938, 1.00044495,
       1.01470038, 1.00236511, 1.01195819, 1.02218335, 1.02510632,
       1.00618542, 1.0014198 , 1.01929724, 1.00750201, 0.99846039,
       1.02021674, 1.02898572, 0.99989355, 1.00500726, 1.0050578 ,
       1.03038123, 1.00043726, 1.01392291, 1.00754797, 1.00039288,
       1.01215039, 1.01764911, 1.00268896, 1.02779272, 1.02729333,
       1.00233063])

In [179]:
num_cores = 10

dwi_hat_bval_path = os.path.join(path,'dwi_hat_bval.mmap')
dwi_hat_bval = np.memmap(dwi_hat_bval_path, dtype=float, shape=(n.shape[0],n.shape[1],n.shape[2],56), mode='w+')
xaxis = range(n.shape[0])
yaxis = range(n.shape[1])
zaxis = range(n.shape[2]) 

# get info 
results = Parallel(n_jobs=num_cores)(delayed(reconstruct_signal_at_voxel_bval)(i,j,k,mask,dwi_hat,mean_b0_vol,og_bval,bval_stack,dwi_hat_bval) for k in zaxis for j in yaxis for i in xaxis)



In [181]:
# Add the b0 volume back 
ind_b0 = np.nonzero(og_bval==0)
ind_non_b0 = np.nonzero(og_bval)
ind_b0 = np.squeeze(ind_b0)
ind_non_b0 = np.squeeze(ind_non_b0)

dwmri_corrected_bval = np.zeros((n.shape[0],n.shape[1],n.shape[2],n.shape[3]))

dwmri_corrected_bval[:,:,:,ind_b0] = n[:,:,:,ind_b0] 
dwmri_corrected_bval[:,:,:,ind_non_b0] = dwi_hat_bval 
dwmri_corrected_bval = np.nan_to_num(dwmri_corrected_bval)

nib.save(nib.Nifti1Image(dwmri_corrected_bval.astype(np.float32),vol.affine),"dwi_hat_bval_whole_img8.nii" )


In [None]:
import shutil
try:
    shutil.rmtree(path)
except:
    print("Couldn't delete folder")

In [None]:
dwmri_corrected[:,:,:,ind_b0].shape

In [None]:
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure(figsize = (7, 7))
ax = plt.axes(projection ="3d")
for i in range(72):
    ax.scatter3D(vec[i][0], vec[i][1], vec[i][2], color ='b')