In [None]:
import nibabel as nib
import numpy as np
import plotly.graph_objects as go

## nii data

In [None]:
nii_gz_path = 'C:/Users/panag/Desktop/Test/mgh_1001/diff/preproc/mri/diff_preproc.nii.gz'

In [None]:
# nii_img  = nib.load(nii_gz_path)

In [None]:
# nii_data = nii_img.get_fdata()

In [None]:
# nii_data.shape

## b-values

In [None]:
b_values_path = 'C:/Users/panag/Desktop/Test/mgh_1001/diff/preproc/bvals.txt'

In [None]:
b_values = np.genfromtxt(b_values_path)

In [None]:
print(np.unique(b_values))

In [None]:
b_values[(b_values >= 9_950) & (b_values <= 10_050)] = 10_000

In [None]:
print(np.unique(b_values))

In [None]:
b_values_unique = np.unique(b_values[b_values > 0])

In [None]:
print(b_values_unique)

## direction vectors

In [None]:
direction_vectors_path = 'C:/Users/panag/Desktop/Test/mgh_1001/diff/preproc/bvecs_moco_norm.txt'

In [None]:
direction_vectors = np.genfromtxt(direction_vectors_path)

In [None]:
print(np.unique(direction_vectors, axis=0).shape)

In [None]:
def create_sphere(radius):
    
    theta = np.linspace(0,2.*np.pi,100)
    phi = np.linspace(0,np.pi,100)

    x = radius * np.outer(np.cos(theta),np.sin(phi))
    y = radius * np.outer(np.sin(theta),np.sin(phi))
    z = radius * np.outer(np.ones(np.size(theta)),np.cos(phi))

    spere = go.Surface(x=x, y=y, z=z, colorscale='Greys', showscale=False, opacity=0.1)
    
    return spere

In [None]:

sphere = create_sphere(radius=1)

scatter = go.Scatter3d(
    x=direction_vectors[:,0],
    y=direction_vectors[:,1],
    z=direction_vectors[:,2],
    mode='markers',
    marker=dict(
        size=3,
        color=np.linalg.norm(direction_vectors, axis=1),
        colorscale=[(0, 'blue'), (1, 'red')],
    ),
)


fig = go.Figure(data=[sphere, scatter])

fig.update_layout(
    autosize=False,
    width=800,
    height=800,
)

fig.show()

## magnitude vectors

In [None]:
magnitude_vectors = np.multiply(b_values[:, np.newaxis], direction_vectors)

In [None]:
spheres = []
scatters = []

for b_value in b_values_unique:

    label = f'b-value: {b_value:,.0f}'.replace(',', '_')

    sphere = create_sphere(b_value)
    sphere.update(showlegend=False, legendgroup=label)
    spheres.append(sphere)

    magnitude_vectors_subset = magnitude_vectors[b_values == b_value, :]
    print(f'b-value {b_value:.0f} - {magnitude_vectors_subset.shape[0]} directions')
    
    scatter = go.Scatter3d(
        x=magnitude_vectors_subset[:, 0],
        y=magnitude_vectors_subset[:, 1],
        z=magnitude_vectors_subset[:, 2],
        mode='markers',
        marker=dict(
            size=4,
        ),
        name=label,
        legendgroup=label,
    )
    scatters.append(scatter)


fig = go.Figure(data=scatters+spheres)

fig.update_layout(
    autosize=False,
    width=800,
    height=800,
)

fig.show()

## voxel data

In [None]:
# voxel_data = nii_data[44, 53, 20, :]
# np.save('./voxel_data.npy', voxel_data)

In [None]:
voxel_data = np.load('./voxel_data.npy')

In [None]:
fig = go.Figure(data=[go.Histogram(x=voxel_data)])
fig.show()

In [None]:
spheres = []
scatters = []

for b_value in b_values_unique:

    label = f'b-value: {b_value:,.0f}'.replace(',', '_')

    sphere = create_sphere(b_value)
    sphere.update(showlegend=False, legendgroup=label)
    spheres.append(sphere)

    magnitude_vectors_subset = magnitude_vectors[b_values == b_value, :]
    voxel_data_subset = voxel_data[b_values == b_value]
    voxel_data_subset_normalized = (voxel_data_subset - voxel_data_subset.min()) / (voxel_data_subset.max() - voxel_data_subset.min())
    
    scatter = go.Scatter3d(
        x=magnitude_vectors_subset[:, 0],
        y=magnitude_vectors_subset[:, 1],
        z=magnitude_vectors_subset[:, 2],
        mode='markers',
        marker=dict(
            size=15 * voxel_data_subset_normalized,
            color=voxel_data_subset_normalized,
            colorscale=[(0, 'blue'), (1, 'red')],
        ),
        name=label,
        legendgroup=label,
        showlegend=False,
    )
    scatters.append(scatter)

    scatter = go.Scatter3d(
        x=[None],
        y=[None],
        z=[None],
        mode='markers',
        marker=dict(
            size=0,
            color='white',
        ),
        name=label,
        legendgroup=label,
    )
    scatters.append(scatter)


fig = go.Figure(data=scatters+spheres)

fig.update_layout(
    autosize=False,
    width=800,
    height=800,
)

fig.show()