In [1]:
import numpy as np
import scipy.optimize as opt
import matplotlib.pyplot as plt
%matplotlib inline

import dipy.data as dpd
import dipy.core.sphere as dps 
import dipy.sims as sims
import dipy.core.gradients as grad
import dipy.core.geometry as geo
import nibabel as nib




In [2]:
fdata, fbval, fbvec = dpd.get_data()
gtab = grad.gradient_table(fbval, fbvec)
gtab.bvals = gtab.bvals/1000. # Make the units work out

  bvecs_close_to_1 = abs(vector_norm(bvecs) - 1) <= atol


In [3]:
data_ni = nib.load(fdata)

In [4]:
data_ni

<nibabel.nifti1.Nifti1Image at 0x10494b080>

In [5]:
data = data_ni.get_data()

In [6]:
gtab.bvals

array([ 0.        ,  0.99287978,  1.00102157,  0.99096331,  1.00036425,
        0.99425127,  0.99397785,  0.98918899,  0.99691968,  0.99116247,
        0.9974664 ,  0.99540734,  0.99196243,  0.99312528,  0.99407712,
        0.98797313,  0.99766205,  0.99000655,  0.98992256,  0.99831979,
        0.99480385,  0.99683871,  0.99164865,  0.99447198,  0.99401681,
        0.98796076,  1.00299124,  0.99949294,  0.98761528,  0.99806216,
        0.99473041,  0.99171649,  0.98772035,  0.98694619,  0.98959543,
        0.99598116,  0.99306809,  1.00057236,  0.99671579,  0.99046733,
        0.98969681,  0.9961932 ,  0.99811798,  0.99063137,  0.99385377,
        0.99680916,  0.99249767,  1.001038  ,  0.99337285,  0.99531404,
        0.99264883,  0.99840489,  0.99726554,  0.99255517,  0.98901376,
        0.98857206,  1.0002916 ,  0.99189215,  1.0011105 ,  0.99051374,
        1.00148146,  0.98846483,  0.99037331,  0.99445498,  1.00169366])

In [7]:
sph1 = dpd.get_sphere()
#sph1 = dps.Sphere(xyz=[[1,0,0], [0,1,0], [0,0,1]])


In [8]:
def l2norm(vector):
    """ 
    
    
    """
    return vector / np.dot(vector, vector)
    

In [9]:
import dipy.reconst.dti as dti

In [10]:
def single_tensor(evecs, evals, bvec, bval=1, S0=1):
    R = np.asarray(evecs)
    D = np.dot(np.dot(R, np.diag(evals)), R.T)
    return S0 * np.exp(-bval * np.dot(np.dot(bvec, D), bvec.T))

In [11]:
out_dir=np.array([1, 0, 0]) 
evals=np.array([1.5, 0.5, 0.5]) 
bval=1

In [12]:
def out_signal(in_dir, out_dir, evals=np.array([1.5, 0.5, 0.5]), bval=1):
    # Rotate the canonical tensor towards the output direction and  
    # calculate the signal you would have gotten in the direction
    evals = np.asarray(evals)
    canon_tensor = np.array([[evals[0], 0, 0], [0, evals[1], 0], [0, 0, evals[2]]])
    rot_matrix = geo.vec2vec_rotmat(np.array([1,0,0]), out_dir)
    out_tensor = np.dot(rot_matrix, canon_tensor)
    _, evecs = dti.decompose_tensor(out_tensor)
    s = single_tensor(evecs, evals, in_dir, bval=bval, S0=1)    
    return s

In [13]:
def distance_weight(dist, tau=1):
    return np.exp(-dist/tau)

In [14]:
def weighting(location, out_dir):
    norm_location = l2norm(location)
    out_corr = np.dot(norm_location, out_dir)    
    return distance_weight(np.dot(location, location)) * out_corr 

In [15]:
def design_signal(location, in_dir, out_dir, evals=np.array([1.5, 0.5, 0.5]), bval=1):

    """
    location : the center-to-center location (relative to 0,0,0)
    
    in_dir : observations
    out_dir : parameters
    
    
    """ 
    this_signal = out_signal(in_dir, out_dir, evals=evals, bval=bval)
    # If you are at the center location:
    if np.all(location == np.array([0, 0, 0])):
        return this_signal

    # Otherwise, we need to downweight by distance and angle 
    else:
        return weighting(location, out_dir) * this_signal 
    
    

In [16]:
def test_out_signal():
    sig1 = out_signal(np.array([1, 0, 0]), np.array([1, 0, 0]), np.array([1.5, 0.5, 0.5]), 1)
    sig2 = single_tensor(np.eye(3), np.array([1.5, 0.5, 0.5]), np.array([1, 0, 0])) 
    assert sig1 == sig2
              
test_out_signal()

In [17]:
def test_design_signal():
    sig1 = design_signal(np.array([1, 0, 0]), np.array([1, 0, 0]), np.array([1, 0, 0]))
    sig2 = out_signal(np.array([1, 0, 0]), np.array([1, 0, 0]), np.array([1.5, 0.5, 0.5]), 1)

    assert sig1 == sig2

test_out_signal()

In [18]:
def design_matrix(gtab, sphere, evals=np.array([1.5, 0.5, 0.5])):
    """ 

    """
    dm = []
    coords = [0, 1, -1]
    rows = np.arange(np.sum(~gtab.b0s_mask))
    columns = np.arange(sphere.x.shape[0])

    for x in coords:
        for y in coords:
            for z in coords:
                location = np.array([x, y, z])
                dm.append(np.empty((np.sum(~gtab.b0s_mask), sphere.x.shape[0])))
                for row in rows: 
                    for col in columns: 
                        dm[-1][row, col] = design_signal(location, gtab.bvecs[~gtab.b0s_mask][row], sphere.vertices[col], 
                                                         bval=gtab.bvals[~gtab.b0s_mask][row], evals=evals)
                    
                
                dm[-1] = dm[-1] - np.mean(dm[-1], 0)
    return dm
            
      
    
    
    

In [None]:
def preprocess_signal(data, gtab, i, j, k):
    sig = []
    coords = [0, 1, -1]
    
    for x in coords:
        for y in coords:
            for z in coords:
                location = np.array([x, y, z])
                this_data = data[i+x, j+y, k+z]
                this_data = this_data[~gtab.b0s_mask] / np.mean(this_data[gtab.b0s_mask])
                if np.all(location == np.array([0, 0, 0])):
                    this_data = this_data - np.mean(this_data)
                    sig.append(this_data)
                else: 
                    weighted_sig = []
                    for out_idx, out_dir in enumerate(gtab.bvecs[~gtab.b0s_mask]):
                        weighted_sig.append(weighting(location, out_dir) * this_data[out_idx])
                    sig.append(np.array(weighted_sig) - np.mean(weighted_sig))
    return sig
    

In [None]:
dm = design_matrix(gtab, sph1, evals=np.array([1, 0.5, 0.5]))

In [None]:
plt.plot(dm[0][:, -1])
plt.plot(dm[1][:, -1])

In [None]:
pp_sig = preprocess_signal(data, gtab, 5, 5, 5)

In [None]:
plt.plot(np.concatenate(pp_sig))

In [None]:
new_dm = np.concatenate(dm)
new_sig = np.concatenate(pp_sig)

In [None]:
beta, rnorm = opt.nnls(new_dm, new_sig)

In [None]:
y_hat = np.dot(new_dm, beta)

In [None]:
plt.scatter(y_hat, new_sig)