In [1]:
# ---- Imports ----
import vtk
import numpy as np
import tkinter as tk
from tkinter import ttk
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.widgets import RectangleSelector
from scipy.ndimage import label
from skimage.measure import regionprops
import pandas as pd
from tkinter import filedialog
import os
import sys

sys.path.append(r"C:\Users\adamd\Downloads\bachelor-1")

from pypore3d.p3dSkelPy import (
    py_p3dLKCSkeletonization,
    py_p3dSkeletonPruning,
    py_p3dSkeletonLabeling,
    py_p3dSkeletonAnalysis,
)
from pypore3d.p3dSITKPy import py_p3dReadRaw8 , py_p3dWriteRaw8

In [2]:
def get_color_map(name):
    """Return a VTK color transfer function for a given colormap."""
    color_tf = vtk.vtkColorTransferFunction()
    if name == "Grayscale":
        color_tf.AddRGBPoint(0, 0, 0, 0)
        color_tf.AddRGBPoint(255, 1, 1, 1)
    elif name == "Hot":
        color_tf.AddRGBPoint(0, 0.1, 0, 0)
        color_tf.AddRGBPoint(128, 1, 0.5, 0)
        color_tf.AddRGBPoint(255, 1, 1, 1)
    elif name == "Cool":
        color_tf.AddRGBPoint(0, 0, 1, 1)
        color_tf.AddRGBPoint(255, 1, 0, 1)
    elif name == "Jet":
        color_tf.AddRGBPoint(0, 0, 0, 0.5)
        color_tf.AddRGBPoint(128, 0, 1, 0)
        color_tf.AddRGBPoint(255, 1, 0, 0)
    else:
        color_tf.AddRGBPoint(0, 0, 0, 0)
        color_tf.AddRGBPoint(255, 1, 1, 1)
    return color_tf


def numpy2VTK(img, spacing=[1.0, 1.0, 1.0]):
    """Convert a NumPy array (3D) to a VTK image."""
    importer = vtk.vtkImageImport()
    img_data = img.astype('uint8')
    img_string = img_data.tobytes()
    dim = img.shape

    importer.SetDataExtent(0, dim[2] - 1, 0, dim[1] - 1, 0, dim[0] - 1)
    importer.SetWholeExtent(0, dim[2] - 1, 0, dim[1] - 1, 0, dim[0] - 1)
    importer.CopyImportVoidPointer(img_string, len(img_string))
    importer.SetDataScalarTypeToUnsignedChar()
    importer.SetNumberOfScalarComponents(1)
    importer.SetDataSpacing(spacing[0], spacing[1], spacing[2])
    importer.SetDataOrigin(0, 0, 0)

    return importer


def volume_render(data, colormap="Grayscale", spacing=[1.0, 1.0, 1.0]):
    """Render the 3D volume with the selected colormap."""
    importer = numpy2VTK(data, spacing)
    # Transfer Functions
    opacity_tf = vtk.vtkPiecewiseFunction()
    color_tf = get_color_map(colormap)
    # Intensity ranges for opacity
    opacity_tf.AddPoint(0, 0.0)
    opacity_tf.AddPoint(255, 1.0)
    volMapper = vtk.vtkGPUVolumeRayCastMapper()
    volMapper.SetInputConnection(importer.GetOutputPort())
    volProperty = vtk.vtkVolumeProperty()
    volProperty.SetColor(color_tf)
    volProperty.SetScalarOpacity(opacity_tf)
    volProperty.ShadeOn()
    volProperty.SetInterpolationTypeToLinear()
    volProperty.SetAmbient(0.3)
    volProperty.SetDiffuse(0.7)
    volume = vtk.vtkVolume()
    volume.SetMapper(volMapper)
    volume.SetProperty(volProperty)
    return volume

def make_scrollable(parent, *, enable_x=False):
    """
    Creates a scrollable area inside 'parent' using a Canvas + Scrollbar(s).
    Returns: (container_frame, scrollable_inner_frame)
      - Put your real UI widgets inside scrollable_inner_frame.
    """
    container = ttk.Frame(parent)
    container.pack(fill="both", expand=True)

    canvas = tk.Canvas(container, highlightthickness=0)
    canvas.pack(side="left", fill="both", expand=True)

    vbar = ttk.Scrollbar(container, orient="vertical", command=canvas.yview)
    vbar.pack(side="right", fill="y")
    canvas.configure(yscrollcommand=vbar.set)

    xbar = None
    if enable_x:
        xbar = ttk.Scrollbar(container, orient="horizontal", command=canvas.xview)
        xbar.pack(side="bottom", fill="x")
        canvas.configure(xscrollcommand=xbar.set)

    inner = ttk.Frame(canvas)
    window_id = canvas.create_window((0, 0), window=inner, anchor="nw")

    def _on_inner_configure(event=None):
        # Update scroll region to match inner frame size
        canvas.configure(scrollregion=canvas.bbox("all"))

    def _on_canvas_configure(event):
        # Make inner frame match canvas width (so it behaves like a page)
        canvas.itemconfigure(window_id, width=event.width)

    inner.bind("<Configure>", _on_inner_configure)
    canvas.bind("<Configure>", _on_canvas_configure)

    # Mouse wheel scrolling (Windows / macOS / Linux)
    def _on_mousewheel(event):
        # Windows/mac: event.delta, Linux uses Button-4/5
        if event.delta:
            canvas.yview_scroll(int(-1 * (event.delta / 120)), "units")

    def _bind_mousewheel(_):
        canvas.bind_all("<MouseWheel>", _on_mousewheel)    # Windows/mac
        canvas.bind_all("<Button-4>", lambda e: canvas.yview_scroll(-1, "units"))  # Linux up
        canvas.bind_all("<Button-5>", lambda e: canvas.yview_scroll(1, "units"))   # Linux down

    def _unbind_mousewheel(_):
        canvas.unbind_all("<MouseWheel>")
        canvas.unbind_all("<Button-4>")
        canvas.unbind_all("<Button-5>")

    canvas.bind("<Enter>", _bind_mousewheel)
    canvas.bind("<Leave>", _unbind_mousewheel)

    return container, inner



In [3]:
def render_blobs(centroids, renderer, sphere_radius=5, color=(1, 0, 0)):
    """Render blobs as spheres in the 3D volume using centroids."""
    points = vtk.vtkPoints()
    for centroid in centroids:
        points.InsertNextPoint(centroid)
    poly_data = vtk.vtkPolyData()
    poly_data.SetPoints(points)
    sphere_source = vtk.vtkSphereSource()
    sphere_source.SetRadius(sphere_radius)
    glyph3D = vtk.vtkGlyph3D()
    glyph3D.SetSourceConnection(sphere_source.GetOutputPort())
    glyph3D.SetInputData(poly_data)
    glyph3D.Update()
    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputConnection(glyph3D.GetOutputPort())
    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetColor(*color)  # Set initial blob color (default red)
    renderer.AddActor(actor)
    return actor

#Convert a 3D skeleton volume into a VTK actor made of tubes.
#skel_vol : 3D numpy array (non-zero voxels = skeleton)
#spacing  : voxel spacing (x, y, z)
#tube_radius : tube radius in world units
#color    : RGB color of tubes
#max_voxels : safety cap; if skeleton has more voxels than this,

def skeleton_volume_to_tube_actor(skel_vol, spacing=(1.0, 1.0, 1.0),
                                  tube_radius=0.15, color=(0.0, 1.0, 0.0),
                                  max_voxels=250_000):
    skel_vol = np.asarray(skel_vol)
    if skel_vol.ndim != 3:
        raise ValueError("skeleton_volume_to_tube_actor expects a 3D volume")

    # Locations of skeleton voxels
    indices = np.argwhere(skel_vol > 0)  # [ix, iy, iz]
    total_vox = indices.shape[0]
    print(f"[skeleton_volume_to_tube_actor] skeleton voxels = {total_vox}")

    if total_vox == 0:
        print("No skeleton voxels found.")
        return None

    # if too dense, take a regular stride sample for visualization
    if total_vox > max_voxels:
        stride = int(np.ceil(total_vox / max_voxels))
        print(f"[skeleton_volume_to_tube_actor] downsampling skeleton voxels "
              f"by stride {stride} for visualization.")
        indices = indices[::stride]

    sx, sy, sz = spacing

    points = vtk.vtkPoints()
    id_map = {}

    # Create a point for each skeleton voxel in the sampled grid
    for pid, (ix, iy, iz) in enumerate(indices):
        px = float(ix) * sx
        py = float(iy) * sy
        pz = float(iz) * sz
        points.InsertNextPoint(px, py, pz)
        id_map[(int(ix), int(iy), int(iz))] = pid

    # Connect 6-neighbour voxels with lines
    neighbours = [
        (1, 0, 0), (-1, 0, 0),
        (0, 1, 0), (0, -1, 0),
        (0, 0, 1), (0, 0, -1),
    ]

    lines = vtk.vtkCellArray()
    for (ix, iy, iz), pid0 in id_map.items():
        for dx, dy, dz in neighbours:
            jx, jy, jz = ix + dx, iy + dy, iz + dz
            key2 = (jx, jy, jz)
            if key2 in id_map:
                pid1 = id_map[key2]
                if pid1 > pid0:
                    line = vtk.vtkLine()
                    line.GetPointIds().SetId(0, pid0)
                    line.GetPointIds().SetId(1, pid1)
                    lines.InsertNextCell(line)

    poly = vtk.vtkPolyData()
    poly.SetPoints(points)
    poly.SetLines(lines)

    # Turn lines into thin tubes
    tube = vtk.vtkTubeFilter()
    tube.SetInputData(poly)
    tube.SetRadius(tube_radius)
    tube.SetNumberOfSides(6)
    tube.CappingOff()
    tube.Update()

    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputConnection(tube.GetOutputPort())

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetColor(*color)

    return actor

#Add a Skeletonization tab using PyPore3D LKC skeletonization.
#- root: main Tk window
#- notebook: ttk.Notebook
#- renderer3D: shared VTK renderer (same as volume)
#- get_active_volume: function returning current 3D numpy volume (ROI or full)
#- spacing: voxel spacing (x, y, z)
def add_skeletonization_tab(root, notebook, renderer3D, get_active_volume,
                            spacing=(1.0, 1.0, 1.0),
                            world_dimensions=(700, 700, 700),
                            npy_default_dir=None):
    tab_skel = ttk.Frame(notebook)
    notebook.add(tab_skel, text="Skeletonization")

    # Make Tab 3 scrollable (same idea as Tab 4) so no controls are hidden
    _, scroll_inner = make_scrollable(tab_skel, enable_x=False)

    # Layout: left controls, right info + bubble chart
    control_frame = tk.Frame(scroll_inner)
    control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=10, pady=10)

    right_frame = tk.Frame(scroll_inner)
    right_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=10, pady=10)

    # ---------- STATUS + LOG ----------
    status_label = tk.Label(right_frame, text="Status: idle", anchor="w", justify="left")
    status_label.pack(fill=tk.X, pady=(0, 5))

    log_text = tk.Text(right_frame, height=10)
    log_text.pack(fill=tk.X)

    def log(msg):
        log_text.insert(tk.END, msg + "\n")
        log_text.see(tk.END)
        log_text.update_idletasks()

    # ---------- BUBBLE CHART (branch nodes) ----------
    from matplotlib.figure import Figure
    from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

    fig_nodes = Figure(figsize=(5, 3), dpi=100)
    canvas_nodes = FigureCanvasTkAgg(fig_nodes, master=right_frame)
    canvas_nodes_widget = canvas_nodes.get_tk_widget()
    canvas_nodes_widget.pack(fill=tk.BOTH, expand=True, pady=(5, 0))
    ax_nodes = fig_nodes.add_subplot(111)

    # Here we store skeleton & branch-node info
    skel_actor_holder = {"actor": None}
    skeleton_volume = {"data": None}      # full labeled skeleton volume
    branch_nodes = {"nodes": []}          # list of (ix, iy, iz, degree)
    # Bubble-chart data cache + selection state
    bubble_data = {
        "z": None,         # np.array of z positions (one per bubble)
        "y": None,         # np.array of y values (avg degree per bubble)
        "sizes": None,     # np.array of bubble sizes
        "counts": None,    # np.array of node-count per z
        "scatter": None,   # matplotlib PathCollection
        "selected_mask": None
    }

    selected = {
        "z_set": set(),     # selected z slices
        "highlight_actor": None
    }

    # Remember current downsample factor used to build the skeleton
    current_ds = {"ds": 1}

    # Tube radius (slider-controlled)
    tube_radius_var = tk.DoubleVar(value=0.5)

    # Volume visibility state
    volume_visibility = {"visible": True}

    # Neighbour offsets for degree computation (26-connectivity)
    neighbour_offsets_26 = [
        (dx, dy, dz)
        for dx in (-1, 0, 1)
        for dy in (-1, 0, 1)
        for dz in (-1, 0, 1)
        if not (dx == 0 and dy == 0 and dz == 0)
    ]

    def compute_branch_nodes(skel):
        skel = (skel > 0)
        coords = np.argwhere(skel) 
        coord_set = set(map(tuple, coords.tolist()))
        nodes = []

        for ix, iy, iz in coords:
            deg = 0
            for dx, dy, dz in neighbour_offsets_26:
                key = (ix + dx, iy + dy, iz + dz)
                if key in coord_set:
                    deg += 1
            if deg >= 3:  # branch point: 3 or more neighbours
                nodes.append((int(ix), int(iy), int(iz), int(deg)))
        return nodes

    def update_branch_node_bubble_chart():
        """Update bubble chart (one bubble per Z slice).

        X-axis  = Z slice index
        Y-axis  = average branch degree in that slice
        Size    = number of branch nodes in that slice
        """
        ax_nodes.clear()
        nodes = branch_nodes["nodes"]

        if not nodes:
            ax_nodes.set_title("No branch nodes found")
            ax_nodes.set_xlabel("Z slice index")
            ax_nodes.set_ylabel("Branch degree (# branches)")
            canvas_nodes.draw()
            return

        # nodes: list of (ix, iy, iz, degree)
        zs = np.array([n[2] for n in nodes], dtype=int)
        degrees = np.array([n[3] for n in nodes], dtype=float)


        unique_z = np.unique(zs)
        avg_deg = []
        counts = []

        for z_val in unique_z:
            mask = (zs == z_val)
            slice_degrees = degrees[mask]
            counts.append(slice_degrees.size)
            avg_deg.append(slice_degrees.mean())

        unique_z = np.array(unique_z)
        avg_deg = np.array(avg_deg)
        counts = np.array(counts)

        # Scale bubble size so it doesn't cover everything
        # Normalize counts into a reasonable range
        size_scale = 300.0 / counts.max()
        bubble_sizes = counts * size_scale

        sc = ax_nodes.scatter(unique_z, avg_deg, s=bubble_sizes, alpha=0.6, edgecolors="k")
        ax_nodes.set_xlabel("Z slice index")
        ax_nodes.set_ylabel("Average branch degree")
        ax_nodes.set_title("Skeleton branch nodes (one bubble per Z-slice)")
        
        # store bubble data for rectangle/click selection
        bubble_data["z"] = unique_z
        bubble_data["y"] = avg_deg
        bubble_data["sizes"] = bubble_sizes
        bubble_data["counts"] = counts
        bubble_data["scatter"] = sc

        # If there is an active selection already, keep it
        if selected["z_set"]:
         bubble_data["selected_mask"] = np.isin(unique_z, list(selected["z_set"]))

        canvas_nodes.draw()

    def apply_selection_from_mask(sel_mask):
        """Apply selection from a boolean mask over bubble_data arrays."""
        if bubble_data["z"] is None:
            return

        z_vals = bubble_data["z"]
        chosen_z = set(map(int, z_vals[sel_mask].tolist()))

        selected["z_set"] = chosen_z
        bubble_data["selected_mask"] = sel_mask

        # Highlight and filter in VTK
        rebuild_skeleton_actor()
        set_branch_highlight(chosen_z)

        # Visually outline selected bubbles
        try:
            sc = bubble_data["scatter"]
            if sc is not None:
                # edge colors: selected -> red, others -> black
                ec = np.array([[0,0,0,1]] * len(z_vals), dtype=float)
                ec[sel_mask] = np.array([1,0,0,1], dtype=float)
                sc.set_edgecolors(ec)
                canvas_nodes.draw()
        except Exception:
            pass

    def on_bubble_rectangle_select(eclick, erelease):
        """
        Tab1-style rectangle selection:
        - clear old highlights
        - read rectangle coords (data space)
        - select bubbles inside rectangle
        - apply selection to VTK + bubble outlines
        """
        # Must have bubble data
        if bubble_data["z"] is None or bubble_data["y"] is None:
            return
        if eclick.xdata is None or eclick.ydata is None or erelease.xdata is None or erelease.ydata is None:
            return

        # Clear previous selection/highlights first (Tab1 behavior)
        selected["z_set"] = set()
        bubble_data["selected_mask"] = None
        set_branch_highlight(set())

        # Rectangle bounds in data coords
        x0, x1 = sorted([eclick.xdata, erelease.xdata])
        y0, y1 = sorted([eclick.ydata, erelease.ydata])

        z_vals = bubble_data["z"]
        y_vals = bubble_data["y"]

        # Select bubbles whose centers are inside rectangle
        sel_mask = (z_vals >= x0) & (z_vals <= x1) & (y_vals >= y0) & (y_vals <= y1)

        # Apply selection (this rebuilds skeleton, highlights, outlines)
        apply_selection_from_mask(sel_mask)
        
    def on_bubble_click(event):
        """Click bubbles: single-select, Ctrl=toggle, Shift=range add."""
        if event.inaxes != ax_nodes:
            return
        if bubble_data["z"] is None or bubble_data["y"] is None:
            return
        if event.xdata is None or event.ydata is None:
            return

        z_vals = bubble_data["z"]
        y_vals = bubble_data["y"]

        # Find nearest bubble in data coords
        dx = z_vals - event.xdata
        dy = y_vals - event.ydata
        idx = int(np.argmin(dx*dx + dy*dy))
        z_clicked = int(z_vals[idx])

        # Modifier keys (matplotlib gives these in event.key sometimes)
        key = event.key or ""
        key = key.lower()

        # Ctrl (or control) -> toggle add/remove
        if "control" in key or "ctrl" in key:
            if z_clicked in selected["z_set"]:
                selected["z_set"].remove(z_clicked)
            else:
                selected["z_set"].add(z_clicked)

        # Shift -> add a range between last selected and this one
        elif "shift" in key and selected["z_set"]:
            # range based on nearest z in current selection
            last = sorted(selected["z_set"])[-1]
            a, b = sorted([last, z_clicked])
            for zz in range(a, b + 1):
                selected["z_set"].add(int(zz))

        # Normal click -> single select
        else:
            selected["z_set"] = {z_clicked}

        # Convert selected z_set to mask and apply
        sel_mask = np.isin(z_vals, list(selected["z_set"]))
        apply_selection_from_mask(sel_mask)


    # Create the rectangle selector on the bubble chart axes
    rect_selector = RectangleSelector(
        ax_nodes,
        on_bubble_rectangle_select,
        useblit=True,
        button=[1],         # left mouse
        interactive=True,
        spancoords="data"   # IMPORTANT: rectangle in data coords
    )

    # Enable clicking bubbles
    canvas_nodes.mpl_connect("button_press_event", on_bubble_click)


    def rebuild_skeleton_actor():
        """Rebuild skeleton tubes from current skeleton volume, tube radius, and selection."""
        skel = skeleton_volume["data"]
        if skel is None:
            return

        ds = current_ds["ds"]
        eff_spacing = (spacing[0] * ds, spacing[1] * ds, spacing[2] * ds)

        # Apply Z-filter if selection exists
        if selected["z_set"]:
            z_mask = np.zeros(skel.shape[2], dtype=bool)
            for zz in selected["z_set"]:
                if 0 <= zz < skel.shape[2]:
                    z_mask[zz] = True
            skel_show = np.zeros_like(skel)
            skel_show[:, :, z_mask] = skel[:, :, z_mask]
        else:
            skel_show = skel

        actor = skeleton_volume_to_tube_actor(
            skel_show,
            spacing=eff_spacing,
            tube_radius=float(tube_radius_var.get()),
            color=(0.0, 1.0, 0.0)
        )
        if actor is None:
            return

        # Position skeleton actor depending on selected input (center NPY in 700³ scene)
        actor.SetPosition(*skel_input_state["offset_world"])

        if skel_actor_holder["actor"] is not None:
            renderer3D.RemoveActor(skel_actor_holder["actor"])

        skel_actor_holder["actor"] = actor
        renderer3D.AddActor(actor)
        renderer3D.GetRenderWindow().Render()


    def set_branch_highlight(z_set):
        """Show selected branch nodes as red points in VTK (only nodes with iz in z_set)."""
        # Remove old highlight actor
        if selected["highlight_actor"] is not None:
            renderer3D.RemoveActor(selected["highlight_actor"])
            selected["highlight_actor"] = None

        if not z_set:
            renderer3D.GetRenderWindow().Render()
            return

        nodes = branch_nodes["nodes"]  # list of (ix, iy, iz, degree)
        if not nodes:
            renderer3D.GetRenderWindow().Render()
            return

        ds = current_ds["ds"]
        sx, sy, sz = spacing
        eff_spacing = (sx * ds, sy * ds, sz * ds)

        pts = vtk.vtkPoints()
        for ix, iy, iz, deg in nodes:
            if iz in z_set:
                pts.InsertNextPoint(ix * eff_spacing[0], iy * eff_spacing[1], iz * eff_spacing[2])

        if pts.GetNumberOfPoints() == 0:
            renderer3D.GetRenderWindow().Render()
            return

        poly = vtk.vtkPolyData()
        poly.SetPoints(pts)

        # Render points efficiently
        mapper = vtk.vtkPointGaussianMapper()
        mapper.SetInputData(poly)
        mapper.SetScaleFactor(max(0.5, float(tube_radius_var.get())))  # reuse radius slider
        mapper.EmissiveOff()

        actor = vtk.vtkActor()
        actor.SetMapper(mapper)
        actor.GetProperty().SetColor(1.0, 0.0, 0.0)  # red highlight

        # Match skeleton offset (needed when input is NPY 300³)
        actor.SetPosition(*skel_input_state["offset_world"])

        selected["highlight_actor"] = actor
        renderer3D.AddActor(actor)
        renderer3D.GetRenderWindow().Render()

    def on_tube_radius_change(value):
        """Callback when tube-radius slider moves."""
        rebuild_skeleton_actor()

    def reset_view():
        """Reset camera to default view."""
        renderer3D.ResetCamera()
        renderer3D.GetRenderWindow().Render()

    def toggle_volume_visibility():
        """Hide/show all VTK volumes (leave skeletons visible)."""
        vols = renderer3D.GetVolumes()
        vols.InitTraversal()
        n = vols.GetNumberOfItems()
        for _ in range(n):
            vol = vols.GetNextVolume()
            if vol is not None:
                vol.SetVisibility(0 if volume_visibility["visible"] else 1)
        volume_visibility["visible"] = not volume_visibility["visible"]
        renderer3D.GetRenderWindow().Render()


    # ---------- CONTROLS ----------
    
    # ---------- INPUT VOLUME (VTK / NPY) ----------
    # The skeletonization pipeline can run either on your active VTK volume (RAW/subvolume)
    # or on a selected 300×300×300 .npy volume.
    input_source_var = tk.StringVar(value="VTK Active Volume (RAW/subvolume)")
    npy_dir_var = tk.StringVar(value=(npy_default_dir if npy_default_dir else os.getcwd()))

    # Keep current input info so we can position skeleton + highlights correctly
    skel_input_state = {"is_npy": False, "offset_world": (0.0, 0.0, 0.0)}

    # Cache last-loaded NPY to avoid reloading on every click
    loaded_npy_cache = {"path": None, "vol8": None}

    # Expected NPY names (your friends’ project) + your local test files
    expected_npy_files = [
        "spherepack_MIS.npy", "spherepack_EDM.npy", "spherepack_ToF_in.npy",
        "spherepack_d1_MIS.npy", "spherepack_d1_EDM.npy", "spherepack_d1_ToF_in.npy",
        "spherepack_d2_MIS.npy", "spherepack_d2_EDM.npy", "spherepack_d2_ToF_in.npy",
        "spherepack.npy", "spherepack_0.npy", "spherepack_d1.npy", "spherepack_d2.npy",
    ]

    def _scan_npy_folder(folder):
        try:
            files = sorted([f for f in os.listdir(folder) if f.lower().endswith(".npy")])
        except Exception:
            files = []
        # Merge expected list first, then discovered files
        merged = []
        for f in expected_npy_files:
            if f not in merged:
                merged.append(f)
        for f in files:
            if f not in merged:
                merged.append(f)
        return merged

    def _refresh_input_dropdown():
        folder = npy_dir_var.get().strip()
        values = ["VTK Active Volume (RAW/subvolume)"]
        values += [f"NPY: {fn}" for fn in _scan_npy_folder(folder)]
        input_combo["values"] = values
        if input_source_var.get() not in values:
            input_source_var.set(values[0])

    def _browse_npy_folder():
        folder = filedialog.askdirectory(initialdir=npy_dir_var.get() or os.getcwd())
        if folder:
            npy_dir_var.set(folder)
            _refresh_input_dropdown()

    tk.Label(control_frame, text="Skeleton input volume", font=("Arial", 11, "bold")).pack(pady=(0, 2))
    input_combo = ttk.Combobox(control_frame, textvariable=input_source_var, state="readonly", width=36)
    input_combo.pack(fill=tk.X, pady=(0, 4))

    row_npy = tk.Frame(control_frame)
    row_npy.pack(fill=tk.X, pady=(0, 6))
    tk.Entry(row_npy, textvariable=npy_dir_var, width=26).pack(side=tk.LEFT, padx=(0, 6))
    tk.Button(row_npy, text="Browse", command=_browse_npy_folder).pack(side=tk.LEFT)

    tk.Button(control_frame, text="Refresh NPY list", command=_refresh_input_dropdown).pack(fill=tk.X, pady=(0, 10))

    ttk.Separator(control_frame, orient="horizontal").pack(fill=tk.X, pady=(4, 10))

    # Init dropdown values
    _refresh_input_dropdown()
    tk.Label(control_frame, text="Binarization threshold", font=("Arial", 11)).pack(pady=(0, 2))
    thr_var = tk.IntVar(value=128)
    tk.Scale(
        control_frame, from_=0, to=255, orient="horizontal",
        variable=thr_var
    ).pack(fill=tk.X, pady=(0, 8))

    tk.Label(control_frame, text="Skeleton pruning threshold", font=("Arial", 11)).pack(pady=(0, 2))
    prune_var = tk.IntVar(value=5)
    tk.Scale(
        control_frame, from_=0, to=20, orient="horizontal",
        variable=prune_var
    ).pack(fill=tk.X, pady=(0, 8))

    tk.Label(control_frame, text="Downsample factor", font=("Arial", 11)).pack(pady=(0, 2))
    down_var = tk.IntVar(value=1)  # 1 = no downsample, 2 = half, 4 = quarter, etc.
    tk.Scale(
        control_frame, from_=1, to=4, orient="horizontal",
        variable=down_var
    ).pack(fill=tk.X, pady=(0, 8))

    tk.Label(control_frame, text="Tube radius", font=("Arial", 11)).pack(pady=(0, 2))
    tk.Scale(
        control_frame,
        from_=0.0,
        to=5.0,
        resolution=0.1,
        orient="horizontal",
        variable=tube_radius_var,
        command=on_tube_radius_change,  # live update when slider moves
    ).pack(fill=tk.X, pady=(0, 8))

    def run_skeletonization():
        """Run PyPore3D skeletonization on current active volume (optionally downsampled)."""
        status_label.config(text="Status: running skeletonization...")
        log_text.delete("1.0", tk.END)
        root.update_idletasks()
        # Decide input volume from dropdown (VTK active volume OR selected NPY)
        selected_input = input_source_var.get()
        if selected_input.startswith("NPY:"):
            fn = selected_input.replace("NPY:", "").strip()
            folder = npy_dir_var.get().strip()
            path = os.path.join(folder, fn)

            if not os.path.exists(path):
                status_label.config(text="Status: ERROR – missing NPY file")
                log(f"Missing NPY file: {path}")
                return

            try:
                # Load once and cache
                if loaded_npy_cache["path"] != path:
                    loaded_npy_cache["vol8"] = _load_npy_as_uint8(path, target_shape=(300, 300, 300))
                    loaded_npy_cache["path"] = path
                vol = loaded_npy_cache["vol8"]
            except Exception as e:
                status_label.config(text="Status: ERROR – failed to load NPY")
                log(f"NPY load failed: {e}")
                return

            skel_input_state["is_npy"] = True

            # Center 300^3 inside the main 700^3 world so it appears in the middle of the scene
            bigx, bigy, bigz = world_dimensions
            ox = max(0, int((bigx - vol.shape[0]) // 2))
            oy = max(0, int((bigy - vol.shape[1]) // 2))
            oz = max(0, int((bigz - vol.shape[2]) // 2))
            skel_input_state["offset_world"] = (ox * spacing[0], oy * spacing[1], oz * spacing[2])

            log(f"Input = NPY ({fn}), offset_world = {skel_input_state['offset_world']}")

        else:
            vol = get_active_volume()
            if vol is None:
                status_label.config(text="Status: ERROR – no active volume")
                log("Error: get_active_volume() returned None")
                return

            skel_input_state["is_npy"] = False
            skel_input_state["offset_world"] = (0.0, 0.0, 0.0)
            log("Input = VTK active volume (RAW/subvolume)")

        vol = np.asarray(vol, dtype=np.uint8)
        x_full, y_full, z_full = vol.shape

        # 1) Binarize full-resolution volume
        thr = int(thr_var.get())
        log(f"Binarizing volume with threshold = {thr}")
        mask_full = (vol >= thr).astype(np.uint8)
        # 2) Downsample factor
        ds = int(down_var.get())
        if ds < 1:
            ds = 1
        log(f"Using pruning threshold = {int(prune_var.get())}")
        log(f"Using downsample factor = {ds}")
        # 3) Build the working mask (downsampled grid)
        if ds == 1:
            mask_work = mask_full
        else:
            # simple stride downsample (nearest-neighbour)
            mask_work = mask_full[::ds, ::ds, ::ds]
        x, y, z = mask_work.shape
        log(f"Working volume shape for skeletonization: {x} x {y} x {z}")
        current_ds["ds"] = ds
        skel_thresh = int(prune_var.get())

        # 4) LKC skeletonization + pruning + labeling (via temp RAW files)
        try:
            import time
            start = time.time()
            # ---- 4.1 Save current working mask to a temporary RAW file ----
            tmp_dir = os.path.join(os.getcwd(), "p3d_tmp")
            os.makedirs(tmp_dir, exist_ok=True)
            tmp_mask_path = os.path.join(tmp_dir, "active_mask.raw")
            tmp_skel_path = os.path.join(tmp_dir, "active_skel.raw")
            # Make sure mask is uint8 and C-contiguous
            mask_uint8 = mask_work.astype(np.uint8, copy=False)
            mask_uint8.ravel(order="C").tofile(tmp_mask_path)
            log(f"Saved mask to: {tmp_mask_path}")
            # ---- 4.2 Read mask using PyPore3D helper (correct type) ----
            log("Reading mask with py_p3dReadRaw8...")
            mask_raw = py_p3dReadRaw8(tmp_mask_path, x, y, dimz=z)
            log(f"mask_raw type = {type(mask_raw)}")
            # ---- 4.3 LKC skeletonization + pruning + labeling ----
            log("Running py_p3dLKCSkeletonization...")
            skl_raw = py_p3dLKCSkeletonization(mask_raw, x, y, dimz=z)
            log("Running py_p3dSkeletonPruning...")
            skl_raw = py_p3dSkeletonPruning(skl_raw, x, y, dimz=z, thresh=skel_thresh)
            log("Running py_p3dSkeletonLabeling...")
            skl_raw = py_p3dSkeletonLabeling(skl_raw, x, y, dimz=z)
            log(f"skl_raw type = {type(skl_raw)}")
            # ---- 4.4 Write skeleton to RAW via PyPore3D, then read with NumPy ----
            py_p3dWriteRaw8(skl_raw, tmp_skel_path, x, y, dimz=z)
            log(f"Saved skeleton to: {tmp_skel_path}")
            skl_array_1d = np.fromfile(tmp_skel_path, dtype=np.uint8)
            if skl_array_1d.size != x * y * z:
                log(f"WARNING: skeleton array size {skl_array_1d.size} != x*y*z = {x*y*z}")
            sklImg = skl_array_1d.reshape((x, y, z), order="C")
            end = time.time()
            log(f"Skeletonization / pruning / labeling time: {end - start:.2f} s")
            # Save skeleton volume for further analysis / display (downsampled grid)
            skeleton_volume["data"] = sklImg
            # ---- 4.5 Optional: skeleton analysis (stats) ----
            try:
                voxel_size = 1.0 * ds  # effective voxel size after downsample
                stats = py_p3dSkeletonAnalysis(mask_raw, skl_raw, x, y, dimz=z,
                                               resolution=voxel_size)
                log(f"\nConnectivityDensity: {getattr(stats, 'ConnectivityDensity', 'N/A')}")
            except Exception as e_ana:
                log(f"SkeletonAnalysis error (non-fatal): {e_ana}")

        except Exception as e:
            status_label.config(text="Status: ERROR in PyPore3D call")
            log(f"PyPore3D skeletonization error: {e}")
            return

        # 5) Convert skeleton to VTK tubes and add to renderer
        log("Converting skeleton to VTK tubes (downsampled grid)...")
        rebuild_skeleton_actor()

        # 6) Compute branch nodes and update bubble chart
        log("Computing branch nodes (where multiple branches meet)...")
        nodes = compute_branch_nodes(skeleton_volume["data"])
        branch_nodes["nodes"] = nodes
        log(f"Found {len(nodes)} branch nodes (on downsampled grid).")
        update_branch_node_bubble_chart()

        status_label.config(text="Status: skeleton ready (downsampled)")


    def toggle_skeleton():
        """Show/hide skeleton tubes."""
        actor = skel_actor_holder["actor"]
        if actor is None:
            return
        vis = actor.GetVisibility()
        actor.SetVisibility(0 if vis else 1)
        renderer3D.GetRenderWindow().Render()

    def clear_skeleton():
        """Remove skeleton actor and clear chart."""
        actor = skel_actor_holder["actor"]
        if actor is not None:
            renderer3D.RemoveActor(actor)
            skel_actor_holder["actor"] = None
            renderer3D.GetRenderWindow().Render()

        skeleton_volume["data"] = None
        branch_nodes["nodes"] = []
        update_branch_node_bubble_chart()
        status_label.config(text="Status: skeleton cleared")
        log("Skeleton actor removed and data cleared.")

    # ---------- BUTTONS ----------
    tk.Button(control_frame, text="Run skeletonization",
              command=run_skeletonization).pack(fill=tk.X, pady=(10, 5))

    tk.Button(control_frame, text="Show / Hide skeleton",
              command=toggle_skeleton).pack(fill=tk.X, pady=5)

    tk.Button(control_frame, text="Clear skeleton",
              command=clear_skeleton).pack(fill=tk.X, pady=5)
    
    tk.Button(control_frame, text="Reset View",
              command=reset_view).pack(fill=tk.X, pady=5)

    tk.Button(control_frame, text="Hide / Show Volume",
              command=toggle_volume_visibility).pack(fill=tk.X, pady=5)
    
    def clear_selection():
        selected["z_set"] = set()
        bubble_data["selected_mask"] = None
        rebuild_skeleton_actor()
        set_branch_highlight(set())
        # reset bubble outline
        sc = bubble_data.get("scatter")
        if sc is not None:
            sc.set_edgecolors("k")
            canvas_nodes.draw()

    tk.Button(control_frame, text="Clear selection", command=clear_selection).pack(fill=tk.X, pady=5)


    # Initial empty chart
    update_branch_node_bubble_chart()

    #----------------------------------------------
    # Tab 4(NPY volume loader)Jannah's part
    #----------------------------------------------
def _load_npy_as_uint8(path, target_shape=(300, 300, 300)):
    arr = np.load(path)

    # Squeeze trivial dimensions if any
    arr = np.asarray(arr)
    if arr.ndim > 3:
        arr = np.squeeze(arr)

    if arr.ndim != 3:
        raise ValueError(f"{os.path.basename(path)} must be 3D, got shape={arr.shape}")

    if tuple(arr.shape) != tuple(target_shape):
        if arr.size == (target_shape[0] * target_shape[1] * target_shape[2]):
            arr = arr.reshape(target_shape)
        else:
            raise ValueError(
                f"{os.path.basename(path)} has shape {arr.shape}, expected {target_shape} "
                f"(or same total size)."
            )

    # Convert to uint8 for your current VTK pipeline
    if arr.dtype == np.bool_:
        vol8 = arr.astype(np.uint8) * 255
        return vol8

    if arr.dtype == np.uint8:
        # Make binary 0/1 visible
        if arr.max() <= 1:
            return arr * 255
        return arr

    # Normalize numeric arrays
    arr_f = arr.astype(np.float32)
    vmin = float(np.min(arr_f))
    vmax = float(np.max(arr_f))
    if vmax <= vmin:
        return np.zeros(target_shape, dtype=np.uint8)

    # If it's binary-ish 0/1 in float
    if vmax <= 1.0 and vmin >= 0.0:
        return (arr_f * 255.0).clip(0, 255).astype(np.uint8)

    norm = (arr_f - vmin) / (vmax - vmin)
    vol8 = (norm * 255.0).clip(0, 255).astype(np.uint8)
    return vol8


def add_npy_volume_tab(
    root,
    notebook,
    renderer3D,
    original_volume=None,
    spacing=(1.0, 1.0, 1.0),
    expected_files=None,
    default_dir=None,
    original_dimensions=(700, 700, 700)   # IMPORTANT for centering in 700^3 scene
):
    """
    Tab 4: Load 300^3 .npy volume, render it in the SAME VTK renderer, and provide
    a separate 'Skeletonize NPY' pipeline (Option B).

    - NPY is translated to the center of the 700^3 world using offset.
    - Skeleton tubes are also translated using the same offset.
    """

    tab_npy = ttk.Frame(notebook)
    notebook.add(tab_npy, text="Tab 4: NPY + NPY Skeleton")

    # Make Tab 4 scrollable
    _, scroll_inner = make_scrollable(tab_npy, enable_x=False)

    # Put your existing layout inside scroll_inner instead of tab_npy
    left = tk.Frame(scroll_inner)
    left.pack(side=tk.LEFT, fill=tk.Y, padx=10, pady=10)

    right = tk.Frame(scroll_inner)
    right.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=10, pady=10)

    # --- State ---
    npy_state = {
        "volume_actor": None,     # vtkVolume for NPY
        "opacity_tf": None,       # vtkPiecewiseFunction
        "last_path": None,
        "vol8": None,             # numpy uint8 (300^3)
        "offset_world": (0.0, 0.0, 0.0),

        "skel_vol": None,         # numpy uint8 skeleton volume
        "skel_actor": None,       # vtkActor tubes
    }

    # --- UI vars ---
    dir_var = tk.StringVar(value=(default_dir if default_dir else os.getcwd()))
    status_var = tk.StringVar(value="Choose a folder and a file, then click Load.")
    dataset_var = tk.StringVar()

    cmap_var = tk.StringVar(value="Grayscale")
    opacity_var = tk.DoubleVar(value=1.0)

    show_npy_var = tk.BooleanVar(value=True)
    show_original_var = tk.BooleanVar(value=True)

    # Skeleton controls
    npy_thr_var = tk.IntVar(value=1)         # for binary volumes, >0 is enough
    npy_prune_var = tk.IntVar(value=5)       # pruning strength (tune)
    npy_down_var = tk.IntVar(value=1)        # downsample factor
    npy_tube_radius_var = tk.DoubleVar(value=0.6)
    show_npy_skel_var = tk.BooleanVar(value=True)

    if expected_files is None:
        expected_files = []

    def _scan_dir_for_npy(folder):
        files = []
        try:
            for fn in os.listdir(folder):
                if fn.lower().endswith(".npy"):
                    files.append(fn)
        except Exception:
            pass
        files.sort()
        return files

    def _compute_center_offset_world(npy_shape=(300, 300, 300)):
        # Center 300^3 inside 700^3 => offset = (700-300)/2 = 200 (voxels)
        ox = max(0, int((original_dimensions[0] - npy_shape[0]) // 2))
        oy = max(0, int((original_dimensions[1] - npy_shape[1]) // 2))
        oz = max(0, int((original_dimensions[2] - npy_shape[2]) // 2))
        return (ox * spacing[0], oy * spacing[1], oz * spacing[2])

    def _refresh_dropdown():
        folder = dir_var.get().strip()
        found = _scan_dir_for_npy(folder)

        merged = []
        for fn in expected_files:
            if fn not in merged:
                merged.append(fn)
        for fn in found:
            if fn not in merged:
                merged.append(fn)

        dataset_combo["values"] = merged

        if not dataset_var.get():
            for fn in merged:
                if os.path.exists(os.path.join(folder, fn)):
                    dataset_var.set(fn)
                    break
            if not dataset_var.get() and merged:
                dataset_var.set(merged[0])

        missing = [fn for fn in expected_files if not os.path.exists(os.path.join(folder, fn))]
        if missing:
            status_var.set(f"Found {len(found)} .npy files. Missing {len(missing)} expected files.")
        else:
            status_var.set(f"Found {len(found)} .npy files. All expected files available.")

    def _browse_folder():
        folder = filedialog.askdirectory(initialdir=dir_var.get() or os.getcwd())
        if folder:
            dir_var.set(folder)
            _refresh_dropdown()

    def _remove_npy_volume():
        if npy_state["volume_actor"] is not None:
            renderer3D.RemoveVolume(npy_state["volume_actor"])
            npy_state["volume_actor"] = None
            npy_state["opacity_tf"] = None
            npy_state["vol8"] = None
            renderer3D.GetRenderWindow().Render()

    def _remove_npy_skeleton():
        if npy_state["skel_actor"] is not None:
            renderer3D.RemoveActor(npy_state["skel_actor"])
            npy_state["skel_actor"] = None
            npy_state["skel_vol"] = None
            renderer3D.GetRenderWindow().Render()

    def _build_volume_from_uint8(vol8):
        opacity_tf = vtk.vtkPiecewiseFunction()
        alpha = float(opacity_var.get())
        opacity_tf.AddPoint(0, 0.0)
        opacity_tf.AddPoint(255, max(0.0, min(1.0, alpha)))

        color_tf = get_color_map(cmap_var.get())

        mapper = vtk.vtkGPUVolumeRayCastMapper()
        importer = numpy2VTK(vol8, spacing=spacing)
        mapper.SetInputConnection(importer.GetOutputPort())

        prop = vtk.vtkVolumeProperty()
        prop.SetColor(color_tf)
        prop.SetScalarOpacity(opacity_tf)
        prop.ShadeOn()
        prop.SetInterpolationTypeToLinear()

        vol_actor = vtk.vtkVolume()
        vol_actor.SetMapper(mapper)
        vol_actor.SetProperty(prop)
        return vol_actor, opacity_tf

    def _load_selected_npy():
        folder = dir_var.get().strip()
        fn = dataset_var.get().strip()
        if not fn:
            status_var.set("No dataset selected.")
            return

        path = os.path.join(folder, fn)
        if not os.path.exists(path):
            status_var.set(f"Missing file: {path}")
            return

        try:
            vol8 = _load_npy_as_uint8(path, target_shape=(300, 300, 300))
        except Exception as e:
            status_var.set(f"Load failed: {e}")
            return

        # Replace old npy volume + skeleton
        _remove_npy_volume()
        _remove_npy_skeleton()

        vol_actor, opacity_tf = _build_volume_from_uint8(vol8)

        # Compute and apply centering offset in the 700^3 world
        offset_world = _compute_center_offset_world(npy_shape=vol8.shape)
        vol_actor.SetPosition(*offset_world)

        npy_state["volume_actor"] = vol_actor
        npy_state["opacity_tf"] = opacity_tf
        npy_state["last_path"] = path
        npy_state["vol8"] = vol8
        npy_state["offset_world"] = offset_world

        renderer3D.AddVolume(vol_actor)

        vol_actor.SetVisibility(bool(show_npy_var.get()))
        if original_volume is not None:
            original_volume.SetVisibility(bool(show_original_var.get()))

        renderer3D.ResetCamera()
        renderer3D.GetRenderWindow().Render()

        status_var.set(f"Loaded: {fn} | shape={vol8.shape} | dtype=uint8 | centered in 700^3")

    def _apply_visual_changes():
        vol_actor = npy_state["volume_actor"]
        if vol_actor is not None:
            prop = vol_actor.GetProperty()
            prop.SetColor(get_color_map(cmap_var.get()))

            opacity_tf = npy_state["opacity_tf"]
            if opacity_tf is not None:
                opacity_tf.RemoveAllPoints()
                alpha = float(opacity_var.get())
                opacity_tf.AddPoint(0, 0.0)
                opacity_tf.AddPoint(255, max(0.0, min(1.0, alpha)))

            vol_actor.SetVisibility(bool(show_npy_var.get()))

        if original_volume is not None:
            original_volume.SetVisibility(bool(show_original_var.get()))

        # Skeleton visibility
        if npy_state["skel_actor"] is not None:
            npy_state["skel_actor"].SetVisibility(bool(show_npy_skel_var.get()))

        renderer3D.GetRenderWindow().Render()

    def _reset_view():
        renderer3D.ResetCamera()
        renderer3D.GetRenderWindow().Render()

    def _skeletonize_npy():
        """
        Option B skeletonization:
        - Uses the loaded npy_state['vol8'] (300^3)
        - Builds a binary mask via threshold
        - Runs PyPore3D skeletonization + pruning + labeling
        - Visualizes as tubes, positioned with the SAME offset as NPY volume
        """
        if npy_state["vol8"] is None:
            status_var.set("Load an NPY file first.")
            return

        vol8 = npy_state["vol8"]
        if tuple(vol8.shape) != (300, 300, 300):
            status_var.set(f"Unexpected NPY shape {vol8.shape}. Expected (300,300,300).")
            return

        thr = int(npy_thr_var.get())
        prune_t = int(npy_prune_var.get())
        ds = int(npy_down_var.get())
        ds = max(1, ds)

        tube_r = float(npy_tube_radius_var.get())

        # Build mask
        mask_full = (vol8 >= thr).astype(np.uint8)

        # Downsample for speed if requested
        if ds > 1:
            mask_work = mask_full[::ds, ::ds, ::ds]
        else:
            mask_work = mask_full

        x, y, z = mask_work.shape

        # Temp folder
        tmp_dir = os.path.join(os.getcwd(), "p3d_tmp_npy")
        os.makedirs(tmp_dir, exist_ok=True)

        mask_raw_path = os.path.join(tmp_dir, "npy_mask.raw")
        skel_raw_path = os.path.join(tmp_dir, "npy_skeleton.raw")

        # Write mask to raw
        mask_work.astype(np.uint8).tofile(mask_raw_path)

        try:
            # Read with PyPore3D
            mask_raw = py_p3dReadRaw8(mask_raw_path, x, y, dimz=z)

            # Skeleton pipeline
            skl_raw = py_p3dLKCSkeletonization(mask_raw, x, y, dimz=z)
            skl_raw = py_p3dSkeletonPruning(skl_raw, x, y, dimz=z, thresh=prune_t)
            skl_raw = py_p3dSkeletonLabeling(skl_raw, x, y, dimz=z)

            # Write skeleton out
            py_p3dWriteRaw8(skl_raw, skel_raw_path, x, y, dimz=z)

            skel_img = np.fromfile(skel_raw_path, dtype=np.uint8).reshape((x, y, z))

            # If downsampled, we visualize in downsample space; adjust spacing accordingly
            skel_spacing = (spacing[0] * ds, spacing[1] * ds, spacing[2] * ds)

        except Exception as e:
            status_var.set(f"Skeletonization failed: {e}")
            return

        # Remove previous skeleton actor
        _remove_npy_skeleton()

        # Build tube actor from skeleton volume
        skel_actor = skeleton_volume_to_tube_actor(
            skel_img,
            spacing=skel_spacing,
            tube_radius=tube_r,
            color=(0.0, 1.0, 0.0),
            max_voxels=120000
        )

        # Apply SAME world offset as NPY volume (center in 700^3)
        # BUT note: if downsampled, the skeleton is smaller; still offset is correct in world coordinates.
        skel_actor.SetPosition(*npy_state["offset_world"])

        npy_state["skel_vol"] = skel_img
        npy_state["skel_actor"] = skel_actor

        renderer3D.AddActor(skel_actor)
        skel_actor.SetVisibility(bool(show_npy_skel_var.get()))
        renderer3D.GetRenderWindow().Render()

        status_var.set(
            f"NPY skeleton ready | mask thr={thr} | prune={prune_t} | ds={ds} | tube_r={tube_r:.2f}"
        )

    # ---------------- LEFT UI ----------------
    tk.Label(left, text="NPY Folder:", font=("Arial", 10, "bold")).pack(anchor="w")
    row_dir = tk.Frame(left)
    row_dir.pack(fill=tk.X, pady=4)
    tk.Entry(row_dir, textvariable=dir_var, width=38).pack(side=tk.LEFT, padx=(0, 6))
    tk.Button(row_dir, text="Browse", command=_browse_folder).pack(side=tk.LEFT)

    tk.Button(left, text="Refresh File List", command=_refresh_dropdown).pack(fill=tk.X, pady=(0, 10))

    tk.Label(left, text="Dataset (.npy) [300×300×300]:", font=("Arial", 10, "bold")).pack(anchor="w")
    dataset_combo = ttk.Combobox(left, textvariable=dataset_var, state="readonly", width=40)
    dataset_combo.pack(pady=4)

    tk.Button(left, text="Load & Render NPY", command=_load_selected_npy).pack(fill=tk.X, pady=(6, 12))

    # Render controls
    tk.Label(left, text="Colormap:", font=("Arial", 10, "bold")).pack(anchor="w")
    cmap_combo = ttk.Combobox(left, textvariable=cmap_var, state="readonly",
                             values=["Grayscale", "Hot", "Cool", "Jet"], width=20)
    cmap_combo.pack(anchor="w", pady=4)
    cmap_combo.bind("<<ComboboxSelected>>", lambda e: _apply_visual_changes())

    tk.Label(left, text="Opacity (0→transparent, 1→opaque):", font=("Arial", 10, "bold")).pack(anchor="w", pady=(10, 0))
    tk.Scale(left, variable=opacity_var, from_=0.0, to=1.0, resolution=0.05,
             orient=tk.HORIZONTAL, length=260, command=lambda v: _apply_visual_changes()).pack(anchor="w")

    tk.Checkbutton(left, text="Show NPY Volume", variable=show_npy_var, command=_apply_visual_changes).pack(anchor="w", pady=(10, 0))
    if original_volume is not None:
        tk.Checkbutton(left, text="Show Original Volume", variable=show_original_var, command=_apply_visual_changes).pack(anchor="w")

    # ---------------- Skeleton controls (Option B) ----------------
    sep = ttk.Separator(left, orient="horizontal")
    sep.pack(fill=tk.X, pady=12)

    tk.Label(left, text="NPY Skeletonization (Option B)", font=("Arial", 10, "bold")).pack(anchor="w")

    tk.Label(left, text="Mask Threshold (uint8):", font=("Arial", 9, "bold")).pack(anchor="w", pady=(6, 0))
    tk.Scale(left, variable=npy_thr_var, from_=0, to=255, resolution=1,
             orient=tk.HORIZONTAL, length=260).pack(anchor="w")
    tk.Label(left, text="Tip: for binary NPY use threshold=1", fg="gray").pack(anchor="w")

    tk.Label(left, text="Prune Threshold:", font=("Arial", 9, "bold")).pack(anchor="w", pady=(6, 0))
    tk.Scale(left, variable=npy_prune_var, from_=0, to=50, resolution=1,
             orient=tk.HORIZONTAL, length=260).pack(anchor="w")

    tk.Label(left, text="Downsample Factor:", font=("Arial", 9, "bold")).pack(anchor="w", pady=(6, 0))
    ttk.Combobox(left, textvariable=npy_down_var, state="readonly", values=[1,2,3,4,5], width=6).pack(anchor="w")

    tk.Label(left, text="Tube Radius (0.0 → 5.0):", font=("Arial", 9, "bold")).pack(anchor="w", pady=(6, 0))
    tk.Scale(left, variable=npy_tube_radius_var, from_=0.0, to=5.0, resolution=0.1,
             orient=tk.HORIZONTAL, length=260).pack(anchor="w")

    tk.Button(left, text="Skeletonize NPY", command=_skeletonize_npy).pack(fill=tk.X, pady=(8, 4))
    tk.Checkbutton(left, text="Show NPY Skeleton", variable=show_npy_skel_var, command=_apply_visual_changes).pack(anchor="w")

    tk.Button(left, text="Remove NPY Skeleton", command=_remove_npy_skeleton).pack(fill=tk.X, pady=(6, 2))

    # Other utility buttons
    tk.Button(left, text="Reset View", command=_reset_view).pack(fill=tk.X, pady=(10, 6))
    tk.Button(left, text="Remove NPY Volume", command=_remove_npy_volume).pack(fill=tk.X)

    # ---------------- RIGHT UI ----------------
    tk.Label(right, text="Status", font=("Arial", 12, "bold")).pack(anchor="w")
    tk.Label(right, textvariable=status_var, justify="left", wraplength=520).pack(anchor="w", pady=(4, 10))

    tk.Label(
        right,
        text=(
            "Notes:\n"
            "• All NPY files must be 300×300×300.\n"
            "• NPY is centered into your 700×700×700 scene using a translation offset.\n"
            "• Skeletonize NPY uses the loaded NPY only (separate from your main Skeleton tab).\n"
        ),
        justify="left",
        wraplength=520
    ).pack(anchor="w")

    _refresh_dropdown()

In [4]:
def combined_gui_with_interactive_filter(filename, dimensions):
    """
    GUI providing:
      - Subvolume selection (X/Y/Z sliders)
      - Scatter plot with rectangle selection → highlight blobs in 3D
      - Histogram with rectangle selection → set intensity range & filter
      - Heatmap change button
      - Reset view button
      - Blob visibility toggle
      - Skeletonization tab access
    """

    # -------------------------------------------------------------------------
    # 1) GLOBALS & DATA LOADING
    # -------------------------------------------------------------------------
    global current_subvolume, renderer3D, render_window3D, blob_actor, highlighted_actors, blobs_visible

    # Load 3D RAW volume
    data = np.fromfile(filename, dtype=np.uint8).reshape(dimensions)

    # ACTIVE_VOLUME = volume currently rendered in 3D
    global ACTIVE_VOLUME
    ACTIVE_VOLUME = data

    # Load blob attribute table + centroids
    blob_test_data = pd.read_csv(r"C:\Users\adamd\Downloads\bachelor-1\blob_test.csv")
    centroids = blob_test_data[["Centroid_x", "Centroid_y", "Centroid_z"]].values.astype(float)
    attributes = list(blob_test_data.columns)

    # -------------------------------------------------------------------------
    # 2) VTK 3D SETUP
    # -------------------------------------------------------------------------
    renderer3D = vtk.vtkRenderer()
    render_window3D = vtk.vtkRenderWindow()
    render_window3D.AddRenderer(renderer3D)
    render_window3D.SetSize(800, 800)

    interactor3D = vtk.vtkRenderWindowInteractor()
    interactor3D.SetRenderWindow(render_window3D)

    # Initial full-volume rendering
    volume = volume_render(data, colormap="Grayscale")
    renderer3D.AddVolume(volume)
    renderer3D.ResetCamera()
    render_window3D.Render()

    # Base blob actor (all blobs in red, initially hidden)
    highlighted_actors = []
    blobs_visible = False
    blob_actor = render_blobs(centroids, renderer3D, sphere_radius=5, color=(1, 0, 0))
    blob_actor.SetVisibility(False)

    # -------------------------------------------------------------------------
    # 3) TKINTER ROOT & NOTEBOOK
    # -------------------------------------------------------------------------
    root = tk.Tk()
    root.title("Interactive 3D Visualization and Filtering")

    notebook = ttk.Notebook(root)
    notebook.pack(fill=tk.BOTH, expand=True)

    # -------------------------------------------------------------------------
    # 4) TAB 1 — VISUALIZATION & FILTERING
    # -------------------------------------------------------------------------
    tab_3d = ttk.Frame(notebook)
    notebook.add(tab_3d, text="3D Visualization & Filtering")

    # Left side: controls
    control_frame = tk.Frame(tab_3d)
    control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=10, pady=10)

    # Right side: plots
    plot_frame = tk.Frame(tab_3d)
    plot_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

    # Matplotlib figure (histogram + scatter)
    fig = Figure(figsize=(5, 6), dpi=100)
    canvas = FigureCanvasTkAgg(fig, master=plot_frame)
    canvas_widget = canvas.get_tk_widget()
    canvas_widget.pack(side=tk.TOP, fill=tk.BOTH, expand=True)

    # Top axes: histogram
    ax_histogram = fig.add_axes([0.1, 0.55, 0.8, 0.35])
    ax_histogram.hist(data.ravel(), bins=50, alpha=0.7)
    ax_histogram.set_title("Intensity Histogram")
    ax_histogram.set_xlabel("Intensity")
    ax_histogram.set_ylabel("Frequency")

    # Bottom axes: scatter
    ax_scatter = fig.add_axes([0.1, 0.1, 0.8, 0.35])
    ax_scatter.set_title("Scatter Plot (blob attributes)")
    ax_scatter.set_xlabel("X attribute")
    ax_scatter.set_ylabel("Y attribute")
    # -------------------------------------------------------------------------
    # 4.1) Scatter attribute selectors
    # -------------------------------------------------------------------------
    tk.Label(control_frame, text="Scatter X-Axis Attribute", font=("Arial", 10)).pack(pady=(10, 2))
    scatter_x_var = tk.StringVar(value=attributes[0])
    ttk.Combobox(control_frame, textvariable=scatter_x_var, values=attributes).pack(pady=2)
    tk.Label(control_frame, text="Scatter Y-Axis Attribute", font=("Arial", 10)).pack(pady=(6, 2))
    scatter_y_var = tk.StringVar(value=attributes[1])
    ttk.Combobox(control_frame, textvariable=scatter_y_var, values=attributes).pack(pady=2)
    # -------------------------------------------------------------------------
    # 4.2) Heatmap selection
    # -------------------------------------------------------------------------
    tk.Label(control_frame, text="Select Heatmap Color", font=("Arial", 12)).pack(pady=5)
    colormap_var = tk.StringVar(value="Grayscale")
    ttk.Combobox(
        control_frame,
        textvariable=colormap_var,
        values=["Grayscale", "Hot", "Cool", "Jet"]
    ).pack(pady=5)
    def change_heatmap():
        """Re-render the current ACTIVE_VOLUME with the selected colormap."""
        new_volume = volume_render(ACTIVE_VOLUME, colormap=colormap_var.get())
        renderer3D.RemoveAllViewProps()
        renderer3D.AddVolume(new_volume)
        render_window3D.Render()
    tk.Button(control_frame, text="Change Heatmap", command=change_heatmap).pack(pady=5)
    # -------------------------------------------------------------------------
    # 4.3) Subvolume sliders
    # -------------------------------------------------------------------------
    tk.Label(control_frame, text="Select X Range:", font=("Arial", 12)).pack(pady=5)
    x_min = tk.IntVar(value=0)
    x_max = tk.IntVar(value=dimensions[0] - 1)
    tk.Scale(control_frame, from_=0, to=dimensions[0] - 1, variable=x_min, orient="horizontal").pack()
    tk.Scale(control_frame, from_=0, to=dimensions[0] - 1, variable=x_max, orient="horizontal").pack()
    tk.Label(control_frame, text="Select Y Range:", font=("Arial", 12)).pack(pady=5)
    y_min = tk.IntVar(value=0)
    y_max = tk.IntVar(value=dimensions[1] - 1)
    tk.Scale(control_frame, from_=0, to=dimensions[1] - 1, variable=y_min, orient="horizontal").pack()
    tk.Scale(control_frame, from_=0, to=dimensions[1] - 1, variable=y_max, orient="horizontal").pack()
    tk.Label(control_frame, text="Select Z Range:", font=("Arial", 12)).pack(pady=5)
    z_min = tk.IntVar(value=0)
    z_max = tk.IntVar(value=dimensions[2] - 1)
    tk.Scale(control_frame, from_=0, to=dimensions[2] - 1, variable=z_min, orient="horizontal").pack()
    tk.Scale(control_frame, from_=0, to=dimensions[2] - 1, variable=z_max, orient="horizontal").pack()
    # -------------------------------------------------------------------------
    # 4.4) Internal min/max intensity variables
    # -------------------------------------------------------------------------
    # These are controlled only by the histogram rectangle selector.
    min_intensity_var = tk.IntVar(value=0)
    max_intensity_var = tk.IntVar(value=255)
    # -------------------------------------------------------------------------
    # 4.5) Scatter & histogram update functions
    # -------------------------------------------------------------------------
    def update_scatter_plot():
        """Draw scatter for current selected X/Y attributes."""
        ax_scatter.clear()
        x_attr = scatter_x_var.get()
        y_attr = scatter_y_var.get()
        if x_attr not in attributes or y_attr not in attributes:
            print("Invalid scatter attribute selection!")
            return
        ax_scatter.scatter(blob_test_data[x_attr], blob_test_data[y_attr], alpha=0.6)
        ax_scatter.set_title(f"Scatter Plot: {x_attr} vs {y_attr}")
        ax_scatter.set_xlabel(x_attr)
        ax_scatter.set_ylabel(y_attr)
        canvas.draw()

    def update_histogram_and_scatter():
        """Update histogram for the current subvolume and redraw scatter."""
        ax_histogram.clear()
        ax_histogram.hist(current_subvolume.ravel(), bins=50, alpha=0.7)
        ax_histogram.set_title("Subvolume Intensity Histogram")
        ax_histogram.set_xlabel("Intensity")
        ax_histogram.set_ylabel("Frequency")
        update_scatter_plot()

    # -------------------------------------------------------------------------
    # 4.6) Intensity filter logic (called by histogram rectangle selector)
    # -------------------------------------------------------------------------
    def apply_intensity_filter():
        """
        Apply intensity filter to current_subvolume using internal
        min_intensity_var and max_intensity_var, then re-render.
        """
        global ACTIVE_VOLUME

        min_intensity = min_intensity_var.get()
        max_intensity = max_intensity_var.get()

        if min_intensity < 0 or max_intensity > 255 or min_intensity > max_intensity:
            print("Invalid intensity range!")
            return

        # Filter the CURRENT subvolume directly
        subvolume = current_subvolume
        mask = (subvolume >= min_intensity) & (subvolume <= max_intensity)
        filtered_data = np.where(mask, subvolume, 0).astype(np.uint8)

        if filtered_data.size == 0 or filtered_data.sum() == 0:
            print("Warning: Filtered subvolume is empty!")

        ACTIVE_VOLUME = filtered_data

        # Re-render filtered volume
        new_volume = volume_render(filtered_data, colormap=colormap_var.get())
        renderer3D.RemoveAllViewProps()
        renderer3D.AddVolume(new_volume)
        render_window3D.Render()

        # Update histogram to show filtered data
        ax_histogram.clear()
        ax_histogram.hist(filtered_data.ravel(), bins=50, alpha=0.7)
        ax_histogram.set_title("Filtered Intensity Histogram")
        ax_histogram.set_xlabel("Intensity")
        ax_histogram.set_ylabel("Frequency")
        canvas.draw()

    # -------------------------------------------------------------------------
    # 4.7) Apply subvolume + Reset view
    # -------------------------------------------------------------------------
    def apply_subvolume():
        """Use sliders to choose ROI, set current_subvolume, and render it."""
        global ACTIVE_VOLUME, current_subvolume

        current_subvolume = data[
            x_min.get():x_max.get() + 1,
            y_min.get():y_max.get() + 1,
            z_min.get():z_max.get() + 1,
        ]
        if current_subvolume.size == 0:
            print("Warning: Subvolume is empty!")
            return

        ACTIVE_VOLUME = current_subvolume
        new_volume = volume_render(current_subvolume, colormap=colormap_var.get())
        renderer3D.RemoveAllViewProps()
        renderer3D.AddVolume(new_volume)
        render_window3D.Render()

        update_histogram_and_scatter()

    def reset_view():
        """Reset to full volume view (no ROI, no intensity filtering)."""
        global ACTIVE_VOLUME, current_subvolume

        current_subvolume = data
        ACTIVE_VOLUME = data

        new_volume = volume_render(data, colormap=colormap_var.get())
        renderer3D.RemoveAllViewProps()
        renderer3D.AddVolume(new_volume)
        renderer3D.ResetCamera()
        render_window3D.Render()

        ax_histogram.clear()
        ax_histogram.hist(data.ravel(), bins=50, alpha=0.7)
        ax_histogram.set_title("Full Volume Histogram")
        ax_histogram.set_xlabel("Intensity")
        ax_histogram.set_ylabel("Frequency")
        update_scatter_plot()

    # -------------------------------------------------------------------------
    # 4.8) Blob visibility toggle
    # -------------------------------------------------------------------------
    def toggle_blobs():
        """Toggle visibility of base red blobs + highlighted yellow blobs."""
        global blobs_visible
        blobs_visible = not blobs_visible
        blob_actor.SetVisibility(blobs_visible)
        for actor in highlighted_actors:
            actor.SetVisibility(blobs_visible)
        render_window3D.Render()

    # Initialize current_subvolume and plots
    current_subvolume = data
    update_scatter_plot()

    # Buttons
    tk.Button(control_frame, text="Apply Subvolume & Render", command=apply_subvolume).pack(pady=10)
    tk.Button(control_frame, text="Reset View", command=reset_view).pack(pady=5)
    tk.Button(control_frame, text="Toggle Blobs", command=toggle_blobs).pack(pady=5)
    # -------------------------------------------------------------------------
    # 4.9) Rectangle selectors (scatter + histogram)
    # -------------------------------------------------------------------------
    def on_scatter_rectangle_select(eclick, erelease):
        """
        When user draws a rectangle on the SCATTER:
          1) Clear previous highlight actors
          2) Read rectangle coordinates in data space
          3) Find matching rows in blob_test_data
          4) Highlight selected blobs in 3D as yellow spheres
        """
        global highlighted_actors

        # 1) Clear previous highlights
        for actor in highlighted_actors:
            renderer3D.RemoveActor(actor)
        highlighted_actors = []

        # 2) Read rectangle coordinates
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        if None in (x1, y1, x2, y2):
            print("Invalid scatter rectangle.")
            return

        x_min_sel, x_max_sel = sorted([x1, x2])
        y_min_sel, y_max_sel = sorted([y1, y2])

        x_attr = scatter_x_var.get()
        y_attr = scatter_y_var.get()

        # 3) Find indices inside rectangle
        selected_indices = blob_test_data[
            (blob_test_data[x_attr] >= x_min_sel) &
            (blob_test_data[x_attr] <= x_max_sel) &
            (blob_test_data[y_attr] >= y_min_sel) &
            (blob_test_data[y_attr] <= y_max_sel)
        ].index

        selected_centroids = centroids[selected_indices]

        # 4) Add yellow spheres for selected blobs
        for c in selected_centroids:
            actor = render_blobs([c], renderer3D, sphere_radius=8, color=(1, 1, 0))
            actor.SetVisibility(blobs_visible)
            highlighted_actors.append(actor)

        render_window3D.Render()

    def on_hist_rectangle_select(eclick, erelease):
        """
        When user draws a rectangle on the HISTOGRAM:
          1) Clear previous highlight actors (keep behavior consistent)
          2) Read rectangle coordinates (x-range)
          3) Map to intensity range and store in min_intensity_var/max_intensity_var
          4) Call apply_intensity_filter()
        """
        global highlighted_actors

        # 1) Clear previous highlights
        for actor in highlighted_actors:
            renderer3D.RemoveActor(actor)
        highlighted_actors = []

        # 2) Read rectangle coordinates
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        if None in (x1, y1, x2, y2):
            print("Invalid histogram rectangle.")
            return

        # 3) X range → intensity range (clamped to [0,255])
        lo, hi = sorted([x1, x2])
        lo_int = max(0, min(255, int(round(lo))))
        hi_int = max(0, min(255, int(round(hi))))

        min_intensity_var.set(lo_int)
        max_intensity_var.set(hi_int)

        # 4) Apply intensity filter based on this range
        apply_intensity_filter()

    # Attach rectangle selectors
    rect_selector_scatter = RectangleSelector(
        ax_scatter,
        on_scatter_rectangle_select,
        interactive=True,
        useblit=True,
        button=[1],
        minspanx=5,
        minspany=5,
        spancoords="pixels",
    )

    rect_selector_hist = RectangleSelector(
        ax_histogram,
        on_hist_rectangle_select,
        interactive=True,
        useblit=True,
        button=[1],
        minspanx=5,
        minspany=5,
        spancoords="pixels",
    )
    # -------------------------------------------------------------------------
    # TAB 2 — PARALLEL COORDINATE PLOT (styled + linked to 3D blobs)
    # -------------------------------------------------------------------------
    tab_parallel = ttk.Frame(notebook)
    notebook.add(tab_parallel, text="Parallel Coordinate Plot")

    # Left side: attribute selection
    parallel_left = tk.Frame(tab_parallel)
    parallel_left.pack(side=tk.LEFT, fill=tk.Y, padx=10, pady=10)

    # Right side: Matplotlib figure for the parallel plot
    parallel_right = tk.Frame(tab_parallel)
    parallel_right.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

    fig_parallel = Figure(figsize=(6, 5), dpi=100)
    canvas_parallel = FigureCanvasTkAgg(fig_parallel, master=parallel_right)
    canvas_parallel_widget = canvas_parallel.get_tk_widget()
    canvas_parallel_widget.pack(side=tk.TOP, fill=tk.BOTH, expand=True)

    ax_parallel = fig_parallel.add_subplot(111)

    tk.Label(
        parallel_left,
        text="Select attributes for\nParallel Coordinates:",
        font=("Arial", 10)
    ).pack(pady=5)

    # Multi-select listbox with all blob attributes
    attr_listbox = tk.Listbox(
        parallel_left,
        selectmode=tk.MULTIPLE,
        exportselection=False,
        height=10
    )
    for attr in attributes:
        attr_listbox.insert(tk.END, attr)
    attr_listbox.pack(fill=tk.Y, expand=False, pady=5)

    # This will store all Line2D objects for picking + highlighting
    parallel_lines = []

    # color cycle
    color_cycle = [
        "tab:blue", "tab:orange", "tab:green", "tab:red",
        "tab:purple", "tab:brown", "tab:pink", "tab:gray",
        "tab:olive", "tab:cyan"
    ]

    def update_parallel_plot():
        """
        Draw a parallel coordinate plot for the selected attributes.

        - One polyline per blob (row in blob_test_data)
        - Each axis is one selected attribute (normalized to [0,1])
        - Lines are colored using a color cycle
        - Legend shows the first few blobs (Blob #1, Blob #2, ...)
        - Each line gets an attached _blob_row_index so clicks can
          be mapped back to a centroid for 3D highlighting.
        """
        nonlocal parallel_lines

        sel_indices = attr_listbox.curselection()
        if len(sel_indices) < 2:
            print("Select at least 2 attributes for the parallel plot.")
            return

        selected_attrs = [attributes[i] for i in sel_indices]

        # Take only the selected columns and convert to float
        df = blob_test_data[selected_attrs].astype(float)
        values = df.values  # shape: (n_blobs, n_selected_attrs)

        # Normalize each attribute to [0, 1]
        mins = values.min(axis=0)
        maxs = values.max(axis=0)
        ranges = maxs - mins
        ranges[ranges == 0] = 1.0  # avoid division by zero
        norm_vals = (values - mins) / ranges

        ax_parallel.clear()
        parallel_lines = []

        # X positions for axes: 0, 1, 2, ..., n_attrs-1
        x_positions = np.arange(len(selected_attrs))

        legend_lines = []
        legend_labels = []
        legend_limit = 6  # show at most 6 blobs in the legend

        # Draw one polyline per blob
        for row_idx, row in enumerate(norm_vals):
            color = color_cycle[row_idx % len(color_cycle)]

            line, = ax_parallel.plot(
                x_positions,
                row,
                color=color,
                alpha=0.7,
                linewidth=1.8,
                picker=True,     # enable picking
                pickradius=5     # pixels around the line that count as a "pick"
            )

            # Attach the blob row index so we can map click -> centroid
            line._blob_row_index = row_idx
            parallel_lines.append(line)

            # Add a few lines to the legend
            if row_idx < legend_limit:
                legend_lines.append(line)
                legend_labels.append(f"Blob {row_idx + 1}")

        # X axis: attribute names
        ax_parallel.set_xticks(x_positions)
        ax_parallel.set_xticklabels(selected_attrs, rotation=45, ha="right")

        # Y axis: 0–100% style (like the example image)
        ax_parallel.set_ylim(0.0, 1.0)
        ax_parallel.set_yticks(np.linspace(0.0, 1.0, 6))
        ax_parallel.set_yticklabels([f"{int(v*100)} %" for v in np.linspace(0.0, 1.0, 6)])
        ax_parallel.set_ylabel("Normalized Value")

        ax_parallel.set_xlim(0, len(selected_attrs) - 1)
        ax_parallel.set_title("Parallel Coordinate Plot (linked to 3D blobs)")

        # Small legend for a few example lines
        if legend_lines:
            ax_parallel.legend(
                legend_lines,
                legend_labels,
                loc="upper right",
                fontsize=8,
                title="Line by:"
            )

        fig_parallel.tight_layout()
        canvas_parallel.draw()

    def on_parallel_pick(event):
        """
        Handles clicks on lines in the parallel plot.
        Logic:
          1) Confirm the clicked artist is one of our lines.
          2) Clear previous 3D highlights and reset line style.
          3) Highlight the clicked line (thicker, more opaque).
          4) Use stored _blob_row_index to find the centroid.
          5) Render the corresponding blob as a yellow sphere in 3D.
        """
        nonlocal parallel_lines
        global highlighted_actors

        artist = event.artist

        # Only act if this is one of our polylines
        if artist not in parallel_lines:
            return

        # 2) Clear previous 3D highlight actors
        for actor in highlighted_actors:
            renderer3D.RemoveActor(actor)
        highlighted_actors = []

        # Reset style for all lines
        for ln in parallel_lines:
            ln.set_linewidth(1.8)
            ln.set_alpha(0.7)

        # 3) Make the clicked line stand out
        artist.set_linewidth(3.0)
        artist.set_alpha(1.0)

        # 4) Get which blob this line corresponds to
        row_idx = getattr(artist, "_blob_row_index", None)
        if row_idx is None:
            canvas_parallel.draw()
            return

        centroid = centroids[row_idx]

        # 5) Highlight this blob in 3D as a yellow sphere
        actor = render_blobs([centroid], renderer3D, sphere_radius=8, color=(1, 1, 0))
        actor.SetVisibility(blobs_visible)  # follows the Toggle Blobs state
        highlighted_actors.append(actor)

        render_window3D.Render()
        canvas_parallel.draw()

    # Connect pick events on the parallel figure
    fig_parallel.canvas.mpl_connect("pick_event", on_parallel_pick)

    # Button to update the parallel plot when attributes change
    tk.Button(
        parallel_left,
        text="Update Parallel Plot",
        command=update_parallel_plot
    ).pack(pady=10)

    # Pre-select first few attributes and draw an initial plot
    default_count = min(4, len(attributes))
    for i in range(default_count):
        attr_listbox.selection_set(i)
    update_parallel_plot()
    # -------------------------------------------------------------------------
    # 5) SKELETONIZATION TAB
    # -------------------------------------------------------------------------
    # Make sure current_subvolume exists
    try:
        current_subvolume
    except NameError:
        current_subvolume = None

    # Add the Skeletonization tab (works on active ROI or full volume)
    add_skeletonization_tab(
        root,
        notebook,
        renderer3D,
        get_active_volume=lambda: current_subvolume if current_subvolume is not None else data,
        spacing=(1.0, 1.0, 1.0),  # change if your voxel spacing is different
        world_dimensions=dimensions,  # (700,700,700) so NPY (300^3) can be centered correctly
        npy_default_dir=os.getcwd()
    )
    # -------------------------------------------------------------------------
    # TAB 4 — NPY (300^3) VOLUME VIEWER
    # -------------------------------------------------------------------------
    expected_npy_files = [
    "spherepack_MIS.npy", "spherepack_EDM.npy", "spherepack_ToF_in.npy",
    "spherepack_d1_MIS.npy", "spherepack_d1_EDM.npy", "spherepack_d1_ToF_in.npy",
    "spherepack_d2_MIS.npy", "spherepack_d2_EDM.npy", "spherepack_d2_ToF_in.npy",
    ]

    # Also include any other npy names you already have (optional)
    expected_npy_files += ["spherepack.npy", "spherepack_d1.npy", "spherepack_d2.npy", "spherepack_0.npy"]

    add_npy_volume_tab(
     root,
     notebook,
     renderer3D,
     original_volume=volume,
     spacing=(1.0, 1.0, 1.0),
     expected_files=expected_npy_files,
     default_dir=os.getcwd(),
     original_dimensions=dimensions   # <-- this is (700,700,700)
)

    # -------------------------------------------------------------------------
    # 6) START TK + VTK LOOPS
    # -------------------------------------------------------------------------
    root.mainloop()
    interactor3D.Start()

In [None]:
def main():
    filename = r"C:\Users\adamd\Downloads\SC1_700x700x700.raw"
    dimensions = (700, 700, 700)
    combined_gui_with_interactive_filter(filename, dimensions)

main()

Success. 

