## Import required modules

In [None]:
# Standard import
import os
import numpy as np
import plotly.graph_objects as go
from numba import njit
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt

# Move to root directory for easier module import
os.chdir("../../")

# Homemade modules
from lbae.modules.maldi_data import MaldiData
from lbae.modules.figures import Figures
from lbae.modules.atlas import Atlas
from lbae.modules.tools.misc import logmem

# Objects containing our data as well as the atlas data
data = MaldiData()
atlas = Atlas(data, resolution=25)
figures = Figures(data, atlas)

In [None]:
decrease_dimensionality_factor = 4

# reducing the dimensionality by 10 yields okay size
array_atlas = atlas.bg_atlas.reference[::decrease_dimensionality_factor, ::decrease_dimensionality_factor, ::decrease_dimensionality_factor]
array_atlas = np.array(array_atlas, dtype=np.int32)

# bug correction for the last slice
array_atlas = np.concatenate((array_atlas, np.zeros((1,array_atlas.shape[1], array_atlas.shape[2]))))



X, Y, Z = np.mgrid[
    0 : array_atlas.shape[0]/1000*25 : array_atlas.shape[0] * 1j,
    0 : array_atlas.shape[1]/1000*25 : array_atlas.shape[1] * 1j,
    0 : array_atlas.shape[2]/1000*25 : array_atlas.shape[2] * 1j,
]

plot_fig = False
if plot_fig:
    fig = go.Figure(
        data=go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=array_atlas.flatten(),
            isomin=5.0,
            isomax=300,
            opacity=0.1,  # max opacity
            opacityscale=[[0, 0], [516, 1]],
            surface_count=30,
            colorscale="RdBu",
        )
    )

    fig.update_layout(
        margin=dict(t=0, r=0, b=0, l=0),
        scene=dict(
            xaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
            yaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
            zaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
        ),
    )
    fig.show()

In [None]:
array_annotation = np.array(atlas.bg_atlas.annotation, dtype = np.int32)

# Subsample array of annotation the same way array_atlas was subsampled
array_annotation = array_annotation[::decrease_dimensionality_factor, ::decrease_dimensionality_factor, ::decrease_dimensionality_factor]

# bug correction for the last slice
array_annotation = np.concatenate((array_annotation, np.zeros((1,array_annotation.shape[1], array_annotation.shape[2]))))


# Compute an array of boundaries 
#print(round(atlas.bg_atlas.annotation.nbytes/ 1024 / 1024,2))
@njit
def fill_array_borders(array_annotation):
    array_atlas_borders = np.full_like(array_annotation, -2., dtype = np.float32)
    for x in range(1, array_annotation.shape[0]-1):
        for y in range(1, array_annotation.shape[1]-1):
            for z in range(1, array_annotation.shape[2]-1):
                if array_annotation[x,y,z]>0:
                    # check if border in a cube of size 2
                    found = False
                    for xt in range(x-1,x+2):
                        for yt in range(y-1,y+2):
                            for zt in range(z-1,z+2):
                                # there's a border around
                                if array_annotation[xt,yt,zt] == 0:
                                    found = True
                    if found:
                        array_atlas_borders[x,y,z] = -0.1
                    # inside the brain but not a border
                    else:
                        array_atlas_borders[x,y,z] = -0.01
                
                #else:
                #    # set all values outside of the brain to -2
                #    array_atlas_borders[x,y,z] = -2.

    return array_atlas_borders  

array_atlas_borders = fill_array_borders(array_annotation)

In [None]:
array_to_display = np.array(atlas.bg_atlas.annotation[300,:,:],dtype = np.int32)
plt.imshow(array_to_display , vmin = 0, vmax = 1)
plt.show()

In [None]:
array_to_display = np.array(array_atlas_borders[30,:,:],dtype = np.float32)
plt.imshow(array_to_display , vmin = -0.2, vmax = 0)
plt.show()

In [None]:
X, Y, Z = np.mgrid[
    0 : array_atlas_borders.shape[0]/1000*25 : array_atlas_borders.shape[0] * 1j,
    0 : array_atlas_borders.shape[1]/1000*25 : array_atlas_borders.shape[1] * 1j,
    0 : array_atlas_borders.shape[2]/1000*25 : array_atlas_borders.shape[2] * 1j,
]

colorscale=[
            [0.0, "rgb(69,117,180)"],
            [0.11, "rgb(254,224,144)"],
            [0.22, "rgb(254,224,144)"],
            [0.33, "rgb(254,224,144)"],
            [0.44, "rgb(253,174,97)"],
            [0.55, "rgb(253,174,97)"],
            [0.66, "rgb(253,174,97)"],
            [0.77, "rgb(244,109,67)"],
            [0.88, "rgb(215,48,39)"],
            [1.0, "rgb(165,0,38)"],
            ]



plot_fig = False
if plot_fig:
    fig = go.Figure(
        data=go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=array_atlas_borders.flatten(),
            isomin=-0.11,
            isomax=2.55,
            opacity=0.9,  # max opacity
            #opacityscale=[[-0.0, 0], [1., 1]],
            opacityscale = "uniform",
            surface_count=20,
            colorscale='Bluyl'#colorscale,
        )
    )

    fig.update_layout(
        margin=dict(t=0, r=0, b=0, l=0),
        scene=dict(
            xaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
            yaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
            zaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
        ),
    )
    fig.show()

In [None]:
ll_t_bounds = [
    [None, None, None],
    [[(759.6073, 759.6117)], None, None],
    [[(759.6071000000001, 759.6112)], None, None],
    [[(759.6068, 759.6106000000001)], None, None],
    [[(759.6069, 759.6111000000001)], None, None],
    [[(759.6064, 759.6108)], None, None],
    [[(759.6066000000001, 759.611)], None, None],
    [[(759.6067, 759.611)], None, None],
    [[(759.6066000000001, 759.611)], None, None],
    [[(759.6067, 759.6111000000001)], None, None],
    [[(759.6067, 759.6115000000001)], None, None],
    [[(759.6067, 759.611)], None, None],
    [[(759.6066000000001, 759.6111000000001)], None, None],
    [[(759.6067, 759.6112)], None, None],
    [[(759.6069, 759.6111000000001)], None, None],
    [[(759.6070000000001, 759.6111000000001)], None, None],
    [[(759.6069, 759.6112)], None, None],
    [[(759.6069, 759.611)], None, None],
    [[(759.6064, 759.611)], None, None],
    [[(759.6066117365762, 759.6111773745573)], None, None],
    [[(759.6067594505706, 759.6111700259931)], None, None],
    [[(759.6064842656122, 759.6111548755334)], None, None],
    [[(759.6074, 759.6112)], None, None],
    [[(759.6067, 759.6109)], None, None],
    [[(759.6067, 759.611)], None, None],
    [[(759.6067, 759.611)], None, None],
    [[(759.6067, 759.6115000000001)], None, None],
    [[(759.6068, 759.6114)], None, None],
    [[(759.6073, 759.6114)], None, None],
    [[(759.6071000000001, 759.6113)], None, None],
    [None, None, None],
    [[(759.6069, 759.6112)], None, None],
]
array_x, array_y, array_z, array_c = figures.compute_figure_bubbles_3D(
    ll_t_bounds, normalize_independently=True, return_arrays=True, high_res = True
)


In [None]:
array_slices = np.copy(array_atlas_borders)
array_x_scaled = array_x * 1000000 / atlas.resolution / decrease_dimensionality_factor
array_y_scaled = array_y * 1000000 / atlas.resolution / decrease_dimensionality_factor
array_z_scaled = array_z * 1000000 / atlas.resolution / decrease_dimensionality_factor

print(np.min(array_x_scaled), np.max(array_x_scaled))
print(np.min(array_y_scaled), np.max(array_y_scaled))
print(np.min(array_z_scaled), np.max(array_z_scaled))
print(np.min(array_c), np.max(array_c))
print(array_slices.shape)
for x, y, z, c in zip(array_x_scaled, array_y_scaled, array_z_scaled, array_c):
    x_scaled = int(round(y))
    y_scaled = int(round(z))
    z_scaled = int(round(x))
    
    array_slices[x_scaled, y_scaled, z_scaled ] = c/100

In [None]:
plot = False
if plot:
    fig = go.Figure(
        data=go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=array_slices.flatten(),
            isomin = -1,
            isomax= 2.55,
            opacity=0.5,  # max opacity
            opacityscale="uniform",
            surface_count=10,
            colorscale="RdBu_r",
        )
    )

    fig.update_layout(
        margin=dict(t=0, r=0, b=0, l=0),
        scene=dict(
            xaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
            yaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
            zaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
        ),
    )
    fig.show()


In [None]:
# Do interpolation between the slices
@njit
def fill_array_interpolation(array_annotation, array_slices):
    array_interpolated = np.copy(array_slices)
    for x in range(0, array_annotation.shape[0]):
        for y in range(0, array_annotation.shape[1]):
            for z in range(0, array_annotation.shape[2]):
                # If we are in a unfilled region of the brain
                if (np.abs(array_slices[x, y, z] - (-0.01)) < 10 ** -4) or array_slices[x, y, z]>=0:
                    # check all datapoints in the same structure, and do a distance-weighted average
                    value_voxel = 0
                    sum_weights = 0
                    size_radius = int(array_annotation.shape[0]/5)
                    for xt in range(max(0,x-size_radius), min(array_annotation.shape[0],x+size_radius+1)):
                        for yt in range(max(0,y-size_radius), min(array_annotation.shape[1],y+size_radius+1)):
                            for zt in range(max(0,z-size_radius), min(array_annotation.shape[2],z+size_radius+1)):
                                # The voxel has data
                                if array_slices[xt, yt, zt] >= 0:
                                    # The structure is identical
                                    if np.abs(array_annotation[x, y, z] - array_annotation[xt, yt, zt]) < 10**-4:
                                        d = np.sqrt((x - xt) ** 2 + (y - yt) ** 2 + (z - zt) ** 2)
                                        value_voxel += np.exp(- d) * array_slices[xt, yt, zt]
                                        sum_weights += np.exp(-d)
                    if sum_weights == 0:
                        pass
                        #print("No other voxel was found for structure ", array_annotation[x, y, z])
                    else:
                        #print('Voxel found for structure', array_annotation[x, y, z])
                        value_voxel = value_voxel / sum_weights
                        array_interpolated[x, y, z] = value_voxel

    return array_interpolated


array_interpolated = fill_array_interpolation(array_annotation, array_slices)


In [None]:
size = widgets.IntSlider(value=5, min=0,max=array_interpolated.shape[0]-1,step=1, description='Slice')

def hist1(size):
    plt.imshow(array_interpolated[size,:,:], vmin = -0.1, vmax = 2.55)
    return
out = widgets.interactive_output(hist1, {'size':size})

display(size, out)

In [None]:
size = widgets.IntSlider(value=5, min=0,max=array_slices.shape[0]-1,step=1, description='Slice')

def hist1(size):
    plt.imshow(array_slices[size,:,:], vmin = -0.1, vmax = 2.55)
    return
out = widgets.interactive_output(hist1, {'size':size})

display(size, out)

In [None]:
plot = True
if plot:
    fig = go.Figure(
        data=go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=array_interpolated.flatten(),
            isomin = -0.1,
            isomax= 2.55,
            opacity=0.5,  # max opacity
            opacityscale="uniform",
            surface_count=5,
            colorscale='Bluyl',#"RdBu_r",
            #cmid=-0.1,
            #cmax = 2.,
            #cmin = -1,
        )
    )

    fig.update_layout(
        margin=dict(t=0, r=0, b=0, l=0),
        scene=dict(
            xaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
            yaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
            zaxis=dict(backgroundcolor="rgba(0,0,0,0)", color="grey", gridcolor="grey"),
        ),
    )
    fig.show()