In [None]:
%matplotlib inline

# Import dependencies
from glob import glob
import math
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import zarr

# pio.renderers.default = 'iframe'
pio.renderers.default = "notebook"

In [None]:
valid_id = "TS_6_4"
npy_path = f"/home/naoya/kaggle/czii/input/numpy-dataset/train_label_{valid_id}.npy"

images = np.load(npy_path)

In [None]:
print(np.max(images))

In [None]:
import numpy as np
import plotly.graph_objs as go

def _normalize_array(zarr_array: zarr.core.Array) -> np.ndarray:
    """Normalize each slice of the array."""
    arr = np.array(zarr_array)
    mins = arr.min(axis=(1, 2), keepdims=True)
    maxs = arr.max(axis=(1, 2), keepdims=True)
    arr = ((arr - mins) / (maxs - mins) * 255).astype(np.uint8)
    return arr

def _plot_slice_surface(array: zarr.core.Array, z_index: int):
    """Plot surface plot of specified z-slice of array."""
    return go.Surface(
        z=z_index * np.ones((array.shape[1], array.shape[2])),
        surfacecolor=array[z_index],
        colorscale="gray",
        cmin=0,
        cmax=6,
        showscale=False,
    )

def plot_animatable_slices(
    arr: np.ndarray,
    step: int = 10,
    init_z: int = 0,
    title: str = "",
    fig: go.Figure | None = None,
) -> go.Figure:
    """
    Plot animatable slices of a 3D NumPy array.

    Parameters:
        arr (np.ndarray): A 3D NumPy array to visualize.
        step (int): Step size for the z-index to create animation frames.
        init_z (int): Initial z-index to display.
        title (str): Title of the plot.
        fig (go.Figure | None): An existing figure to add data to, or None to create a new one.

    Returns:
        go.Figure: A Plotly figure object with animatable slices.
    """
    if arr.ndim != 3:
        raise ValueError("Input array must be 3-dimensional.")
    
    # Normalize the input array
    # arr = _normalize_array(arr)
    z_dim = arr.shape[0]

    # Z-indices to animate
    z_indices = range(0, z_dim, step)

    # Initial plot
    base_traces = list(fig.data) if fig is not None else []
    fig = go.Figure(
        data=base_traces + [_plot_slice_surface(arr, init_z)],
        layout=go.Layout(
            title=title,
            scene=dict(
                yaxis=dict(autorange="reversed"),
                zaxis=dict(range=[0, z_dim], autorange=False),
                aspectratio=dict(x=1, y=1, z=1),
                camera=dict(eye=dict(x=1.25, y=-1.25, z=1.25)),
            ),
            width=800,
            height=800,
            template="plotly_dark",
            updatemenus=[
                dict(
                    type="buttons",
                    buttons=[
                        dict(
                            label="▶",
                            method="animate",
                            args=[
                                None,
                                dict(
                                    frame=dict(duration=500, redraw=True),
                                    fromcurrent=True,
                                ),
                            ],
                        )
                    ],
                    font=dict(color="black"),
                )
            ],
        ),
    )

    # Set animation frames
    frames = []
    for i in z_indices:
        frame = go.Frame(data=base_traces + [_plot_slice_surface(arr, i)], name=str(i))
        frames.append(frame)
    fig.frames = frames

    # Setup slider
    sliders = [
        dict(
            active=0,
            currentvalue=dict(prefix="z-index: "),
            pad=dict(t=50),
            steps=[
                dict(
                    label=str(i),
                    method="animate",
                    args=[
                        [str(i)],
                        dict(
                            mode="immediate",
                            frame=dict(duration=200, redraw=True),
                            transition=dict(duration=0),
                        ),
                    ],
                )
                for i in z_indices
            ],
        )
    ]
    fig.update_layout(sliders=sliders)
    return fig


In [None]:
plot_animatable_slices(images, step=2, title=valid_id, fig=None)