In [None]:
import nibabel as nib
import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import Rbf
from plotly.subplots import make_subplots

## nii data

In [None]:
nii_gz_path = 'C:/Users/panag/Desktop/Test/mgh_1001/diff/preproc/mri/diff_preproc.nii.gz'

In [None]:
# nii_img  = nib.load(nii_gz_path)

In [None]:
# nii_data = nii_img.get_fdata()

In [None]:
# nii_data.shape

## b-values

In [None]:
b_values_path = 'C:/Users/panag/Desktop/Test/mgh_1001/diff/preproc/bvals.txt'

In [None]:
b_values = np.genfromtxt(b_values_path)

In [None]:
print(*zip(*np.unique(b_values, return_counts=True)))

In [None]:
b_values[(b_values >= 9_950) & (b_values <= 10_050)] = 10_000

In [None]:
b_values_nz_unique, _ = np.unique(b_values[b_values > 0], return_counts=True)

In [None]:
print(*zip(b_values_nz_unique, _))

## direction vectors

In [None]:
direction_vectors_path = 'C:/Users/panag/Desktop/Test/mgh_1001/diff/preproc/bvecs_moco_norm.txt'

In [None]:
direction_vectors = np.genfromtxt(direction_vectors_path)

In [None]:
print(np.unique(direction_vectors, axis=0).shape)

In [None]:
def create_sphere(radius):
    
    theta = np.linspace(0,2.*np.pi,100)
    phi = np.linspace(0,np.pi,100)

    x = radius * np.outer(np.cos(theta),np.sin(phi))
    y = radius * np.outer(np.sin(theta),np.sin(phi))
    z = radius * np.outer(np.ones(np.size(theta)),np.cos(phi))

    spere = go.Surface(x=x, y=y, z=z, colorscale='Greys', showscale=False, opacity=0.1)
    
    return spere

In [None]:
sphere = create_sphere(radius=1)

scatter = go.Scatter3d(
    x=direction_vectors[:,0],
    y=direction_vectors[:,1],
    z=direction_vectors[:,2],
    mode='markers',
    marker=dict(
        size=3,
        color=np.linalg.norm(direction_vectors, axis=1),
        colorscale=[(0, 'blue'), (1, 'red')],
    ),
)


fig = go.Figure(data=[sphere, scatter])

fig.update_layout(
    autosize=False,
    width=800,
    height=800,
    hovermode=False,
    scene=dict(
        xaxis=dict(showspikes=False),
        yaxis=dict(showspikes=False),
        zaxis=dict(showspikes=False)
    ),
)

fig.show()

## magnitude vectors

In [None]:
magnitude_vectors = np.multiply(b_values[:, np.newaxis], direction_vectors)

In [None]:
spheres = []
scatters = []

for b_value in b_values_nz_unique:

    label = f'b-value: {b_value:,.0f}'.replace(',', '_')

    sphere = create_sphere(b_value)
    sphere.update(showlegend=False, legendgroup=label)
    spheres.append(sphere)

    magnitude_vectors_subset = magnitude_vectors[b_values == b_value, :]
    
    scatter = go.Scatter3d(
        x=magnitude_vectors_subset[:, 0],
        y=magnitude_vectors_subset[:, 1],
        z=magnitude_vectors_subset[:, 2],
        mode='markers',
        marker=dict(
            size=4,
        ),
        name=label,
        legendgroup=label,
    )
    scatters.append(scatter)


fig = go.Figure(data=scatters+spheres)

fig.update_layout(
    autosize=False,
    width=800,
    height=800,
    hovermode=False,
    **{
        f'scene{i+1}': dict(
            xaxis=dict(showspikes=False),
            yaxis=dict(showspikes=False),
            zaxis=dict(showspikes=False)
        ) for i in range(len(b_values_nz_unique))
    }
)

fig.show()

## voxel data

In [None]:
# voxel_data = nii_data[44, 53, 20, :]
# np.save('./voxel_data.npy', voxel_data)

In [None]:
voxel_data = np.load('./voxel_data.npy')

#### Together

In [None]:
spheres = []
scatters = []

for b_value in b_values_nz_unique:

    label = f'b-value: {b_value:,.0f}'.replace(',', '_')

    sphere = create_sphere(b_value)
    sphere.update(showlegend=False, legendgroup=label)
    spheres.append(sphere)

    magnitude_vectors_subset = magnitude_vectors[b_values == b_value, :]
    voxel_data_subset = voxel_data[b_values == b_value]
    voxel_data_subset_normalized = (voxel_data_subset - voxel_data_subset.min()) / (voxel_data_subset.max() - voxel_data_subset.min())
    
    scatter = go.Scatter3d(
        x=magnitude_vectors_subset[:, 0],
        y=magnitude_vectors_subset[:, 1],
        z=magnitude_vectors_subset[:, 2],
        mode='markers',
        marker=dict(
            size=15 * voxel_data_subset_normalized,
            color=voxel_data_subset_normalized,
            colorscale=[(0, 'blue'), (1, 'red')],
        ),
        name=label,
        legendgroup=label,
        showlegend=False,
    )
    scatters.append(scatter)

    scatter = go.Scatter3d(
        x=[None],
        y=[None],
        z=[None],
        mode='markers',
        marker=dict(
            size=0,
            color='white',
        ),
        name=label,
        legendgroup=label,
    )
    scatters.append(scatter)


fig = go.Figure(data=scatters+spheres)

fig.update_layout(
    autosize=False,
    width=800,
    height=800,
    hovermode=False,
    **{
        f'scene{i+1}': dict(
            xaxis=dict(showspikes=False),
            yaxis=dict(showspikes=False),
            zaxis=dict(showspikes=False)
        ) for i in range(len(b_values_nz_unique))
    }
)

fig.show()

#### Separate

In [None]:
num_b_values = len(b_values_nz_unique)

# Create subplot figure
fig = make_subplots(
    rows=1, 
    cols=num_b_values, 
    specs=[[{'type': 'scene', 'is_3d':True}]*num_b_values], 
    subplot_titles=[f'b-value: {b_value:,.0f}'.replace(',', '_') for b_value in b_values_nz_unique],
)

for i, b_value in enumerate(b_values_nz_unique):

    label = f'b-value: {b_value:,.0f}'.replace(',', '_')

    # Create sphere
    sphere = create_sphere(b_value)
    sphere.update(showlegend=False, legendgroup=label)

    # Create scatter
    magnitude_vectors_subset = magnitude_vectors[b_values == b_value, :]
    voxel_data_subset = voxel_data[b_values == b_value]
    voxel_data_subset_normalized = (voxel_data_subset - voxel_data_subset.min()) / (voxel_data_subset.max() - voxel_data_subset.min())
    
    scatter = go.Scatter3d(
        x=magnitude_vectors_subset[:, 0],
        y=magnitude_vectors_subset[:, 1],
        z=magnitude_vectors_subset[:, 2],
        mode='markers',
        marker=dict(
            size=15 * voxel_data_subset_normalized,
            color=voxel_data_subset_normalized,
            colorscale=[(0, 'blue'), (1, 'red')],
        ),
        name=label,
        legendgroup=label,
        showlegend=False,
    )

    # Add scatter and sphere
    fig.add_trace(scatter, row=1, col=i+1)
    fig.add_trace(sphere, row=1, col=i+1)

# Update layout
fig.update_layout(
    autosize=True,
    hovermode=False,
    **{
        f'scene{i+1}': dict(
        xaxis=dict(showspikes=False),
        yaxis=dict(showspikes=False),
        zaxis=dict(showspikes=False)
    ) for i in range(num_b_values)},
)

fig.show()

#### Synchronized

In [None]:
import dash
from dash.dependencies import Input, Output
from dash import dcc
from dash import html
from copy import deepcopy

app = dash.Dash(__name__)

app.layout = html.Div([
    dcc.Graph(
        id='3d-plot',
        figure=deepcopy(fig),
    )
])

@app.callback(
    Output('3d-plot', 'figure'),
    [Input('3d-plot', 'relayoutData')]
)
def update_subplots(relayoutData):
    if relayoutData:
        for key in relayoutData.keys():
            if key.startswith('scene') and key.endswith('.camera'):           
                app.layout.children[0].figure.update_layout(
                    **{
                        f'scene{i+1}': dict(
                            camera=relayoutData[key],
                    ) for i in range(num_b_values)},
                )
                break        
    return app.layout.children[0].figure

app.run_server(debug=False)

## spherical density

In [None]:
# Define a grid of (theta, phi) angles
theta = np.linspace(0, 2*np.pi, 100)
phi = np.linspace(0, np.pi, 100)
theta, phi = np.meshgrid(theta, phi)

# Given radius
radius = 1  # replace with your radius

# Convert these angles to Cartesian coordinates (x, y, z)
x_grid = radius * np.sin(phi) * np.cos(theta)
y_grid = radius * np.sin(phi) * np.sin(theta)
z_grid = radius * np.cos(phi)

# Your voxel data
intensity = voxel_data  # replace with your voxel data

# Your direction vectors
x = direction_vectors[:, 0]
y = direction_vectors[:, 1]
z = direction_vectors[:, 2]

# Use radial basis function interpolation to estimate the intensity
rbf = Rbf(x, y, z, intensity, function='multiquadric', smooth=1)
intensity_grid = rbf(x_grid, y_grid, z_grid)

# Create a go.Surface object with the Cartesian coordinates and the interpolated intensities
surface = go.Surface(x=x_grid, y=y_grid, z=z_grid, surfacecolor=intensity_grid, colorscale='Viridis')

fig = go.Figure(data=[surface])

fig.update_layout(
    autosize=False,
    width=800,
    height=800,
    hovermode=False,
    scene=dict(
        xaxis=dict(showspikes=False),
        yaxis=dict(showspikes=False),
        zaxis=dict(showspikes=False)
    ),
)

fig.show()