This notebook is used to understand what Michael did and to subsequently clean and restructure his code, use available functionalities from existing packages and document it.

First of all I collect all the code which is needed in this notebook to create a working example.

In [1]:
import os
import dipy as dp
import nibabel as nib
import numpy as np
import numpy.linalg as la
from scipy.special import genlaguerre, gamma
from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table, reorient_bvecs, GradientTable
from math import  pi, sqrt, sin, cos, factorial
import math
from tqdm import tqdm, tqdm_notebook

os.chdir("/home/olivier/Devel/test/samuel/in4michi/")


# Load fractional anisotropy
dti_fa = nib.load("dti_FA.nii.gz")

# Load DTI mask
dti_mask = nib.load("mask.nii.gz")

# Load and adjust issue segmentation masks
csf_mask = nib.load("fast_pve_0.nii.gz")
gm_mask = nib.load("fast_pve_1.nii.gz")
wm_mask = nib.load("fast_pve_2.nii.gz")

dti_vecs = nib.load("dti_V1.nii.gz")

data = nib.load("data.nii")

bvals, bvecs = read_bvals_bvecs("bvals", "bvecs")
gtab = gradient_table(bvals, bvecs)


In [2]:
class ResponseFunction():
    def __init__(self, wm_signal, gm_signal, csf_signal, zeta, tau,  *args, ** kwargs):
        """
    
    
        :param args:
        :param kwargs:
        """
    
        self.csf_signal = csf_signal
        self.wm_signal = wm_signal
        self.gm_signal = gm_signal
        self.zeta = zeta
        self.tau = tau



    def save(self, filename = 'response.npz'):
        """

        :param filename:
        """
        np.savez(filename, csf=self.csf_signal, gm=self.gm_signal, wm=self.wm_signal,
                 zeta=self.zeta, tau=self.tau)

In [3]:
def get_response(data, gtab, mask, radial_order, angular_order, zeta, tau):
    shore_coeff = fit_shore(gtab, data, mask, radial_order, angular_order, zeta, tau)
    return accumulate_shore(shore_coeff, mask, radial_order, angular_order, zeta, tau)

In [4]:
V_v_0 = np.array((0,0,0))
V_e_x = np.array((1,0,0))
V_e_y = np.array((0,1,0))
V_e_z = np.array((0,0,1))
def dot(a, b):
    return a[0]*b[0] + a[1]*b[1] + a[2]*b[2]

def cross(a, b):
    return np.array([a[1]*b[2] - a[2]*b[1], a[2]*b[0] - a[0]*b[2], a[0]*b[1] - a[1]*b[0]])


def norm(v):
    return math.sqrt(v[0]**2 + v[1]**2 + v[2]**2)

# find an ortho normal basis including a
def ortho_normal(a):
    if abs(a[2]) < 0.8:
        b = cross(a, V_e_z)
    else:
        b = cross(a, V_e_x)
    b = b / norm(b)
    c = cross(a, b)
    return b,c

def normalized(v):
    l = norm(v)
    if l > 0:
        return v / l
    return e_z

In [5]:
# matrix M for   M*w = v x w
def cross_matrix(v):
    return np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])

# assuming |n| = 1
def rotation(n, phi):
    nn = np.outer(n, n)
    e = np.eye(3)
    return nn + math.cos(phi) * (e - nn) - math.sin(phi) * cross_matrix(n)

In [6]:
# create a rotation matrix R, so that R*a = b
# assume a,b normalized
def vector_rotation_onto(a, b):
    d = dot(a, b)
    if d > 0.999999:
        # almost a = b   => identity
        return np.eye(3)
    if d < -0.999999:
        # almost a = -b
        axis, _ = ortho_normal(a)
    else:
        axis = normalized(cross(a, b))
    return rotation(axis, -math.acos(d))

In [7]:
def gtab_find_shells(gtab):
    NGRADS = len(gtab.bvals)
    index = np.zeros((NGRADS,), dtype=int) - 1

    # b0
    index[gtab.bvals < max(gtab.bvals)/100] = 0
    shell_bvals = [0.0]
    shells = 1

    # higher shells
    while sum(index < 0) > 0:
        bmin = min(gtab.bvals[index < 0])
        shell_bvals += [bmin]
        index[np.logical_and(index < 0, gtab.bvals < (bmin*1.2))] = shells
        shells += 1

    # index contains the shell-index for each bvec
    return index, np.array(shell_bvals)

In [8]:
def gtab_rotate(gtab, R):
    N = len(gtab.bvals)
    rot_bvecs = np.zeros((N,3))
    for i in range(N):
        rot_bvecs[i,:] = np.dot(R, gtab.bvecs[i,:])
    return gradient_table(gtab.bvals, rot_bvecs)


In [9]:
def reorient_gtab(u, gtab):
    NGRADS = len(gtab.bvals)
    shell_index, bvals = gtab_find_shells(gtab)
    N0 = len(np.ones(NGRADS)[shell_index == 0])

    # rotate gradients to align 1st eigenvector to (0,0,1)
    #R = np.eye(3)
    R = vector_rotation_onto(u, V_e_z)
    return gtab_rotate(gtab, R)

In [10]:
def get_response_reorient(data, gtab, mask, vecs, radial_order, angular_order, zeta, tau):
    """
    vecs: the first principal direction of diffusion for every voxel
    """ 
    shore_coeff = np.zeros(data.shape[:3] + (shore_get_size(radial_order, angular_order),))

    count = 0

    #for i in space_range(data, Progress):
    for i in tqdm(space_range(data), total=np.prod(data.shape[:3])):
        if mask[i] == 0:
            continue

        gtab2 = reorient_gtab(vecs[i], gtab)
        M = shore_matrix(radial_order, angular_order, zeta, gtab2, tau)

        r = la.lstsq(M, data[i])
        shore_coeff[i] = r[0]
        count += 1

    return accumulate_shore(shore_coeff, mask,radial_order, angular_order, zeta, tau)

In [11]:
# v: 1 vector
#        returns (r,phi,theta)
#    or list of vectors [N,3]
#        returns [N,3]
def cart_to_sphere(v):
    v = np.array(v)
    if len(v.shape) == 1:
        r = norm(v)
        if r == 0:
            return v_0
        phi = math.atan2(v[1], v[0])
        theta = math.atan2(math.sqrt(v[0]*v[0] + v[1]*v[1]), v[2])
        return np.array([r, theta, phi])
    elif len(v.shape) == 2:
        r = la.norm(v, axis=1)
        phi = np.arctan2(v[:,1], v[:,0])
        theta = np.arctan2(np.sqrt(v[:,0]*v[:,0] + v[:,1]*v[:,1]), v[:,2])
        return np.stack([r, theta, phi]).T

In [12]:
def zero(order):
    return np.zeros(LENGTH[order])

In [13]:

LEGENDRE_A = [
[2],
[4,-2,4],
[16,-8,8,-8,16],
[32, -16, 64, -32, 32, -32, 64],
[256, -64/3.0, 128/3.0, -64, 128/3.0, -64/3.0, 128, -64/3.0, 256/3.0],
[512, -256, 512/3.0, -256/3.0, 256/3.0, -256/3.0, 1024/3.0, -512/3.0, 512, -512, 1024],
[2048/5.0, -1024/5.0, 1024/5.0, -1024/5.0, 4096/15.0, -1024/15.0, 2048/5.0, -1024/5.0, 2048/5.0, -1024/5.0, 2048/5.0, -1024/5.0, 4096/5.0]]
LEGENDRE_B = [
[1],
[5, 15, 15],
[9, 45*2, 45, 315*2, 315],
[13, 273, 2730, 2730, 91*9, 2002*9, 6006],
[17, 17, 2*595, 19635, 1309, 17017, 2*7293, 12155, 12155],
[21, 1155, 385, 2*5005, 5005, 2002, 2*5005, 2*85085, 255255, 2*4849845, 2*969969],
[1, 2*39, 3003, 2*1001, 2*1001, 17017, 2*2431, 138567, 138567, 323323, 2*88179, 2028117, 2*676039]]
LEGENDRE_C = []
for i in range(len(LEGENDRE_A)):
    LEGENDRE_C += [[sqrt(b / pi) / a for a,b in zip(LEGENDRE_A[i], LEGENDRE_B[i])]]

# symmetric for +-m
def legendre(order, theta):
    res = zero(order)

    res[0] = LEGENDRE_C[0][0]
    if order < 2:
        return res

    st = sin(theta)
    ct = cos(theta)
    st2 = st * st
    ct2 = ct * ct
    res[5] = res[1] = LEGENDRE_C[1][2]*st2
    res[4] = res[2] = LEGENDRE_C[1][1]*ct*st
    res[3] = LEGENDRE_C[1][0]*(3.0*ct2-1.0)
    if order < 4:
        return res

    st4 = st2 * st2
    ct4 = ct2 * ct2
    res[14] = res[6] = LEGENDRE_C[2][4]*st4
    res[13] = res[7] = LEGENDRE_C[2][3]*ct*st*st2
    res[12] = res[8] = LEGENDRE_C[2][2]*(7*ct2-1)*st2
    res[11] = res[9] = LEGENDRE_C[2][1]*(7*ct2*ct-3*ct)*st
    res[10] = LEGENDRE_C[2][0]*(35.0*ct4-30.0*ct2+3.0)
    if order < 6:
        return res

    st6 = st4 * st2
    ct6 = ct4 * ct2
    res[27] = res[15] = LEGENDRE_C[3][6]*st6
    res[26] = res[16] = LEGENDRE_C[3][5]*st4*st*ct
    res[25] = res[17] = LEGENDRE_C[3][4]*st4*(11*ct2-1.0)
    res[24] = res[18] = LEGENDRE_C[3][3]*st2*st*(11*ct2*ct-3*ct)
    res[23] = res[19] = LEGENDRE_C[3][2]*st2*(33*ct4-18*ct2+1.0)
    res[22] = res[20] = LEGENDRE_C[3][1]*st*(33*ct4*ct-30.0*ct2*ct+5*ct)
    res[21] = LEGENDRE_C[3][0]*(231*ct6-315*ct4+105*ct2-5.0)
    if order < 8:
        return res

    st8 = st4 * st4
    ct8 = ct4 * ct4
    res[44] = res[28] = LEGENDRE_C[4][8]*st8
    res[43] = res[29] = LEGENDRE_C[4][7]*st6*st*ct
    res[42] = res[30] = LEGENDRE_C[4][6]*st6*(15*ct2-1)
    res[41] = res[31] = LEGENDRE_C[4][5]*st4*st*ct*(5*ct2-1)
    res[40] = res[32] = LEGENDRE_C[4][4]*st4*(65*ct4-26*ct2+1)
    res[39] = res[33] = LEGENDRE_C[4][3]*st2*st*ct*(39*ct4-26*ct2+3)
    res[38] = res[34] = LEGENDRE_C[4][2]*st2*(143*ct6-143*ct4+33*ct2-1)
    res[37] = res[35] = LEGENDRE_C[4][1]*st*ct*(715*ct6-1001*ct4+385*ct2-35)
    res[36] = LEGENDRE_C[4][0]*(6435*ct8-12012*ct6+6930*ct4-1260*ct2+35)
    if order < 10:
        return res

    st10 = st8 * st2
    ct10 = ct8 * ct2
    res[65] = res[45] = LEGENDRE_C[5][10] * st10
    res[64] = res[46] = LEGENDRE_C[5][9] * st8 * st*ct
    res[63] = res[47] = LEGENDRE_C[5][8] * st8 * (19*ct2 - 1)
    res[62] = res[48] = LEGENDRE_C[5][7] * st6 * st*ct * (19*ct2 - 3)
    res[61] = res[49] = LEGENDRE_C[5][6] * st6 * (323*ct4 - 102*ct2 + 3)
    res[60] = res[50] = LEGENDRE_C[5][5] * st4 * st*ct * (323*ct4 - 170*ct2 + 15)
    res[59] = res[51] = LEGENDRE_C[5][4] * st4 * (323*ct6 - 255*ct4 + 45*ct2 - 1)
    res[58] = res[52] = LEGENDRE_C[5][3] * st2 * st*ct * (323*ct6 - 357*ct4 + 105*ct2 - 7)
    res[57] = res[53] = LEGENDRE_C[5][2] * st2 * (4199*ct8 - 6188*ct6 + 2730*ct4 - 364*ct2 + 7)
    res[56] = res[54] = LEGENDRE_C[5][1] * st*ct * (4199*ct8 - 7956*ct6 + 4914*ct4 - 1092*ct2 + 63)
    res[55] = LEGENDRE_C[5][0] * (46189*ct10 - 109395*ct8 + 90090*ct6 - 30030*ct4 + 3465*ct2 - 63)
    if order < 12:
        return res

    st12 = st6 * st6
    ct12 = ct6 * ct6
    res[90] = res[66] = LEGENDRE_C[6][12] * st12
    res[89] = res[67] = LEGENDRE_C[6][11] * st10 * st*ct
    res[88] = res[68] = LEGENDRE_C[6][10] * st10 * (23*ct2 - 1)
    res[87] = res[69] = LEGENDRE_C[6][9] * st8 * st*ct * (23*ct2 - 3)
    res[86] = res[70] = LEGENDRE_C[6][8] * st8 * (161*ct4 - 42*ct2 + 1)
    res[85] = res[71] = LEGENDRE_C[6][7] * st6 * st*ct * (161*ct4 - 70*ct2 + 5)
    res[84] = res[72] = LEGENDRE_C[6][6] * st6 * (3059*ct6 - 1995*ct4 + 285*ct2 - 5)
    res[83] = res[73] = LEGENDRE_C[6][5] * st4 * st*ct * (437*ct6 - 399*ct4 + 95*ct2 - 5)
    res[82] = res[74] = LEGENDRE_C[6][4] * st4 * (7429*ct8 - 9044*ct6 + 3230*ct4 - 340*ct2 + 5)
    res[81] = res[75] = LEGENDRE_C[6][3] * st2 * st*ct * (7429*ct8 - 11628*ct6 + 5814*ct4 - 1020*ct2 + 45)
    res[80] = res[76] = LEGENDRE_C[6][2] * st2 * (7429*ct10 - 14535*ct8 + 9690*ct6 - 2550*ct4 + 225*ct2 - 3)
    res[79] = res[77] = LEGENDRE_C[6][1] * st*ct * (52003*ct10 - 124355*ct8 + 106590*ct6 - 39270*ct4 + 5775*ct2 - 231)
    res[78] = LEGENDRE_C[6][0] * (676039*ct12 - 1939938*ct10 + 2078505*ct8 - 1021020*ct6 + 225225*ct4 - 18018*ct2 + 231)
    return res

In [14]:
# index(l,m) = INDEX_OFFSET[l] + m
INDEX_OFFSET = [0, 0, 3, 0, 10, 0, 21, 0, 36, 0, 55, 0, 78]
def eval_basis(order, theta, phi):
    res = legendre(order, theta)

    cos_m_phi = [cos(m*phi) for m in range(order+1)]
    sin_m_phi = [sin(m*phi) for m in range(order+1)]

    for l in range(0, order+1, 2):
        for m in range(1, l+1):
            res[INDEX_OFFSET[l] - m] *= cos_m_phi[m]
        for m in range(1, l+1):
            res[INDEX_OFFSET[l] + m] *= sin_m_phi[m]
    return res
def esh_index(l, m):
    return INDEX_OFFSET[l] + m

In [15]:
LENGTH = [1, 0, 6, 0, 15, 0, 28, 0, 45, 0, 66, 0, 91]
def esh_matrix(order, angles):
    N = angles.shape[0]
    sh = np.zeros((N, LENGTH[order]))
    for i in range(N):
        sh[i] = eval_basis(order, angles[i,0], angles[i,1])
    return sh

In [16]:
def _kappa(zeta, n, l):
    return np.sqrt((2 * factorial(n - l)) / (zeta ** 1.5 * gamma(n + 1.5)))

In [17]:
def shore_matrix(radial_order, angular_order, zeta, gtab, tau=1 / (4 * np.pi ** 2)):
    assert(radial_order >= angular_order)

    NGRADS = len(gtab.bvals)

    q = np.sqrt(gtab.bvals / (4 * np.pi**2 * tau))
    q[gtab.bvals < 40] = 0

    qgradients = q[:, None] * gtab.bvecs

    # r, theta, phi
    rtp = cart_to_sphere(qgradients)
    rsqrz = rtp[:,0]**2 / zeta
    angles = rtp[:,1:]

    sh = esh_matrix(angular_order, angles)

    size = shore_get_size(radial_order, angular_order)
    M = np.zeros((NGRADS, size))

    counter = 0
    for l in range(0, angular_order + 1, 2):
        for n in range(l, (radial_order + l) // 2 + 1):
            c = genlaguerre(n - l, l + 0.5)(rsqrz) * np.exp(- rsqrz / 2.0) * _kappa(zeta, n, l) * rsqrz ** (l / 2)
            for m in range(-l, l + 1):
                M[:, counter] = sh[:,esh_index(l, m)] * c
                counter += 1
    return M

In [18]:
MAX_ORDER = 12

def _get_size(a,r):
    if (a%2)==1 or (r%2)==1:
        return 0
    return sum([l*r - l*l + r//2 + l//2*3 + 1 for l in range(0,a+1,2)])

SIZES = [[_get_size(a,r) if r >= a else 0 for a in range(MAX_ORDER+1)] for r in range(MAX_ORDER+1)]
#SIZES = [[1, 0,  0, 0,  0, 0,   0, 0,   0, 0,   0, 0,   0],
#         [0, 0,  0, 0,  0, 0,   0, 0,   0, 0,   0, 0,   0],
#         [2, 0,  7, 0,  0, 0,   0, 0,   0, 0,   0, 0,   0],
#         [0, 0,  0, 0,  0, 0,   0, 0,   0, 0,   0, 0,   0],
#         [3, 0, 13, 0, 22, 0,   0, 0,   0, 0,   0, 0,   0],
#         [0, 0,  0, 0,  0, 0,   0, 0,   0, 0,   0, 0,   0],
#         [4, 0, 19, 0, 37, 0,  50, 0,   0, 0,   0, 0,   0],
#         [0, 0,  0, 0,  0, 0,   0, 0,   0, 0,   0, 0,   0],
#         [5, 0, 25, 0, 52, 0,  78, 0,  95, 0,   0, 0,   0],
#         [0, 0,  0, 0,  0, 0,   0, 0,   0, 0,   0, 0,   0],
#         [6, 0, 31, 0, 67, 0, 106, 0, 140, 0, 161, 0,   0],
#         [0, 0,  0, 0,  0, 0,   0, 0,   0, 0,   0, 0,   0],
#         [7, 0, 37, 0, 82, 0, 134, 0, 185, 0, 227, 0, 252]]

def shore_get_size(radial_order, angular_order):
    return SIZES[radial_order][angular_order]

In [19]:
def space_dims(array):
    return array.shape[:3]

In [20]:
# P can be the helper.progress.Progress class
def space_range(array, P=None):
    NX, NY, NZ = space_dims(array)
    if P:
        p = P(NX*NY*NZ)
    for x in range(NX):
        for y in range(NY):
            for z in range(NZ):
                if P:
                    p.step()
                yield x,y,z
    return

In [22]:
def fit_shore(gtab, data, mask, radial_order, angular_order, zeta, tau):
    M = shore_matrix(radial_order, angular_order, zeta, gtab, tau)
    shore_coeff = np.zeros(data.shape[:3] + (shore_get_size(radial_order, angular_order),))
    for i in tqdm(space_range(data), total=np.prod(data.shape[:3])):
    #for i in space_range(data, Progress):    
        if mask[i] == 0:
            continue
        r = la.lstsq(M, data[i])
        shore_coeff[i] = r[0]
    return shore_coeff

In [23]:
def accumulate_shore(shore_coeff, mask, radial_order, angular_order, zeta, tau):
    shore_accum = np.zeros(shore_get_size(radial_order, angular_order))
    accum_count = 0
    nan_count = 0
    for i in space_range(data):
        if mask[i] == 0:
            continue

        #if array_has_nan(shore_coeff[i]):
        if np.isnan(shore_coeff[i]).any():
            nan_count += 1
        else:
            shore_accum += shore_coeff[i]
            accum_count += 1

    print(accum_count, "voxel")
    if nan_count > 0:
        print(nan_count, "nans")
    if accum_count == 0:
        return shore_accum
    return shore_accum / accum_count

In [24]:
def order(coeff):
    size = len(coeff)
    for i in range(len(SIZES)):
        for j in range(len(SIZES)):
            if size == SIZES[i][j]:
                return i,j
    raise Exception("shore order can not be determined for size " + str(size))
get_order = order

In [25]:
def _get_kernel_size(a,r):
    if (a%2)==1 or (r%2)==1:
        return 0
    return sum([(r-l)//2 + 1 for l in range(0,a+1,2)])

KERNEL_SIZES = [[_get_kernel_size(a,r) if r >= a else 0 for a in range(MAX_ORDER+1)] for r in range(MAX_ORDER+1)]
#KERNEL_SIZES = [[1, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0],
#                [0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0],
#                [2, 0,  3, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0],
#                [0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0],
#                [3, 0,  5, 0,  6, 0,  0, 0,  0, 0,  0, 0,  0],
#                [0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0],
#                [4, 0,  7, 0,  9, 0, 10, 0,  0, 0,  0, 0,  0],
#                [0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0],
#                [5, 0,  9, 0, 12, 0, 14, 0, 15, 0,  0, 0,  0],
#                [0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0],
#                [6, 0, 11, 0, 15, 0, 18, 0, 20, 0, 21, 0,  0],
#                [0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0],
#                [7, 0, 13, 0, 18, 0, 22, 0, 25, 0, 27, 0, 28]]

def get_kernel_size(radial_order, angular_order):
    return KERNEL_SIZES[radial_order][angular_order]

In [26]:
# "kernel": only use z-rotational part
def shore_compress(s):
    radial_order, angular_order = get_order(s)
    r = np.zeros(get_kernel_size(radial_order, angular_order))
    counter = 0
    ccounter = 0
    for l in range(0, angular_order + 1, 2):
        for n in range(l, (radial_order - l) // 2 + 1):
            r[ccounter] = s[counter+l]
            counter += 2*l+1
            ccounter += 1
    return r

In [27]:
def estimate_response_function(data, wm_mask, gm_mask, csf_mask, gtab, dti_fa, dti_vecs, dti_mask=None, order=4, zeta=700,
                      tau=1 / (4 * np.pi ** 2), fawm=0.7, verbose = False):
    """
    This function calculates the response Function needed for the deconvolution of the diffusion imaging signal.

    :param data: The diffusion weighted data
    :param wm_mask: The white matter mask
    :param gm_mask: The grey matter mask
    :param csf_mask: The cerebrospinal fluid mask
    :param dti_fa:
    :param dti_vecs:
    :param gtab: This has to be a GradientTable object from dipy. It can be created from b-values and b-vectors using dipy.core.gradients.gradient_table(bvals, bvecs).
    :param dti_mask:
    :param order: The order of the shore basis. The default is 4.
    :param zeta: The radial scaling factor. The default is 700.
    :param tau: The q-scaling factor. The default is 1 / (4 * np.pi ** 2)
    :param fawm: The threshold for the white matter fractional anisotropy. The default is 0.7
    :return: ResponseFunction -- holding all important values.
    """

    # Load DTI fa map
    fa = dti_fa.get_data()

    # Load DTI vecs
    vecs = dti_vecs.get_data()

    # Load DTI mask if available
    if dti_mask is None:
        NX, NY, NZ = fa.shape
        mask = np.ones((NX, NY, NZ))
    else:
        mask = dti_mask.get_data()

    # Create masks
    # CSF
    csf = csf_mask.get_data()
    mask_csf = np.logical_and(mask, np.logical_and(csf > 0.95, fa < 0.2)).astype('int')
    # GM
    gm = gm_mask.get_data()
    mask_gm = np.logical_and(mask, np.logical_and(gm > 0.95, fa < 0.2)).astype('int')
    # WM
    wm = wm_mask.get_data()
    mask_wm = np.logical_and(mask, np.logical_and(wm > 0.95, fa > float(fawm))).astype('int')

    # Load data
    data = data.get_data()

    # Reshape data
    #NX, NY, NZ = data.shape[0:3]
    #N = NX * NY * NZ
    #data = data.reshape((N, -1))
    #vecs = vecs.reshape((N, 3))
    #mask_csf = mask_csf.flatten()
    #mask_gm = mask_gm.flatten()
    #mask_wm = mask_wm.flatten()

    radial_order = order
    angular_order = order
    response_parameters = (radial_order, angular_order, zeta, tau)
    # Calculate csf response
    shore_coeff = get_response(data, gtab, mask_csf, *response_parameters)
    signal_csf = shore_compress(shore_coeff)

    # Calculate gm response
    shore_coeff = get_response(data, gtab, mask_gm, *response_parameters)
    signal_gm = shore_compress(shore_coeff)

    # Calculate wm response
    shore_coeff = get_response_reorient(data, gtab, mask_wm, vecs, *response_parameters)
    signal_wm = shore_compress(shore_coeff)

    return ResponseFunction(signal_wm, signal_gm, signal_csf, zeta, tau)

In [28]:
my_response = estimate_response_function(data, wm_mask, gm_mask, csf_mask, gtab, dti_fa, dti_vecs, dti_mask)
my_response.save()

100%|██████████| 460800/460800 [00:09<00:00, 49520.20it/s] 
 18%|█▊        | 82973/460800 [00:00<00:00, 829554.77it/s]

13169 voxel


100%|██████████| 460800/460800 [00:20<00:00, 22037.78it/s] 
 24%|██▍       | 111964/460800 [00:00<00:00, 1073551.28it/s]

32432 voxel


100%|██████████| 460800/460800 [00:26<00:00, 17230.63it/s]  


5319 voxel
