In [27]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from dipy.sims.voxel import multi_tensor, multi_tensor_odf
from dipy.data import get_sphere
from dipy.core.sphere import disperse_charges, Sphere, HemiSphere
from dipy.core.gradients import gradient_table
from dipy.reconst.shm import CsaOdfModel, QballModel

from fury import actor, window
from IPython.core.display import Image
from PIL import Image as PILImage

import os
import sys
sys.path.insert(0,'..')
import odfs
from tqdm import tqdm

In [28]:
J = np.load('/home/brysongray/diffusion_analysis/outputs/hist_to_mri_jacobian.npy')
disp = np.load('/home/brysongray/diffusion_analysis/outputs/hist_to_mri_disp.npy')
odfs = np.load('/home/brysongray/diffusion_analysis/outputs/human_amyg_csd_odfs_patch-55.npy')
sh = np.load('/home/brysongray/diffusion_analysis/outputs/human_amyg_csd_shm_coeff_patch-55.npy')

In [8]:
# helper functions for visualization
WINDOW_SIZE = (400, 400)
SAVEIM_FOLDER = 'images'
if not os.path.exists(SAVEIM_FOLDER):
    os.mkdir(SAVEIM_FOLDER)

def screenshot_animated_sf(sf, sphere, B=None, rot=True, norm=True, scale=True, title='Modeling', theta_step=30):
    '''
    Render a spherical function to file. Returns path to image.
    '''
    
    scene = window.Scene()
    scene.background(window.colors.white)
    
    sf_actor = actor.odf_slicer(sf[None, None, None, :],
                               sphere=sphere, colormap='jet',
                               norm=norm)#, radial_scale=scale)
    
#     sf_actor = actor.odf_slicer(sf,
#                            sphere=sphere, colormap='jet',
#                            norm=norm, B_matrix=B)#, radial_scale=scale)
        
    if rot:
        sf_actor.RotateX(-90)
#         sf_actor.RotateY(180)
        sf_actor.RotateZ(180)
    scene.add(sf_actor)
    
    images = []
    n_frames = 360//theta_step
    for i in np.arange(n_frames):
#         sf_actor.RotateX(theta_step)
#         sf_actor.RotateY(theta_step)
        sf_actor.RotateZ(theta_step)
        scene.reset_clipping_range()
        images.append(PILImage.fromarray(window.snapshot(scene, size=WINDOW_SIZE)))
        
    frame_duration = 15000 // theta_step
    filename = os.path.join(SAVEIM_FOLDER, '{0}.gif'.format(title))
    images[0].save(filename, save_all=True, append_images=images[1:],
                  duration=frame_duration, optimize=False, loop=0)
    scene.clear()

    return filename
    
def screenshot_gradients(sph_gtab, title='Modeling'):
    scene = window.Scene()
    scene.background(window.colors.white)

    scene.add(actor.point(sph_gtab.vertices, window.colors.green, point_radius=0.05))

    outfile = os.path.join(SAVEIM_FOLDER, '{0}.png'.format(title))
    window.snapshot(scene, size=WINDOW_SIZE, fname=outfile)

    scene.clear()
    return outfile

def display_slice(odfs, sphere, norm=True, B=None):
    '''
    odfs : odf field as spherical harmonics
    B : spherical harmonic to spherical function matrix
    sphere : dipy Sphere
    norm : normalizes so the maximum ODF amplitude per voxel is 1
    '''
    scene = window.Scene()
    scene.background(window.colors.white)
    
    sf_actor = actor.odf_slicer(odfs,
                               sphere=sphere, colormap='jet',
                               norm=norm, B_matrix=B)#, radial_scale=scale)
    scene.add(sf_actor)
    
    window.show(scene, size=WINDOW_SIZE)
    
    return scene

In [11]:
sphere = get_sphere('symmetric724')
dir(sphere)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'edges',
 'faces',
 'find_closest',
 'phi',
 'subdivide',
 'theta',
 'vertices',
 'x',
 'y',
 'z']

In [37]:
def cart_to_polar(X):
    
    x,y,z = X[..., :]
    r = np.sqrt(x**2 + y**2 + z**2)
    theta = np.arccos(z/r)
    phi = np.arctan(y/x)
    
    P = np.stack((r,theta,phi),-1)
    
    return P

def get_rescale(J, X, phi) :
    x, y, z = X[..., :]
    r, theta, phi = cart_to_polar(X)[..., :]
    x_, y_, z_ = phi[..., :]
    r_, theta_, phi_ = cart_to_polar(phi)[..., :]
    
    J_1 = np.stack((np.sin(theta)*np.cos(phi), r*np.cos(theta)*np.cos(phi), -r*np.sin(theta)*np.sin(phi),
                    np.sin(theta)*np.sin(phi), r*np.cos(theta)*np.sin(phi), r*np.sin(theta)*np.cos(phi),
                    np.cos(theta), -r*np.sin(theta), 0), axis=-1).reshape(J.shape)
    J_3 = np.stack((x_/r_, y_/r_, z_/r_,
                    x_*z_/np.sqrt(x_**2 + y_**2*r_**2), y_*z_/np.sqrt(x_**2 + y_**2*r_**2), -np.sqrt(x_**2 + y_**2)/r_**2,
                    -y_/(x_**2 + y_**2), x_/(x_**2 + y_**2), 0.), axis=-1).reshape(J.shape)
    
    J_polar = J_3 @ J @ J_1
    scale_factor = np.sin(theta)/np.sin(theta_) * 1/np.abs(np.linalg.det(J_polar))
    
    return scale_factor
    
    return J_polar


def odf_transform(odf, gtab, J, xI, phi):
    
    
    # first apply transformation (J) to the sphere vertices
    vertices = (J@gtab.T[None,None,None]).T
    
    # now rescale odf
    # scaling factor = (sin(theta) / sin(theta')) / abs(det(J_polar))
    # we need to convert from cartesian to polar coordinates
    XI = np.stack(np.meshgrid(xI, indexing='ij'))
    scale_factor = get_rescale(J, XI, phi)
    signal = scale_factor * odf
    signal_sph = np.concatenate((signal,-signal), axis=-1)
    
    # fit signal and v to spherical harmonics with new sphere
    signal_sph = signal_sph.reshape(-1, signal_sph.shape[-1])
    vertices = vertices.reshape(-1, vertices.shape[-2], vertices.shape[-1])
    sh = np.zeros(vertices.shape[0]+(45,))
    for i in range(len(vertices)):
        sph_gtab = Sphere(xyz=np.vstack((vertices[i],-vertices[i])))
        B, invB = sh_to_sf_matrix(sph_gtab, 8)
        sh[i] = np.dot(invB.T, signal_sph[i])
    sh = sh.reshape(J.shape[:-2] + (45,))
    
    return sh

In [None]:
print(f"odfs shape & dtype: {odfs.shape}, {odfs.dtype}")
print(f"sphere.vertices shape & dtype: {sphere.vertices.shape}, {sphere.vertices.dtype}")
print()

In [35]:
sh_transformed = odf_transform(odfs, sphere.vertices, J, )

[[[[[1 3]
    [2 4]]]]]
