In [1]:
import os
import scipy
import logging
import tempfile
import numpy as np
import nibabel as nib
import multiprocessing
from dipy.data import get_sphere
from dipy.io import read_bvals_bvecs
from joblib import Parallel, delayed
from dipy.core.sphere import Sphere
from dipy.reconst.shm import sf_to_sh
from dipy.core.gradients import gradient_table_from_bvals_bvecs
from scilpy.reconst.multi_processes import fit_from_model, convert_sh_basis
from scilpy.reconst.raw_signal import compute_sh_coefficients
from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel
from dipy.reconst.shm import real_sh_descoteaux_from_index, sh_to_sf
from scilpy.utils.bvec_bval_tools import (DEFAULT_B0_THRESHOLD,
                                          check_b0_threshold, identify_shells,
                                          is_normalized_bvecs, normalize_bvecs)

  from .autonotebook import tqdm as notebook_tqdm


In [36]:
n = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs002/ses-s1Bx2/prequal_dwi_cat/sub-cIVs002_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.nii.gz').get_fdata()
vol = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs002/ses-s1Bx2/prequal_dwi_cat/sub-cIVs002_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.nii.gz')
og_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs002/ses-s1Bx2/prequal_dwi_cat/sub-cIVs002_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.bvec'
ob_file = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs002/ses-s1Bx2/prequal_dwi_cat/sub-cIVs002_ses-s1Bx2_acq-b1000b2000n96r21x21x22peAPP_run-1_dwi.bval'
vec_folder = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs002/ses-s1Bx2/tracto_ip_lr_corr_1/Lemp/emp_Lcorrected_bvec'
val_folder = '/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs002/ses-s1Bx2/tracto_ip_lr_corr_1/Lemp/emp_Lcorrected_bval'
pk_corr_dwi1 = nib.load('/nfs/masi/kanakap/projects/LR_tract/MASiVar_kids/sub-cIVs002/ses-s1Bx2/tracto_ip_lr_corr_1/emp_Lcorrected_sig.nii.gz').get_fdata()

In [None]:
def compute_dwi_attenuation(dwi_weights: np.ndarray, b0: np.ndarray):
    """ Compute signal attenuation by dividing the dwi signal with the b0.
    Parameters:
    -----------
    dwi_weights : np.ndarray of shape (X, Y, Z, #gradients)
        Diffusion weighted images.
    b0 : np.ndarray of shape (X, Y, Z)
        B0 image.
    Returns
    -------
    dwi_attenuation : np.ndarray
        Signal attenuation (Diffusion weights normalized by the B0).
    """
    b0 = b0[..., None]  # Easier to work if it is a 4D array.

    # Make sure that, in every voxels, weights are lower in the b0. Should
    # always be the case, but with the noise we never know!
    erroneous_voxels = np.any(dwi_weights > b0, axis=3)
    nb_erroneous_voxels = np.sum(erroneous_voxels)
    if nb_erroneous_voxels != 0:
        logging.info("# of voxels where `dwi_signal > b0` in any direction: "
                     "{}".format(nb_erroneous_voxels))
        dwi_weights = np.minimum(dwi_weights, b0)

    # Compute attenuation
    dwi_attenuation = dwi_weights / b0

    # Make sure we didn't divide by 0.
    dwi_attenuation[np.logical_not(np.isfinite(dwi_attenuation))] = 0.

    return dwi_attenuation

In [11]:
def val_emp(i,j,k,n,bvec_stack,bval_stack,emp_sh):
        dwi = n[i][j][k]
        vec = bvec_stack[i,j,k,:,:]
        val = bval_stack[i,j,k,:]
        gradient_table = gradient_table_from_bvals_bvecs(val, vec)
        sh_order=10
        basis_type='tournier07'
        smooth=0.00
        use_attenuation=True
        force_b0_threshold=False
        mask=None
        sphere=None

        # Extracting infos
        b0_mask = gradient_table.b0s_mask
        bvecs = gradient_table.bvecs
        bvals = gradient_table.bvals
        
        dwi = np.reshape(dwi,[1,1,1,bvals.shape[0]])

        if not is_normalized_bvecs(bvecs):
                logging.warning("Your b-vectors do not seem normalized...")
                bvecs = normalize_bvecs(bvecs)

        b0_threshold = check_b0_threshold(force_b0_threshold, bvals.min())

        # Ensure that this is on a single shell.
        shell_values, _ = identify_shells(bvals)
        shell_values.sort()
        # if shell_values.shape[0] != 2 or shell_values[0] > b0_threshold:
        #     raise ValueError("Can only work on single shell signals.")

        # Keeping b0-based infos
        bvecs = bvecs[np.logical_not(b0_mask)]
        weights = dwi[..., np.logical_not(b0_mask)]

        # Compute attenuation using the b0.
        if use_attenuation:
                b0 = dwi[..., b0_mask].mean(axis=3)
                weights = compute_dwi_attenuation(weights, b0)

        # # Get cartesian coords from bvecs # from here cut debugging
        sphere = Sphere(xyz=bvecs)

        # SF TO SH
        # Fit SH
        sh = sf_to_sh(weights, sphere, sh_order, basis_type, smooth=smooth)
        emp_sh[i,j,k,:] = sh


In [32]:
def val_pk(dwi1,og_bval,og_bvec):
    og_gradient_table = gradient_table_from_bvals_bvecs(og_bval, og_bvec)
    pk_sh = compute_sh_coefficients(dwi1,og_gradient_table,sh_order=10,basis_type='tournier07',use_attenuation=True,smooth=0.00)
    return pk_sh

In [37]:
og_bval, og_bvec = read_bvals_bvecs(ob_file,og_file)
ind_1000 = np.where(og_bval == 1000)
ind_2000 = np.where(og_bval == 2000)
ind_b0 = np.nonzero(og_bval==0)
ind_b0 = np.squeeze(ind_b0)
ind_0_1000 = np.where((og_bval == 0) | (og_bval == 1000))
ind_0_2000 = np.where((og_bval == 0) | (og_bval == 2000))

# for 1000
len1 = ind_0_1000[0]
pk_sh1000 = val_pk(pk_corr_dwi1[:,:,:,len1],og_bval[len1],og_bvec[len1])
# for 2000
len2 = ind_0_2000[0]
pk_sh2000 = val_pk(pk_corr_dwi1[:,:,:,len2],og_bval[len2],og_bvec[len2,:])

In [9]:
bvec_vols = []
for i in sorted(os.listdir(vec_folder)):
    if i.endswith('.nii.gz'):
        bvec_vol = nib.load(vec_folder + '/' + i).get_fdata()
        bvec_vol = np.expand_dims(bvec_vol,4)
        bvec_vol = np.transpose(bvec_vol,(0,1,2,4,3))
        bvec_vols.append(bvec_vol)
bvec_stack = np.stack(bvec_vols,3)
bvec_stack = bvec_stack.squeeze()

bval_vols = []
for i in sorted(os.listdir(val_folder)):
    if i.endswith('.nii.gz'):
        bval_vol = nib.load(val_folder + '/' + i).get_fdata()
        bval_vols.append(bval_vol)
bval_stack = np.stack(bval_vols,3)

In [13]:
num_cores = 10
path = tempfile.mkdtemp()
xaxis = range(n.shape[0])
yaxis = range(n.shape[1])
zaxis = range(n.shape[2]) 

# for dwi with 0 1000
len1 = ind_0_1000[0]
dwi_hat_path1 = os.path.join(path,'emp_sh1000.mmap')
emp_sh1000 = np.memmap(dwi_hat_path1, dtype=float, shape=(n.shape[0],n.shape[1],n.shape[2],66), mode='w+')
data = n[:,:,:,len1]
corr_bvec = bvec_stack[:,:,:,len1,:]
corr_bval = bval_stack[:,:,:,len1]
results = Parallel(n_jobs=num_cores)(delayed(val_emp)(i,j,k,data,corr_bvec,corr_bval,emp_sh1000) for k in zaxis for j in yaxis for i in xaxis)



In [14]:

num_cores = 10
path = tempfile.mkdtemp()
xaxis = range(n.shape[0])
yaxis = range(n.shape[1])
zaxis = range(n.shape[2]) 

len2 = ind_0_2000[0]
dwi_hat_path2 = os.path.join(path,'emp_sh2000.mmap')
emp_sh2000 = np.memmap(dwi_hat_path2, dtype=float, shape=(n.shape[0],n.shape[1],n.shape[2],66), mode='w+')
data = n[:,:,:,len2]
org_bvec = og_bvec[len2,:]
org_bval = og_bval[len2]
corr_bvec = bvec_stack[:,:,:,len2,:]
corr_bval = bval_stack[:,:,:,len2]

results = Parallel(n_jobs=num_cores)(delayed(val_emp)(i,j,k,data,corr_bvec,corr_bval,emp_sh2000) for k in zaxis for j in yaxis for i in xaxis)



In [15]:
emp_sh2000.shape

(112, 112, 54, 66)

In [28]:
pk_sh2000.shape

(112, 112, 54, 66)

In [40]:
emp_sh2000[45][45][30][45]

-0.009226272515282363

In [39]:
pk_sh2000[45][45][30][45]

-0.008742835394595416