## Import required modules

In [None]:
from bg_atlasapi import BrainGlobeAtlas
from pprint import pprint
import numpy as np
from matplotlib import pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from PIL import Image
import base64
from io import BytesIO
from matplotlib import cm
from plotly.subplots import make_subplots
import meshio
from skimage import io
from bg_atlasapi import show_atlases
import zarr
from pympler import asizeof
import os
from numba import njit


## (Down-)load the atlas

In [None]:
show_atlases()


In [None]:
bg_atlas = BrainGlobeAtlas("allen_mouse_25um", check_latest=True)


In [None]:
bg_atlas.metadata


## Display slices with annotations

In [None]:
# space = bg_atlas.space
# stack = bg_atlas.reference
# labels_id = bg_atlas.annotation


In [None]:
class Labels():
    """ Class used to access labels data without having to create new arrays"""
    def __init__(self, bg_atlas, use_zarr = True):
        self.bg_atlas = bg_atlas
        self.use_zarr = use_zarr
        
    def __getitem__(self, key):
        x = self.bg_atlas.annotation[key]
        if isinstance(x, np.uint32) :
            if x!=0:
                return self.bg_atlas.structures[x]['name']
            else:
                return 'undefined'
        #an array slice have been provided
        else:
            if self.use_zarr:
                return zarr.array(np.reshape([self.bg_atlas.structures[i]['name'] if i!=0 else 'undefined' for i in x.flatten() ], x.shape))
            else:
                return np.reshape([self.bg_atlas.structures[i]['name'] if i!=0 else 'undefined' for i in x.flatten() ], x.shape)
        
class LabelContours():
    """ Class used to map labels to increasing integers"""
    def __init__(self, bg_atlas, use_zarr = True):
        self.bg_atlas = bg_atlas
        self.unique_id = {ni: indi for indi, ni in enumerate(set(self.bg_atlas.annotation.flatten()))}
        self.use_zarr = use_zarr
        
    def __getitem__(self, key):
        x = self.bg_atlas.annotation[key]
        if isinstance(x, np.uint32) :
            return self.unique_id[x]
        else:
            array = np.reshape([self.unique_id[i] for i in x.flatten()], x.shape)
            if self.use_zarr:
                return zarr.array(array)
            else:
                return array

        

In [None]:
labels = Labels(bg_atlas)
simplified_labels_id = LabelContours(bg_atlas)


In [None]:
test = []
for idx, (plane, axis_labels) in enumerate(
    zip(bg_atlas.space.sections, bg_atlas.space.axis_labels)
):
    coor = list(range(0, bg_atlas.reference.shape[idx], 25))
    # Add traces, one for each slider step
    for step in coor:
        if idx == 0:
            test.append(labels[step, :, :])
        elif idx == 1:
            test.append(labels[:, step, :])
        elif idx == 2:
            test.append(labels[:, :, step])

# test is very light, so the arrays are not using memory... plotly is, through customdata which uncompresses the arrays
print(asizeof.asizeof(test) / 1024 / 1024)


### Version with binary string

In [None]:
# compute mask of Isocortex as a test for masks
stack_mask = bg_atlas.get_structure_mask("MOs")


In [None]:
print(np.min(stack_mask))


In [None]:
contour = True
do_stack_mask = True

multiplier = 1
if contour or do_stack_mask:
    multiplier = 2
if contour and do_stack_mask:
    multiplier = 3

for idx, (plane, axis_labels) in enumerate(
    zip(bg_atlas.space.sections, bg_atlas.space.axis_labels)
):

    # Create figure
    fig = go.Figure()

    coor = list(range(0, bg_atlas.reference.shape[idx], 25))
    # Add traces, one for each slider step
    for step in coor:
        if idx == 0:
            img = np.uint8(cm.viridis(bg_atlas.reference[step, :, :]) * 255)
            customdata = labels[step, :, :]
            if do_stack_mask:
                img_mask = np.uint8(cm.gray(stack_mask[step, :, :]) * 255)
        elif idx == 1:
            img = np.uint8(cm.viridis(bg_atlas.reference[:, step, :]) * 255)
            customdata = labels[:, step, :]
            if do_stack_mask:
                img_mask = np.uint8(cm.gray(stack_mask[:, step, :]) * 255)
        elif idx == 2:
            img = np.uint8(cm.viridis(bg_atlas.reference[:, :, step]) * 255)
            customdata = labels[:, :, step]
            if do_stack_mask:
                img_mask = np.uint8(cm.gray(stack_mask[:, :, step]) * 255)

        pil_img = Image.fromarray(img)  # PIL image object
        prefix = "data:image/png;base64,"
        with BytesIO() as stream:
            pil_img.save(stream, format="png", optimize=True, quality=85)
            base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

        if do_stack_mask:
            pil_img = Image.fromarray(img_mask)  # PIL image object
            # make background transparent
            img = pil_img.convert("RGBA")

            datas = img.getdata()
            newData = []
            for item in datas:
                if item[0] == 0 and item[1] == 0 and item[2] == 0:
                    newData.append((0, 0, 0, 0))
                else:
                    item = list(item)
                    item[-1] = 150
                    item = tuple(item)
                    newData.append(item)
            img.putdata(newData)
            pil_img = img

            prefix = "data:image/png;base64,"
            with BytesIO() as stream:
                pil_img.save(stream, format="png", optimize=True, quality=85)
                base64_string_mask = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

        if not contour:
            fig.add_trace(
                go.Image(
                    visible=True,
                    source=base64_string,
                    customdata=customdata,
                    hovertemplate="<br>%{customdata}<br><extra></extra>",
                )
            )

        else:
            fig.add_trace(go.Image(visible=True, source=base64_string))

            if idx == 0:
                contour = simplified_labels_id[step, :, :]
            elif idx == 1:
                contour = simplified_labels_id[:, step, :]
            elif idx == 2:
                contour = simplified_labels_id[:, :, step]

            fig.add_trace(
                go.Contour(
                    visible=False,
                    showscale=False,
                    z=contour,
                    contours=dict(coloring="none"),
                    line_width=2,
                    line_color="gold",
                    # customdata = customdata,
                    # hovertemplate = '<br>%{customdata}<br><extra></extra>'
                )
            )

        if do_stack_mask:
            fig.add_trace(go.Image(visible=True, source=base64_string_mask, hoverinfo="skip"))

    # Make 10th trace visible
    fig.data[multiplier * 4].visible = True

    # Create and add slider
    steps = []
    for i in range(len(fig.data) // multiplier):
        step = dict(
            method="update",
            args=[
                {"visible": [False] * len(fig.data) * multiplier},
            ],  # layout attribute
            label=coor[i],
        )
        step["args"][0]["visible"][multiplier * i] = True  # Toggle i'th trace to "visible"
        if contour or do_stack_mask:
            step["args"][0]["visible"][multiplier * i + 1] = True
        if contour and do_stack_mask:
            step["args"][0]["visible"][multiplier * i + 2] = True

        steps.append(step)

    sliders = [
        dict(
            active=10,
            currentvalue={
                "visible": False,
            },
            pad={"t": 50, "l": 100, "r": 100},
            steps=steps,
            # len = 0.4,
            # xanchor = 'center',
        )
    ]

    fig.update_layout(sliders=sliders)

    fig.update_layout(
        yaxis=dict(scaleanchor="x"),
        width=1000,
        height=800,
        title={
            "text": f"{plane.capitalize()} view",
            "y": 0.9,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
    )

    fig.update_xaxes(title_text=axis_labels[1])
    fig.update_yaxes(title_text=axis_labels[0])
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)

    fig.show()

    # heavy plots, remove if needed to see other projections
    break


### Version without binary string (slow)

In [None]:
if False:
    # Create figure
    fig = go.Figure()

    x_coor = list(range(0, stack.shape[0], 10))
    # Add traces, one for each slider step
    for step in x_coor:
        fig.add_trace(go.Heatmap(visible=False, z=stack[step, :, :], colorscale="Viridis"))

    # Make 10th trace visible
    fig.data[10].visible = True

    # Create and add slider
    steps = []
    for i in range(len(fig.data)):
        step = dict(
            method="update",
            args=[
                {"visible": [False] * len(fig.data)},
            ],  # layout attribute
            label=x_coor[i],
        )
        step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
        steps.append(step)

    sliders = [
        dict(
            active=10,
            currentvalue={
                "visible": False,
            },
            pad={"t": 10},
            steps=steps,
        )
    ]

    fig.update_layout(sliders=sliders)
    fig.update_layout(yaxis=dict(scaleanchor="x"))
    fig.update_layout(plot_bgcolor="rgba(0,0,0,0)")
    fig["layout"]["xaxis"]["gridcolor"] = "rgba(0, 0, 0, 0)"
    fig["layout"]["yaxis"]["gridcolor"] = "rgba(0, 0, 0, 0)"
    fig["layout"]["yaxis"]["color"] = "rgba(0, 0, 0, 0)"
    fig["layout"]["xaxis"]["color"] = "rgba(0, 0, 0, 0)"
    fig["layout"]["yaxis"]["autorange"] = "reversed"

    fig.show()


### Version without px_express (impossible to annotate)

In [None]:
if False:
    fig = px.imshow(stack[0:-1:10,:,:], 
                    animation_frame=0, 
                    binary_string = True, 
                    labels=dict(x="y", y="z"), 
                   )

    #fig.update(data=[{'customdata': labels[0:-1:10,:,:],
    #    'hovertemplate': '<br>Annotation: %{customdata}<br><extra></extra>'}])


    #fig.update_layout(coloraxis_showscale=False)
    #fig.update_traces(hovertemplate="Annotation: %{labels[0:-1:10,:,:]:$.2f}<extra></extra>",)

    fig["layout"].pop("updatemenus") # optional, drop animation buttons
    fig.show()

## Explore atlas in 3D

Work with low-resolution atlas as too heavy otherwise

In [None]:
volume = np.array(
    BrainGlobeAtlas("allen_mouse_100um", check_latest=True).reference, dtype=np.float64
)
# needed to not have the last slice prominent
volume = np.concatenate((volume, np.zeros((1, volume.shape[1], volume.shape[2]))))


In [None]:
X, Y, Z = np.mgrid[
    0 : volume.shape[0] - 1 : complex(volume.shape[0]),
    0 : volume.shape[1] - 1 : complex(volume.shape[1]),
    0 : volume.shape[2] - 1 : complex(volume.shape[2]),
]
volume.min(), volume.max()


In [None]:
isosurf = go.Isosurface(
    # surface=dict(show=True, fill=1,  count= 1),
    surface_count=2,  # number of isosurfaces, 2 by default: only min and max
    # colorbar_nticks=5, # colorbar ticks correspond to isosurface values
    # colorscale=[[0, "rgb(190,190,190)"], [1, "rgb(190,190,190)"]],
    colorscale="emrld",
    opacity=0.4,
    showscale=False,
    x=X.flatten() * 25 / 1000,
    y=Y.flatten() * 25 / 1000,
    z=Z.flatten() * 25 / 1000,
    value=volume.flatten(),
    isomin=9,
    isomax=516,
    flatshading=True,
    lighting=dict(
        ambient=0.65,
        diffuse=0.5,
        fresnel=0.25,
        specular=0.25,
        roughness=0.25,
        facenormalsepsilon=0,
        vertexnormalsepsilon=0,
    ),
)
# lightposition=dict(x=200,
#                   y=100,
#                   z=10))

fig = go.Figure(isosurf)

fig.update_layout(
    width=700,
    height=700,
    scene=dict(aspectmode="data", camera_eye=dict(x=1.45, y=1.45, z=0.5)),
    # scene_xaxis_visible=False,
    # scene_yaxis_visible=False,
    # scene_zaxis_visible=False,
)
fig.show()


In [None]:
myslices = go.Isosurface(
    surface=dict(show=False),
    colorscale="viridis",
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=volume.flatten(),
    slices=dict(
        x=dict(show=True, fill=1.0, locations=[39]), z=dict(show=True, fill=1.0, locations=[51])
    ),
    isomin=9,
    isomax=max([volume[39, :, :].max(), volume[:, :, 51].max()]),
)

fig_slices = go.Figure(
    myslices, go.Layout(width=700, height=700, scene_camera_eye=dict(x=1.25, y=1.25, z=1))
)

# fig_slices.add_trace(isosurf)
fig_slices.update_layout(
    scene_xaxis_visible=False,
    scene_yaxis_visible=False,
    scene_zaxis_visible=False,
)

fig_slices.show()


## Regions hierarchy

In [None]:
bg_atlas.structures


In [None]:
pprint(bg_atlas.structures["root"])


In [None]:
bg_atlas.structures["CH"]["structure_id_path"]


In [None]:
bg_atlas.get_structure_descendants("VISC")


In [None]:
bg_atlas.get_structure_ancestors("VISC6a")


In [None]:
# create a list of parents for all ancestors
l_nodes = []
l_parents = []
l_id = []
idx = 0
for x, v in bg_atlas.structures.items():
    if len(bg_atlas.get_structure_ancestors(v["acronym"])) > 0:
        ancestor_acronym = bg_atlas.get_structure_ancestors(v["acronym"])[-1]
        ancestor_name = bg_atlas.structures[ancestor_acronym]["name"]
    else:
        ancestor_name = ""
    current_name = bg_atlas.structures[x]["name"]

    l_nodes.append(current_name)
    l_parents.append(ancestor_name)
    l_id.append(v["acronym"])
    # print(current_name, ', ', ancestor_name)
    # idx+=1
    # if idx ==10:
    #    break


In [None]:
fig = px.treemap(
    names=l_nodes,
    parents=l_parents,
    # ids = l_id,
    maxdepth=4,
)
fig.update_traces(root_color="lightgrey")
# fig.update_layout(margin = dict(t=0, l=0, r=0, b=0))
fig.show()

fig = px.sunburst(
    names=l_nodes,
    parents=l_parents,
    # ids = l_id,
    maxdepth=3,
)
fig.show()

# fig = go.Figure()

# fig.add_trace(go.Sunburst(
#    labels=l_nodes,
#    parents=l_parents,
#    maxdepth=4
# ))

# fig.show()


## Test queries by voxel

In [None]:
# Ask for identity of some indexes in the stack:
print("By index:", bg_atlas.structure_from_coords((200, 200, 300), as_acronym=True))

# Now give coordinates in microns
print(
    "By coordinates:",
    bg_atlas.structure_from_coords((200 * 25, 200 * 25, 300 * 25), as_acronym=True, microns=True),
)

# Now cut hierarchy at some level
print(
    "Higher hierarchy level:",
    bg_atlas.structure_from_coords(
        (200 * 25, 200 * 25, 300 * 25), as_acronym=True, microns=True, hierarchy_lev=2
    ),
)


## Play with region meshes

If we need to access the structure meshes, we can either query for the file (e.g., if we need to load the file through some library like `vedo`):

In [None]:
file = bg_atlas.meshfile_from_structure("CH")


In [None]:
print(file)


Or directly obtain the mesh, as a mesh object of the `meshio` library:

In [None]:
mesh_data = bg_atlas.mesh_from_structure("CH")
vertices = mesh_data.points
triangles = mesh_data.cells[0].data

mesh_data_root = bg_atlas.mesh_from_structure("root")
vertices_root = mesh_data_root.points
triangles_root = mesh_data_root.cells[0].data


In [None]:
print(vertices_root.shape)
print(triangles_root.shape)

print(vertices_root[1])
print(triangles_root[1])


In [None]:
x, y, z = vertices.T
I, J, K = triangles.T
tri_points = vertices[triangles]


x_root, y_root, z_root = vertices_root.T
I_root, J_root, K_root = triangles_root.T
tri_vertices_root = vertices_root[triangles_root]


In [None]:
pl_mygrey = [0, "rgb(153, 153, 153)"], [1.0, "rgb(255,255,255)"]

pl_mesh = go.Mesh3d(
    x=x,
    y=y,
    z=z,
    colorscale="Blues",  # pl_mygrey,
    intensity=z,
    flatshading=True,
    i=I,
    j=J,
    k=K,
    name="Mesh CH",
    showscale=False,
)

pl_mesh.update(
    cmin=-7,  # atrick to get a nice plot (z.min()=-3.31909)
    lighting=dict(
        ambient=0.2,
        diffuse=1,
        fresnel=0.1,
        specular=1,
        roughness=0.05,
        facenormalsepsilon=1e-15,
        vertexnormalsepsilon=1e-15,
    ),
    lightposition=dict(x=100, y=200, z=0),
)


Xe = []
Ye = []
Ze = []
for T in tri_points:
    Xe.extend([T[k % 3][0] for k in range(4)] + [None])
    Ye.extend([T[k % 3][1] for k in range(4)] + [None])
    Ze.extend([T[k % 3][2] for k in range(4)] + [None])

# define the trace for triangle sides
lines = go.Scatter3d(
    x=Xe, y=Ye, z=Ze, mode="lines", name="", line=dict(color="rgb(70,70,70)", width=1)
)

layout = go.Layout(
    # title="Test mesh with flatshading",
    font=dict(size=16, color="white"),
    width=700,
    height=700,
    scene_xaxis_visible=False,
    scene_yaxis_visible=False,
    scene_zaxis_visible=False,
    # paper_bgcolor='rgb(50,50,50)',
)


# add simple surface triangulation for root
# Xe_root = []
# Ye_root = []
# Ze_root = []
# for T in tri_vertices_root:
#    Xe_root += [T[k%3][0] for k in range(4)] + [ None]
#    Ye_root += [T[k%3][1] for k in range(4)] + [ None]
#    Ze_root += [T[k%3][2] for k in range(4)] + [ None]

# root_lines = go.Scatter3d(x=Xe_root,
#                     y=Ye_root,
#                     z=Ze_root,
#                     mode='lines',
#                     name='',
#                     line=dict(color= 'rgb(150,150,150)', width=0.5))


pl_mesh_root = go.Mesh3d(
    x=x_root,
    y=y_root,
    z=z_root,
    colorscale=pl_mygrey,
    intensity=z,
    flatshading=False,
    i=I_root,
    j=J_root,
    k=K_root,
    opacity=0.2,
    name="Mesh CH",
    showscale=False,
)


# fig = go.Figure(data=[pl_mesh, lines, root_lines], layout=layout)
# fig = go.Figure(data=[pl_mesh], layout=layout)
# fig = go.Figure(data=[pl_mesh, root_lines], layout=layout)
# fig = go.Figure(data=[root_lines])
# fig = go.Figure(data=[pl_mesh_root, pl_mesh], layout=layout)
fig = go.Figure(data=[pl_mesh_root], layout=layout)
fig.show()


## Load our data

In [None]:
images = io.imread("../app/data/tif_files/slices.tif")
images = np.array(images)

coors = io.imread("../app/data/tif_files/coors.tif")
coors = np.array(coors)


In [None]:
print(bg_atlas.reference.shape)
print(images.shape)
print(np.max(coors[:, :, :, 0]) * 100)
print(np.max(coors[:, :, :, 1]) * 100)
print(np.max(coors[:, :, :, 2]) * 100)


In [None]:
# get slice of the atlas corresponding to our slice idx as an example
idx_slice = 10
slice_exp = images[idx_slice]
slice_coor_mm = coors[idx_slice]
slice_coor_25um = ((slice_coor_mm * 1000 / 25).round(0)).astype(np.int32)


In [None]:
# projected_stack = np.full(slice_coor_10um.shape, np.nan)
projected_stack = np.full(slice_exp.shape, 0, dtype=np.int32)
projected_labels = np.full(slice_exp.shape, "undefined", dtype=np.object)
projected_simplified_labels_id = np.full(
    slice_exp.shape, simplified_labels_id[0, 0, 0], dtype=np.int32
)
for x in range(slice_coor_25um.shape[0]):
    for y in range(slice_coor_25um.shape[1]):
        if (
            min(slice_coor_25um[x, y]) >= 0
            and slice_coor_25um[x, y][0] < bg_atlas.reference.shape[0]
            and slice_coor_25um[x, y][1] < bg_atlas.reference.shape[1]
            and slice_coor_25um[x, y][2] < bg_atlas.reference.shape[2]
        ):
            projected_stack[x, y] = bg_atlas.reference[tuple(slice_coor_25um[x, y])]
            projected_labels[x, y] = labels[tuple(slice_coor_25um[x, y])]
            projected_simplified_labels_id[x, y] = simplified_labels_id[
                tuple(slice_coor_25um[x, y])
            ]


In [None]:
# display our slice against queried atlas

contour = True

fig = go.Figure()

# compute image for the atlas
# img = np.uint8(cm.viridis(projected_stack )*255)
# pil_img = Image.fromarray(img) # PIL image object
# prefix = "data:image/png;base64,"
# with BytesIO() as stream:
#    pil_img.save(stream, format="png", optimize=True,quality=85)
#    base64_string_atlas = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

# now image from our data
img = np.uint8(cm.viridis(slice_exp) * 255)
pil_img = Image.fromarray(img)  # PIL image object
prefix = "data:image/png;base64,"
with BytesIO() as stream:
    pil_img.save(stream, format="png", optimize=True, quality=85)
    base64_string_exp = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

if not contour:
    fig.add_trace(
        go.Image(
            visible=True,
            source=base64_string_exp,
        )
    )  # customdata = projected_labels, hovertemplate = '<br>%{customdata}<br><extra></extra>'))
    # fig.add_trace(go.Image(visible=True,source=base64_string_atlas, opacity = 0.2, customdata = projected_labels, hovertemplate = '<br>%{customdata}<br><extra></extra>'))

else:
    fig.add_trace(go.Image(visible=True, source=base64_string_exp))
    # fig.add_trace(go.Image(visible=True,source=base64_string_atlas, opacity = 0.2))
    # print('go')
    contours = projected_simplified_labels_id[1:, 1:] - projected_simplified_labels_id[:-1, :-1]
    contours = np.clip(contours**2, 0, 1)
    contours = np.pad(contours, ((1, 0), (1, 0)))
    contour_image = np.full(projected_stack.shape + (3,), 255, dtype=np.uint8)
    contour_image = np.concatenate((contour_image, np.expand_dims(contours, -1)), axis=-1)

    # Build graph from image
    fig.add_trace(go.Image(z=contour_image, colormodel="rgba", hoverinfo="skip"))
    # print('finish')
    # fig.add_trace(go.Contour(visible=True, showscale=False,
    #    z=projected_simplified_labels_id,
    #    contours=dict(coloring='none'),
    #    line_width=2,
    #    line_color = 'gold',
    #    customdata = projected_labels,
    #    hovertemplate = '<br>%{customdata}<br><extra></extra>'
    # ))


fig.update_layout(
    yaxis=dict(scaleanchor="x"),
    width=1000,
    height=800,
    # title={
    #'text': f"{plane.capitalize()} view",
    #'y':0.9,
    #'x':0.5,
    #'xanchor': 'center',
    #'yanchor': 'top'}
)

fig.update_xaxes(title_text=bg_atlas.space.axis_labels[0][1])
fig.update_yaxes(title_text=bg_atlas.space.axis_labels[0][0])
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)

fig.show()


In [None]:
# plt.figure()
C = plt.contour(projected_simplified_labels_id, colors="r", linewidths=0.1, levels=400)

plt.figure()
for segs in C.allsegs:
    for seg in segs:
        plt.plot(seg[:, 0], seg[:, 1], lw=0.1, color="r")
plt.show()
# print(len(C.allsegs))


In [None]:
fig = go.Figure()
for seg in C.allsegs[0]:
    fig.add_trace(
        go.Scattergl(x=seg[:, 0], y=seg[:, 1], mode="lines", hoverinfo="skip", line_color="white")
    )
fig.update_layout(showlegend=False)
fig.show()


In [None]:
contours = projected_simplified_labels_id[1:, 1:] - projected_simplified_labels_id[:-1, :-1]
contours = np.clip(contours**2, 0, 1)
contours = np.pad(contours, ((1, 0), (1, 0)))
# do some cleaning on the sides
contours[:, :10] = 0
contours[:, -10:] = 0
contours[:10, :] = 0
contours[-10:, :] = 0

fig = plt.figure(frameon=False)
dpi = 100
fig.set_size_inches(contours.shape[1] / dpi, contours.shape[0] / dpi)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis("off")
plt.contour(contours, colors="w", antialiased=True, linewidths=0.2, origin="image")
with BytesIO() as stream:
    plt.savefig(stream, format="png", dpi=dpi)
    plt.close()
    base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")


fig = go.Figure()

# now image from our data
img = np.uint8(cm.viridis(slice_exp) * 255)
pil_img = Image.fromarray(img)  # PIL image object
prefix = "data:image/png;base64,"
with BytesIO() as stream:
    pil_img.save(stream, format="png", optimize=True, quality=85)
    base64_string_exp = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

fig.add_trace(
    go.Image(visible=True, source=base64_string_exp)
)  # customdata = projected_labels, hovertemplate = '<br>%{customdata}<br><extra></extra>'))


# fig=go.Figure()
# values for x0 and y0 are chosen to match best the true contours, because matplotlib actually shifts the image
fig.add_trace(go.Image(visible=True, source=base64_string, zsmooth="fast"))

# fig.add_trace(go.Contour(visible=True, showscale=False,
#    z=projected_simplified_labels_id,
#    contours=dict(coloring='none'),
#    line_width=1,
#    line_color = 'gold',
# ))


fig.update_xaxes(range=[0, projected_simplified_labels_id.shape[1]])
fig.update_yaxes(range=[projected_simplified_labels_id.shape[0], 0])
fig.show()


## Project initial data onto warped data

In [None]:
array_coordinates = np.array(io.imread("../app/data/tif_files/coors.tif"), dtype=np.float64)
warped_image = np.array(io.imread("../app/data/tif_files/slices.tif"), dtype=np.int16)


In [None]:
@njit
def slice_to_atlas_transform(a, u, v, lambd, mu):
    # equation of a plan in space
    x_atlas = a[0] + lambd * u[0] + mu * v[0]
    y_atlas = a[1] + lambd * u[1] + mu * v[1]
    z_atlas = a[2] + lambd * u[2] + mu * v[2]
    return x_atlas, y_atlas, z_atlas


@njit
def atlas_to_slice_transform(a, u, v, y_atlas, z_atlas):
    # solve the inverse transform by solving a system of linear equations
    A = np.empty((2, 2), dtype=np.float64)

    A[0] = [u[1], v[1]]
    A[1] = [u[2], v[2]]

    A += np.random.normal(0, 0.000000001, (2, 2))  # to solve singularity issues

    b = np.array([y_atlas - a[1], z_atlas - a[2]], dtype=np.float64)

    lambd, mu = np.linalg.solve(A, b)

    return lambd, mu


In [None]:
def solve_plane_equation(
    array_coordinates, slice_index, point_1=(150, 151), point_2=(800, 1200), point_3=(100, 101)
):
    # define a system of linear equation for three points of the plane
    # can't take points on the extremities as Nicholas software is buggued and the origin doesn't linearly maps to the 3D plane
    A = np.zeros((9, 9))
    b = np.zeros((9,))

    A[0] = [point_1[0], 0, 0, point_1[1], 0, 0, 1, 0, 0]
    A[1] = [0, point_1[0], 0, 0, point_1[1], 0, 0, 1, 0]
    A[2] = [0, 0, point_1[0], 0, 0, point_1[1], 0, 0, 1]
    A[3] = [point_2[0], 0, 0, point_2[1], 0, 0, 1, 0, 0]
    A[4] = [0, point_2[0], 0, 0, point_2[1], 0, 0, 1, 0]
    A[5] = [0, 0, point_2[0], 0, 0, point_2[1], 0, 0, 1]
    A[6] = [point_3[0], 0, 0, point_3[1], 0, 0, 1, 0, 0]
    A[7] = [0, point_3[0], 0, 0, point_3[1], 0, 0, 1, 0]
    A[8] = [0, 0, point_3[0], 0, 0, point_3[1], 0, 0, 1]

    b = [
        array_coordinates[slice_index, point_1[0], point_1[1], 0],
        array_coordinates[slice_index, point_1[0], point_1[1], 1],
        array_coordinates[slice_index, point_1[0], point_1[1], 2],
        array_coordinates[slice_index, point_2[0], point_2[1], 0],
        array_coordinates[slice_index, point_2[0], point_2[1], 1],
        array_coordinates[slice_index, point_2[0], point_2[1], 2],
        array_coordinates[slice_index, point_3[0], point_3[1], 0],
        array_coordinates[slice_index, point_3[0], point_3[1], 1],
        array_coordinates[slice_index, point_3[0], point_3[1], 2],
    ]

    u1, u2, u3, v1, v2, v3, a1, a2, a3 = np.linalg.solve(A, b)
    u_atlas = (u1, u2, u3)
    v_atlas = (v1, v2, v3)
    a_atlas = (a1, a2, a3)
    return a_atlas, u_atlas, v_atlas


In [None]:
l_transform_parameters = []
for slice_index in range(array_coordinates.shape[0]):
    a_atlas, u_atlas, v_atlas = solve_plane_equation(array_coordinates, slice_index)
    # coor_from_transform = slice_to_atlas_transform(a_atlas,u_atlas,v_atlas,300,302)
    # coor_from_array = array_coordinates[slice_index,300,302]
    # coor_slice = atlas_to_slice_transform(a_atlas, u_atlas, v_atlas, coor_from_transform[1], coor_from_transform[2])
    l_transform_parameters.append((a_atlas, u_atlas, v_atlas))


In [None]:
# Plot used to debug
if False:
    slice_index = 9
    # get plane going through three points
    point_1 = (150, 151)
    point_2 = (800, 1200)
    point_3 = (100, 101)
    point_1_atlas = array_coordinates[slice_index, point_1[0], point_1[1]]
    point_2_atlas = array_coordinates[slice_index, point_2[0], point_2[1]]
    point_3_atlas = array_coordinates[slice_index, point_3[0], point_3[1]]

    l_x_scatter = [point_1_atlas[0], point_2_atlas[0], point_3_atlas[0]]
    l_y_scatter = [point_1_atlas[1], point_2_atlas[1], point_3_atlas[1]]
    l_z_scatter = [point_1_atlas[2], point_2_atlas[2], point_3_atlas[2]]

    a, u, v = solve_plane_equation(
        array_coordinates, slice_index, point_1=point_1, point_2=point_2, point_3=point_3
    )

    l_x = []
    l_y = []
    l_z = np.zeros((12, 12), dtype=np.float32)
    for i, lambd in enumerate(range(0, 1200, 100)):
        for j, mu in enumerate(range(0, 1200, 100)):
            x_atlas, y_atlas, z_atlas = slice_to_atlas_transform(a, u, v, lambd, mu)
            l_x.append(x_atlas)
            l_y.append(y_atlas)
            l_z[i, j] = z_atlas

    # print(l_x, l_y, l_z)
    surface = go.Surface(z=l_z, x=l_x, y=l_y)

    # point_4 = (300,300)
    # point_4_atlas = array_coordinates[slice_index,point_4[0],point_4[1]]

    scatter = go.Scatter3d(
        x=l_x_scatter, y=l_y_scatter, z=l_z_scatter, mode="markers", marker_size=2
    )
    fig = go.Figure(data=[surface])  # , scatter])
    fig.show()


In [None]:
@njit
def fill_array_projection(
    slice_index,
    array_projection,
    array_projection_filling,
    original_coor,
    atlas_resolution,
    a,
    u,
    v,
    original_slice,
    array_coordinates,
    annotation,
    nearest_neighbour_correction=False,
    atlas_correction=False,
):
    for i_original_slice in range(original_coor.shape[0]):
        for j_original_slice in range(original_coor.shape[1]):
            x_atlas, y_atlas, z_atlas = original_coor[i_original_slice, j_original_slice]
            i, j = np.array(atlas_to_slice_transform(a, u, v, y_atlas, z_atlas))
            i = int(round(i))
            j = int(round(j))
            if i < array_projection.shape[1] and j < array_projection.shape[2] and i > 0 and j > 0:
                try:
                    array_projection[slice_index, i, j] = original_slice[
                        i_original_slice, j_original_slice, 2
                    ]
                    array_projection_filling[slice_index, i, j] = 1
                except:
                    print(
                        i,
                        j,
                        array_projection.shape,
                        i_original_slice,
                        j_original_slice,
                        original_slice.shape,
                    )

    if nearest_neighbour_correction:
        for i in range(array_projection.shape[1]):

            for j in range(array_projection.shape[2]):
                x_atlas, y_atlas, z_atlas = (
                    array_coordinates[slice_index, i, j] * 1000 / atlas_resolution
                )
                # ugly but numba doesn't support np.round
                x_atlas = int(round(x_atlas))
                y_atlas = int(round(y_atlas))
                z_atlas = int(round(z_atlas))
                if (
                    x_atlas < annotation.shape[0]
                    and x_atlas >= 0
                    and y_atlas < annotation.shape[1]
                    and y_atlas >= 0
                    and z_atlas < annotation.shape[2]
                    and z_atlas >= 0
                ):
                    if annotation[x_atlas, y_atlas, z_atlas] != 0:
                        if array_projection_filling[slice_index, i, j] == 0:
                            # only fill missing areas if far from the sides
                            if (
                                i > 20
                                and i < array_projection.shape[1] - 20
                                and j > 20
                                and j < array_projection.shape[2] - 20
                            ):
                                # look for neighbours that are filled in a close window
                                radius = 3
                                array_window = np.empty(
                                    (2 * radius + 1, 2 * radius + 1), dtype=np.float32
                                )
                                for x in range(-radius, radius + 1):
                                    for y in range(-radius, radius + 1):
                                        if (
                                            i + x > 0
                                            and i + x < array_projection.shape[1]
                                            and j + x > 0
                                            and j + y < array_projection.shape[2]
                                        ):
                                            if (
                                                array_projection_filling[slice_index, i + x, j + y]
                                                == 0
                                            ):
                                                array_window[x + radius, y + radius] = np.nan
                                            else:
                                                array_window[
                                                    x + radius, y + radius
                                                ] = array_projection[slice_index, i + x, j + y]
                                avg = np.nanmean(array_window)
                                if np.isnan(avg):
                                    continue
                                clean_window = np.abs(array_window - avg)
                                # numba doesn't support nanargmin...
                                mini = 10000
                                selected_pixel_x = 0
                                selected_pixel_y = 0
                                for x in range(2 * radius + 1):
                                    for y in range(2 * radius + 1):
                                        if clean_window[x, y] < mini:
                                            mini = clean_window[x, y]
                                            selected_pixel_x = x
                                            selected_pixel_y = y

                                array_projection[slice_index, i, j] = array_window[
                                    selected_pixel_x, selected_pixel_y
                                ]
                    elif atlas_correction:
                        array_projection[slice_index, i, j] = 0
                        array_projection_filling[slice_index, i, j] = 1

                elif atlas_correction:
                    array_projection[slice_index, i, j] = 0
                    array_projection_filling[slice_index, i, j] = 1

    return array_projection


In [None]:
print(bg_atlas.annotation.shape)

In [None]:
# process each slice independently as, due to tilting, two slices can map to the same atlas coordinate...
array_projection = np.zeros(warped_image.shape, dtype=np.int16)
array_projection_filling = np.zeros(warped_image.shape, dtype=np.int16)

array_projection_no_atlas = np.zeros(warped_image.shape, dtype=np.int16)
array_projection_filling_no_atlas = np.zeros(warped_image.shape, dtype=np.int16)

array_projection_uncorrected = np.zeros(warped_image.shape, dtype=np.int16)
array_projection_filling_uncorrected = np.zeros(warped_image.shape, dtype=np.int16)

for i in range(array_projection.shape[0]):
    # for i in range(13):
    print("slice " + str(i) + " getting processed")
    a, u, v = l_transform_parameters[i]

    # load corresponding slice and coor
    path = "../app/data/tif_files/deformation_field/"
    filename = (
        path
        + [x for x in os.listdir(path) if str(i + 1) == x.split("slice_")[1].split(".tiff")[0]][0]
    )
    original_coor = np.array(io.imread(filename), dtype=np.float32)

    path = "../app/data/tif_files/original_slices/"
    filename = (
        path
        + [x for x in os.listdir(path) if str(i + 1) == x.split("slice_")[1].split(".tiff")[0]][0]
    )
    original_slice = np.array(io.imread(filename), dtype=np.int16)

    # map back the pixel from the atlas coordinates
    atlas_resolution = 25

    array_projection = fill_array_projection(
        i,
        array_projection,
        array_projection_filling,
        original_coor,
        atlas_resolution,
        a,
        u,
        v,
        original_slice,
        array_coordinates,
        bg_atlas.annotation,
        nearest_neighbour_correction=True,
        atlas_correction=True,
    )
    array_projection_no_atlas = fill_array_projection(
        i,
        array_projection_no_atlas,
        array_projection_filling_no_atlas,
        original_coor,
        atlas_resolution,
        a,
        u,
        v,
        original_slice,
        array_coordinates,
        bg_atlas.annotation,
        nearest_neighbour_correction=True,
        atlas_correction=False,
    )
    array_projection_uncorrected = fill_array_projection(
        i,
        array_projection_uncorrected,
        array_projection_filling_uncorrected,
        original_coor,
        atlas_resolution,
        a,
        u,
        v,
        original_slice,
        array_coordinates,
        bg_atlas.annotation,
        nearest_neighbour_correction=False,
        atlas_correction=False,
    )


## Plot comparison of the processing for all slices

In [None]:
if True:
    fig = make_subplots(rows=32, cols=5, subplot_titles=tuple([context + " slice " + str(x) for x in range(1,33) for context in ['original', 'uncorrected', 'corrected without atlas', 'corrected with atlas', 'warped image']]))
    for index_image in range(20):

        # compute low-res image
        array_original_image = np.array(io.imread("../app/data/tif_files/original_slices/slice_" + str(index_image+1)+".tiff"), dtype = np.int32)[:,:,2]    
        img = np.uint8(cm.viridis(array_original_image) * 255)
        pil_img = Image.fromarray(img)  # PIL image object
        prefix = "data:image/png;base64,"
        with BytesIO() as stream:
            pil_img.save(stream, format="png", optimize=True, quality=85)
            base64_string_exp = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

        fig.add_trace(go.Image(visible=True,source=base64_string_exp,), row=index_image +1, col=1)


        # compute image reconstituted from low-res with no correction
        img = np.uint8(cm.viridis(array_projection_uncorrected[index_image]) * 255)
        pil_img = Image.fromarray(img)  # PIL image object
        prefix = "data:image/png;base64,"
        with BytesIO() as stream:
            pil_img.save(stream, format="png", optimize=True, quality=85)
            base64_string_exp = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

        fig.add_trace(go.Image(visible=True,source=base64_string_exp,), row=index_image +1, col=2)

        # compute image reconstituted from low-res with no correction
        img = np.uint8(cm.viridis(array_projection_no_atlas[index_image]) * 255)
        pil_img = Image.fromarray(img)  # PIL image object
        prefix = "data:image/png;base64,"
        with BytesIO() as stream:
            pil_img.save(stream, format="png", optimize=True, quality=85)
            base64_string_exp = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

        fig.add_trace(go.Image(visible=True,source=base64_string_exp,), row=index_image +1, col=3)


        # compute image reconstituted from low-res image corrected with atlas
        img = np.uint8(cm.viridis(array_projection[index_image]) * 255)
        pil_img = Image.fromarray(img)  # PIL image object
        prefix = "data:image/png;base64,"
        with BytesIO() as stream:
            pil_img.save(stream, format="png", optimize=True, quality=85)
            base64_string_exp = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

        fig.add_trace(go.Image(visible=True,source=base64_string_exp,), row=index_image +1, col=4)

        #compare with warped image
        img = np.uint8(cm.viridis(warped_image[index_image]) * 255)
        pil_img = Image.fromarray(img)  # PIL image object
        prefix = "data:image/png;base64,"
        with BytesIO() as stream:
            pil_img.save(stream, format="png", optimize=True, quality=85)
            base64_string_exp = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")

        fig.add_trace(go.Image(visible=True, source=base64_string_exp, ), row=index_image +1, col=5 )


        fig.update_xaxes(title_text=bg_atlas.space.axis_labels[0][1])
        fig.update_yaxes(title_text=bg_atlas.space.axis_labels[0][0])
        fig.update_xaxes(showticklabels=False)
        fig.update_yaxes(showticklabels=False)
        fig.update_layout(margin=dict(t=30, r=0, b=0, l=0), height = 1200*20, width = 6000)

    #fig.show()
    fig.write_image("comparison.png")

        

## Plot 3d voxels with annotation

In [None]:
def cube_points(position3d):
    # position3d is either a 3-list or an array of shape(3,)
    # where a unit cube defined below is translated

    # define an array of shape(8, 3) as a template for a 3d cube ;
    # each row, cube[k], defines the coordinates of a cube vertex

    cube = np.array(
        [[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1]],
        dtype=float,
    )

    # cube*=0.25

    cube += np.asarray(position3d, dtype=float)

    # print(cube)
    # the last sum translates  the template cube(i.e. its representative points)
    # on the direction OP, with  P of coords position3d

    return cube


def triangulate_cube_faces(positions):
    # positions is an array of shape (N, 3) containing all cube (voxel) positions of an object
    # This function sets up all voxels, extract their defining vertices as arrays of shape (3,)
    # and reduces their number, by determining the array of unique vertices
    # the voxel faces are triangularized and from their faces one extracts the lists of indices I, J, K
    # to define a mesh3d;

    positions = np.asarray(positions)
    if positions.shape[1] != 3:
        raise ValueError("Wrong shape for positions of cubes in your data")
    all_cubes = [cube_points(pos) for pos in positions]
    p, q, r = np.array(all_cubes).shape
    vertices, ixr = np.unique(np.array(all_cubes).reshape(p * q, r), return_inverse=True, axis=0)
    I = []
    J = []
    K = []
    # each triplei (i, j, k) defines a face/triangle
    for k in range(len(all_cubes)):
        I.extend(
            np.take(
                ixr,
                [
                    8 * k,
                    8 * k + 2,
                    8 * k + 4,
                    8 * k + 6,
                    8 * k + 5,
                    8 * k + 2,
                    8 * k + 4,
                    8 * k + 3,
                    8 * k + 6,
                    8 * k + 3,
                    8 * k + 4,
                    8 * k + 1,
                ],
            )
        )
        J.extend(
            np.take(
                ixr,
                [
                    8 * k + 1,
                    8 * k + 3,
                    8 * k + 5,
                    8 * k + 7,
                    8 * k + 1,
                    8 * k + 6,
                    8 * k,
                    8 * k + 7,
                    8 * k + 2,
                    8 * k + 7,
                    8 * k + 5,
                    8 * k,
                ],
            )
        )
        K.extend(
            np.take(
                ixr,
                [
                    8 * k + 2,
                    8 * k,
                    8 * k + 6,
                    8 * k + 4,
                    8 * k + 2,
                    8 * k + 5,
                    8 * k + 3,
                    8 * k + 4,
                    8 * k + 3,
                    8 * k + 6,
                    8 * k + 1,
                    8 * k + 4,
                ],
            )
        )

    return vertices, I, J, K


In [None]:
array_position = None
for i in range(0,32,15):
    print('slice ' + str(i) +' getting processed')    
    #load corresponding slice and coor
    path = "../app/data/tif_files/deformation_field/"
    filename = path + [x for x in os.listdir(path) if  str(i+1) == x.split("slice_")[1].split('.tiff')[0]][0]
    original_coor = np.array(io.imread(filename), dtype=np.float32)
    
    original_coor_rescaled = np.array(np.round(original_coor * 1000 / 25), dtype = np.int32)
    array_to_keep = np.zeros(original_coor_rescaled.shape[:2], dtype = np.int16)
    for idx_x in range(original_coor_rescaled.shape[0]):
        for idx_y in range(original_coor_rescaled.shape[1]):
            x, y, z = original_coor_rescaled[idx_x, idx_y]
            if x<bg_atlas.annotation.shape[0] and x>=0 and y<bg_atlas.annotation.shape[1] and y>=0 and z<bg_atlas.annotation.shape[2] and z>=0:
                if bg_atlas.annotation[x, y, z]!=0:
                    array_to_keep[idx_x,idx_y] = 1
                    
    array_position_slice = np.array([original_coor_rescaled[x,y] for x in range(original_coor_rescaled.shape[0]) for y in range(original_coor_rescaled.shape[1])if array_to_keep[x,y]==1 ])
    if array_position is None:
        array_position = array_position_slice
    else:
        array_position = np.vstack((array_position, array_position_slice))
    
                    

In [None]:
# the plot is extremely heavy so we better not plot it everytime
if True:
    vertices, I, J, K = triangulate_cube_faces(array_position)
    X, Y, Z = vertices.T
    mesh3d = go.Mesh3d(x=X, y=Y, z=Z, i=I, j=J, k=K, flatshading=False, color="grey", opacity=0.2)

    layout = go.Layout(
        width=650,
        height=750,
        # title_text='Voxel data visualization',
        title_x=0.5,
    )
    fig = go.Figure(data=[mesh3d], layout=layout)
    fig.layout.scene.update(
        xaxis_showticklabels=False,
        xaxis_ticks="",
        xaxis_title="",
        yaxis_showticklabels=False,
        yaxis_ticks="",
        yaxis_title="",
        zaxis_showticklabels=False,
        zaxis_ticks="",
        zaxis_title="",
        camera_eye_x=1.4,
        camera_eye_y=-2.5,
        camera_eye_z=1,
    )

    fig.data[0].update(
        lighting=dict(ambient=0.5, diffuse=1, fresnel=4, specular=0.5, roughness=0.5)
    )

    fig.update_layout(scene_aspectmode="data")
    fig.show()


## Plot our slices in 3D

In [None]:
array_coordinates = np.array(io.imread("../app/data/tif_files/coors.tif"), dtype=np.float64)
warped_image = np.array(io.imread("../app/data/tif_files/slices.tif"), dtype=np.int16)


In [None]:
print(array_coordinates.shape)
print(warped_image.shape)
print(array_projection.shape)


In [None]:
# decrease array size because undoable otherwise
from scipy.ndimage.interpolation import map_coordinates

new_dims = []
for original_length, new_length in zip(
    array_projection.shape, (32, int(round(918 / 7)), int(round(1311 / 7)))
):
    new_dims.append(np.linspace(0, original_length - 1, new_length))

coords = np.meshgrid(*new_dims, indexing="ij")
array_projection_small = map_coordinates(array_projection, coords)


In [None]:
def get_surface(slice_index):

    # get plane going through three points
    point_1 = (150, 151)
    point_2 = (800, 1200)
    point_3 = (100, 101)
    point_1_atlas = array_coordinates[slice_index, point_1[0], point_1[1]]
    point_2_atlas = array_coordinates[slice_index, point_2[0], point_2[1]]
    point_3_atlas = array_coordinates[slice_index, point_3[0], point_3[1]]

    l_x_scatter = [point_1_atlas[0], point_2_atlas[0], point_3_atlas[0]]
    l_y_scatter = [point_1_atlas[1], point_2_atlas[1], point_3_atlas[1]]
    l_z_scatter = [point_1_atlas[2], point_2_atlas[2], point_3_atlas[2]]

    a, u, v = solve_plane_equation(
        array_coordinates, slice_index, point_1=point_1, point_2=point_2, point_3=point_3
    )

    l_x = []
    l_y = []

    l_z = np.zeros(array_projection_small[slice_index].shape, dtype=np.float32)

    for i, lambd in enumerate(range(array_projection_small[slice_index].shape[0])):
        for j, mu in enumerate(range(array_projection_small[slice_index].shape[1])):
            x_atlas, y_atlas, z_atlas = slice_to_atlas_transform(a, u, v, lambd * 3, mu * 3)
            l_x.append(z_atlas * 25 / 1000 * 20)
            l_y.append(x_atlas * 25 / 1000 * 20)

            ok = False
            # if x_atlas<bg_atlas.annotation.shape[0] and x_atlas>=0 and y_atlas<bg_atlas.annotation.shape[1] and y_atlas>=0 and z_atlas<bg_atlas.annotation.shape[2] and z_atlas>=0:
            #    if bg_atlas.annotation[int(round(x_atlas)), int(round(y_atlas)), int(round(z_atlas))]!=0:
            l_z[i, j] = y_atlas * 25 / 1000 * 20
            #        ok = True
            # if not ok:
            #    l_z[i,j] = np.nan

    surface = go.Surface(
        z=l_z,
        x=l_x,
        y=l_y,
        surfacecolor=array_projection_small[slice_index].astype(np.int32),
        cmin=0,
        cmax=255,
        colorscale="Viridis",
        opacityscale=[[0, 0], [0.1, 1], [1, 1]],
        showscale=False,
    )
    return surface


layout = go.Layout(
    # title="Test mesh with flatshading",
    # font=dict(size=16, color='white'),
    # width=700,
    # height=700,
    # scene_xaxis_visible=False,
    # scene_yaxis_visible=False,
    # scene_zaxis_visible=False,
)


fig = go.Figure(data=[get_surface(i) for i in range(32)] + [isosurf], layout=layout)
fig.update_scenes(zaxis_autorange="reversed")
fig.update_scenes(aspectmode="data")


fig.show()


In [None]:
colorscale = [
    [0.0, "rgba(165,0,38, 0.)"],
    [0.1111111111111111, "rgba(215,48,39,1)"],
    [0.2222222222222222, "rgba(244,109,67,1)"],
    [0.3333333333333333, "rgba(253,174,97,1)"],
    [0.4444444444444444, "rgba(254,224,144,1)"],
    [0.5555555555555556, "rgba(224,243,248,1)"],
    [0.6666666666666666, "rgba(171,217,233,1)"],
    [0.7777777777777778, "rgba(116,173,209,1)"],
    [0.8888888888888888, "rgba(69,117,180,1)"],
    [1.0, "rgba(49,54,149,1)"],
]

# Define frames
nb_frames = array_projection_small.shape[0]


def get_surface(slice_index):

    # get plane going through three points
    point_1 = (150, 151)
    point_2 = (800, 1200)
    point_3 = (100, 101)
    point_1_atlas = array_coordinates[slice_index, point_1[0], point_1[1]]
    point_2_atlas = array_coordinates[slice_index, point_2[0], point_2[1]]
    point_3_atlas = array_coordinates[slice_index, point_3[0], point_3[1]]

    l_x_scatter = [point_1_atlas[0], point_2_atlas[0], point_3_atlas[0]]
    l_y_scatter = [point_1_atlas[1], point_2_atlas[1], point_3_atlas[1]]
    l_z_scatter = [point_1_atlas[2], point_2_atlas[2], point_3_atlas[2]]

    a, u, v = solve_plane_equation(
        array_coordinates, slice_index, point_1=point_1, point_2=point_2, point_3=point_3
    )

    ll_x = []
    ll_y = []
    ll_z = []

    for i, lambd in enumerate(range(array_projection_small[slice_index].shape[0])):
        l_x = []
        l_y = []
        l_z = []
        for j, mu in enumerate(range(array_projection_small[slice_index].shape[1])):
            x_atlas, y_atlas, z_atlas = (
                np.array(slice_to_atlas_transform(a, u, v, lambd * 3, mu * 3)) * 25 / 1000
            )
            l_x.append(z_atlas)
            l_y.append(x_atlas)
            l_z.append(y_atlas)

        if l_x != []:
            ll_x.append(l_x)
            ll_y.append(l_y)
            ll_z.append(l_z)

    surface = go.Surface(
        z=np.array(ll_z),
        x=np.array(ll_x),
        y=np.array(ll_y),
        surfacecolor=array_projection_small[slice_index].astype(np.int32),
        cmin=0,
        cmax=255,
        colorscale=colorscale,
        showscale=False,
    )
    return surface


# layout = go.Layout(
# title="Test mesh with flatshading",
# font=dict(size=16, color='white'),
# width=700,
# height=700,
# scene_xaxis_visible=False,
# scene_yaxis_visible=False,
# scene_zaxis_visible=False,
#        )


# fig = go.Figure(data=[get_surface(i) for i in range(32)] , layout = layout)
# fig.update_scenes(zaxis_autorange="reversed")
# fig.update_scenes(aspectmode = "data")


fig = go.Figure(
    frames=[
        go.Frame(data=get_surface(i), name=str(i + 1))
        if i != 12
        else go.Frame(data=get_surface(i - 1), name=str(i + 1))
        for i in range(0, 32, 1)
    ]
)
fig.add_trace(get_surface(9))


def frame_args(duration):
    return {
        "frame": {"duration": duration},
        "mode": "immediate",
        "fromcurrent": True,
        "transition": {"duration": duration, "easing": "linear"},
    }


sliders = [
    {
        "pad": {"b": 10, "t": 60},
        "len": 0.9,
        "x": 0.1,
        "y": 0,
        "steps": [
            {
                "args": [[f.name], frame_args(0)],
                "label": str(k),
                "method": "animate",
            }
            for k, f in enumerate(fig.frames)
        ],
    }
]

# Layout
fig.update_layout(
    title="Slices in volumetric data",
    # width=600,
    # height=600,
    scene=dict(
        # zaxis=dict(autorange=False),
        aspectratio=dict(x=1.5, y=1, z=1),
        # zaxis_autorange="reversed",
        # aspectmode = "data",
        yaxis=dict(range=[0.001 * 40, 0.008 * 40], autorange=False),
        zaxis=dict(range=[0.002 * 40, -221 * 10**-6 * 40], autorange=False),
        xaxis=dict(range=[0.004 * 40, 0.008 * 40], autorange=False),
    ),
    updatemenus=[
        {
            "buttons": [
                {
                    "args": [None, frame_args(50)],
                    "label": "&#9654;",  # play symbol
                    "method": "animate",
                },
                {
                    "args": [[None], frame_args(0)],
                    "label": "&#9724;",  # pause symbol
                    "method": "animate",
                },
            ],
            "direction": "left",
            "pad": {"r": 10, "t": 70},
            "type": "buttons",
            "x": 0.1,
            "y": 0,
        }
    ],
    sliders=sliders,
)


fig.show()
