In [None]:
import blosc2
import numpy as np
import plotly.express as px
import ipywidgets as widgets
import plotly.graph_objects as go
import pandas as pd

In [None]:
path = '/home/blosc/gaia/gaia-3d-windows-int8-3.b2nd'
# path = '/Users/martaiborra/gaia_plots/gaia-3d-windows-int8-3.b2nd'

In [None]:
arr = blosc2.open(path)
arr.info

In [None]:
shape = arr.shape
cube_shape = arr.chunks
cube_shape = np.array([100, 100, 100])
cube_shape = np.array([50, 50, 50])

step = 10

x_origin = widgets.IntText(
    min=0,
    max=shape[0] - cube_shape[0],
    step=1,
    description='X Origin:',
    value=5_000
)
y_origin = widgets.IntText(
    min=0,
    max=shape[1] - cube_shape[1],
    step=1,
    description='Y Origin:',
    value=5_000
)
z_origin = widgets.IntText(
    min=0,
    max=shape[2] - cube_shape[2],
    step=1,
    description='Z Origin:',
    value=5_000
)

In [None]:
axis_widget = widgets.Dropdown(
    options=[('x', 0), ('y', 1), ('z', 2)],
    value=0,
    description='Travel axis:',
    disabled=False,
)

In [None]:
def init(obj):
    global axis
    global step
    global start
    global stop
    axis = axis_widget.value
    start = np.array([x_origin.value, y_origin.value, z_origin.value])
    stop = start + np.array(cube_shape)
    global axis_origin
    axis_origin = start[axis]
    # print("start ", start)
    # print("stop ", stop)

In [None]:
display(axis_widget)
coordinates_widgets = widgets.HBox(
[x_origin, y_origin, z_origin])
display(coordinates_widgets)

axis_widget.observe(init, names=['value'])
x_origin.observe(init, names=['value'])
y_origin.observe(init, names=['value'])
z_origin.observe(init, names=['value'])

a = 0
init(a)

In [None]:
axis_names = ["x", "y", "z"]
names = axis_names + ["marker_size"]

i=0
dfs = []
while stop[axis] < shape[axis]:
    npa = arr[start[0]:stop[0], start[1]:stop[1], start[2]:stop[2]]
    idx = np.indices(npa.shape)
    
    x = idx[0, :, :].flatten()
    y = idx[1, :, :].flatten()
    z = idx[2, :, :].flatten()
    """
    if axis == 0:
        x += start[0]
    elif axis== 1:
        y += start[1]
    else:
        z += start[2]
    """

    size = npa.flatten()
    df_step = pd.DataFrame(x)
    df_step = pd.concat([df_step, pd.DataFrame(y)], axis=1)
    df_step = pd.concat([df_step, pd.DataFrame(z)], axis=1)
    df_step = pd.concat([df_step, pd.DataFrame(size)], axis=1)
    df_step.columns = names
    
    # Remove dots with no stars
    df_step = df_step[~df_step['marker_size'].isin([0])]
    df_step.reset_index(drop=True)

    if df_step.empty:
        l = [0] * 4
        l[axis] = start[axis]
        df_step = pd.DataFrame([l], columns=names)
    df_step.marker_size = df_step.marker_size.astype('int32')
    dfs.append(df_step)
            
    
    start[axis] += step
    stop[axis] += step
    i+=1
xmax = x[-1] + 1
ymax = y[-1] + 1

nframes = i

In [None]:
sliders_labels =  [str(l) for l in range(axis_origin, shape[axis], step)]

fig = go.Figure(frames=[go.Frame(data=go.Scatter3d(
    z=dfs[k]['z'], x=dfs[k]['x'], y=dfs[k]['y'], mode='markers', marker=dict(size=dfs[k]["marker_size"]*10)
    #cmin=0, cmax=100
    ),
    name=str(k) # you need to name the frame for the animation to behave properly
    )
    for k in range(nframes)])

# Add data to be displayed before animation starts
fig.add_trace(go.Scatter3d(
    z=dfs[0]['z'], x=dfs[0]['x'], y=dfs[0]['y'], mode='markers',  marker=dict(size=dfs[0]["marker_size"]*10)
    ))


def frame_args(duration):
    return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": duration, "easing": "linear"},
        }

sliders = [
            {
                "pad": {"b": 10, "t": 60},
                "len": 0.9,
                "x": 0.1,
                "y": 0,
                "steps": [
                    {
                        "args": [[f.name], frame_args(0)],
                        "label": sliders_labels[k],
                        "method": "animate",
                    }
                    for k, f in enumerate(fig.frames)
                ],
            }
        ]

# Layout
scene = dict(xaxis=dict(range=[0, cube_shape[0]], autorange=False),
            yaxis=dict(range=[0, cube_shape[1]], autorange=False),
         zaxis=dict(range=[0, cube_shape[2]], autorange=False),
                aspectratio=dict(x=1, y=1, z=1),
                )

fig.update_layout(
         title='Slices in volumetric data',
         width=800,
         height=800,
         scene=scene,
         updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(50)],
                        "label": "&#9654;", # play symbol
                        "method": "animate",
                },
                    {
                        "args": [[None], frame_args(0)],
                        "label": "&#9724;", # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
         ],
         sliders=sliders
)

fig.show()