In [1]:
# ---- Imports ----
import vtk
import numpy as np
import tkinter as tk
from tkinter import ttk
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.widgets import RectangleSelector
from scipy.ndimage import label
from skimage.measure import regionprops
import pandas as pd
from tkinter import filedialog
from pypore_3D_skeleton_integration import add_skeleton_tab


In [2]:
def get_color_map(name):
    """Return a VTK color transfer function for a given colormap."""
    color_tf = vtk.vtkColorTransferFunction()
    if name == "Grayscale":
        color_tf.AddRGBPoint(0, 0, 0, 0)
        color_tf.AddRGBPoint(255, 1, 1, 1)
    elif name == "Hot":
        color_tf.AddRGBPoint(0, 0.1, 0, 0)
        color_tf.AddRGBPoint(128, 1, 0.5, 0)
        color_tf.AddRGBPoint(255, 1, 1, 1)
    elif name == "Cool":
        color_tf.AddRGBPoint(0, 0, 1, 1)
        color_tf.AddRGBPoint(255, 1, 0, 1)
    elif name == "Jet":
        color_tf.AddRGBPoint(0, 0, 0, 0.5)
        color_tf.AddRGBPoint(128, 0, 1, 0)
        color_tf.AddRGBPoint(255, 1, 0, 0)
    else:
        color_tf.AddRGBPoint(0, 0, 0, 0)
        color_tf.AddRGBPoint(255, 1, 1, 1)
    return color_tf


def numpy2VTK(img, spacing=[1.0, 1.0, 1.0]):
    """Convert a NumPy array (3D) to a VTK image."""
    importer = vtk.vtkImageImport()
    img_data = img.astype('uint8')
    img_string = img_data.tobytes()
    dim = img.shape

    importer.SetDataExtent(0, dim[2] - 1, 0, dim[1] - 1, 0, dim[0] - 1)
    importer.SetWholeExtent(0, dim[2] - 1, 0, dim[1] - 1, 0, dim[0] - 1)
    importer.CopyImportVoidPointer(img_string, len(img_string))
    importer.SetDataScalarTypeToUnsignedChar()
    importer.SetNumberOfScalarComponents(1)
    importer.SetDataSpacing(spacing[0], spacing[1], spacing[2])
    importer.SetDataOrigin(0, 0, 0)

    return importer


def volume_render(data, colormap="Grayscale", spacing=[1.0, 1.0, 1.0]):
    """Render the 3D volume with the selected colormap."""
    importer = numpy2VTK(data, spacing)

    # Transfer Functions
    opacity_tf = vtk.vtkPiecewiseFunction()
    color_tf = get_color_map(colormap)

    # Intensity ranges for opacity
    opacity_tf.AddPoint(0, 0.0)
    opacity_tf.AddPoint(255, 1.0)

    volMapper = vtk.vtkGPUVolumeRayCastMapper()
    volMapper.SetInputConnection(importer.GetOutputPort())

    volProperty = vtk.vtkVolumeProperty()
    volProperty.SetColor(color_tf)
    volProperty.SetScalarOpacity(opacity_tf)
    volProperty.ShadeOn()
    volProperty.SetInterpolationTypeToLinear()

    volume = vtk.vtkVolume()
    volume.SetMapper(volMapper)
    volume.SetProperty(volProperty)

    return volume


In [3]:
def render_blobs(centroids, renderer, sphere_radius=5, color=(1, 0, 0)):
    """Render blobs as spheres in the 3D volume using centroids."""
    points = vtk.vtkPoints()
    for centroid in centroids:
        points.InsertNextPoint(centroid)

    poly_data = vtk.vtkPolyData()
    poly_data.SetPoints(points)

    sphere_source = vtk.vtkSphereSource()
    sphere_source.SetRadius(sphere_radius)

    glyph3D = vtk.vtkGlyph3D()
    glyph3D.SetSourceConnection(sphere_source.GetOutputPort())
    glyph3D.SetInputData(poly_data)
    glyph3D.Update()

    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputConnection(glyph3D.GetOutputPort())

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetColor(*color)  # Set initial blob color (default red)

    renderer.AddActor(actor)
    return actor

In [4]:
def combined_gui_with_interactive_filter(filename, dimensions):
    """Create a combined GUI for subvolume selection, scatter plot, heatmap selection, and histogram filtering."""
    global current_subvolume, renderer3D, render_window3D, blob_actor, highlighted_actors, blobs_visible

    # Load 3D volume data
    data = np.fromfile(filename, dtype=np.uint8).reshape(dimensions)

    # Load blob test data
    blob_test_data = pd.read_csv(r"C:\Users\adamd\Downloads\bachelor-1\blob_test.csv")
    centroids = blob_test_data[["Centroid_x", "Centroid_y", "Centroid_z"]].values.astype(float)
    attributes = list(blob_test_data.columns)

    # Initialize global variables
    highlighted_actors = []
    blobs_visible = True

    # VTK Renderer setup
    renderer3D = vtk.vtkRenderer()
    render_window3D = vtk.vtkRenderWindow()
    render_window3D.AddRenderer(renderer3D)
    render_window3D.SetSize(800, 800)

    interactor3D = vtk.vtkRenderWindowInteractor()
    interactor3D.SetRenderWindow(render_window3D)

    # Initial 3D volume rendering
    volume = volume_render(data, colormap="Grayscale")
    renderer3D.AddVolume(volume)
    renderer3D.ResetCamera()
    render_window3D.Render()

    # Render initial blobs
    # Render initial blobs with red color
    blob_actor = render_blobs(centroids, renderer3D, sphere_radius=5, color=(1, 0, 0))
    blob_actor.SetVisibility(False)  # Ensure blobs are initially hidden
    # Tkinter GUI setup
    root = tk.Tk()
    root.title("Interactive 3D Visualization and Filtering")

    # Create Notebook (Tabbed Interface)
    notebook = ttk.Notebook(root)
    notebook.pack(fill=tk.BOTH, expand=True)

    # Tab 1: Subvolume Selection and 3D Visualization
    tab_3d = ttk.Frame(notebook)
    notebook.add(tab_3d, text="3D Visualization & Filtering")
    
    control_frame = tk.Frame(tab_3d)
    control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=10, pady=10)
    
    plot_frame = tk.Frame(tab_3d)
    plot_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
    
    # Matplotlib Figures for Histogram and Scatter Plot
    fig = Figure(figsize=(5, 6), dpi=100)
    canvas = FigureCanvasTkAgg(fig, master=plot_frame)
    canvas_widget = canvas.get_tk_widget()
    canvas_widget.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
    
    ax_histogram = fig.add_axes([0.1, 0.55, 0.8, 0.35])
    ax_histogram.hist(data.ravel(), bins=50, color="blue", alpha=0.7)
    ax_histogram.set_title("Intensity Histogram")
    ax_histogram.set_xlabel("Intensity")
    ax_histogram.set_ylabel("Frequency")
    
    ax_scatter = fig.add_axes([0.1, 0.1, 0.8, 0.35])
    z_positions = np.arange(dimensions[2])
    scatter_data = data.mean(axis=(0, 1))
    scatter = ax_scatter.scatter(z_positions, scatter_data, c=scatter_data, cmap="viridis", picker=True)
    ax_scatter.set_title("Scatter Plot (Intensity vs. Z)")
    ax_scatter.set_xlabel("Z Position")
    ax_scatter.set_ylabel("Mean Intensity")
    
    canvas.draw()
    
    # Current subvolume tracking
    current_subvolume = data
    
    # Subvolume selection sliders
    tk.Label(control_frame, text="Select X Range:", font=("Arial", 12)).pack(pady=5)
    x_min = tk.IntVar(value=0)
    x_max = tk.IntVar(value=dimensions[0] - 1)
    tk.Scale(control_frame, from_=0, to=dimensions[0] - 1, variable=x_min, orient="horizontal").pack()
    tk.Scale(control_frame, from_=0, to=dimensions[0] - 1, variable=x_max, orient="horizontal").pack()
    
    tk.Label(control_frame, text="Select Y Range:", font=("Arial", 12)).pack(pady=5)
    y_min = tk.IntVar(value=0)
    y_max = tk.IntVar(value=dimensions[1] - 1)
    tk.Scale(control_frame, from_=0, to=dimensions[1] - 1, variable=y_min, orient="horizontal").pack()
    tk.Scale(control_frame, from_=0, to=dimensions[1] - 1, variable=y_max, orient="horizontal").pack()
    
    tk.Label(control_frame, text="Select Z Range:", font=("Arial", 12)).pack(pady=5)
    z_min = tk.IntVar(value=0)
    z_max = tk.IntVar(value=dimensions[2] - 1)
    tk.Scale(control_frame, from_=0, to=dimensions[2] - 1, variable=z_min, orient="horizontal").pack()
    tk.Scale(control_frame, from_=0, to=dimensions[2] - 1, variable=z_max, orient="horizontal").pack()
    
    # Heatmap selection
    tk.Label(control_frame, text="Select Heatmap Color", font=("Arial", 12)).pack(pady=5)
    colormap_var = tk.StringVar(value="Grayscale")
    heatmap_menu = ttk.Combobox(
        control_frame,
        textvariable=colormap_var,
        values=["Grayscale", "Hot", "Cool", "Jet"],
        font=("Arial", 10),
    )
    heatmap_menu.pack(pady=5)
    
    # Subvolume interaction functions
    def apply_subvolume():
        global current_subvolume
        current_subvolume = data[
            x_min.get():x_max.get() + 1,
            y_min.get():y_max.get() + 1,
            z_min.get():z_max.get() + 1,
        ]
    
        if current_subvolume.size == 0:
            print("Warning: Subvolume is empty!")
            return
    
        new_volume = volume_render(current_subvolume, colormap=colormap_var.get())
        renderer3D.RemoveAllViewProps()
        renderer3D.AddVolume(new_volume)
        render_window3D.Render()
    
        update_histogram_and_scatter()
    
    def reset_view():
        global current_subvolume
        current_subvolume = data
        renderer3D.RemoveAllViewProps()
        original_volume = volume_render(data, colormap="Grayscale")
        renderer3D.AddVolume(original_volume)
        render_window3D.Render()
    
    def change_heatmap():
        new_volume = volume_render(current_subvolume, colormap=colormap_var.get())
        renderer3D.RemoveAllViewProps()
        renderer3D.AddVolume(new_volume)
        render_window3D.Render()
    
    def update_histogram_and_scatter():
        ax_histogram.clear()
        ax_histogram.hist(current_subvolume.ravel(), bins=50, color="blue", alpha=0.7)
        ax_histogram.set_title("Updated Subvolume Intensity Histogram")
        ax_histogram.set_xlabel("Intensity")
        ax_histogram.set_ylabel("Frequency")
    
        ax_scatter.clear()
        z_positions = np.arange(current_subvolume.shape[2])
        scatter_data = current_subvolume.mean(axis=(0, 1))
        ax_scatter.scatter(z_positions, scatter_data, c=scatter_data, cmap="viridis", picker=True)
        ax_scatter.set_title("Updated Subvolume Scatter Plot (Intensity vs. Z)")
        ax_scatter.set_xlabel("Z Position")
        ax_scatter.set_ylabel("Mean Intensity")
        canvas.draw()
    
    # Buttons
    tk.Button(control_frame, text="Apply Subvolume", command=apply_subvolume).pack(pady=10)
    tk.Button(control_frame, text="Reset View", command=reset_view).pack(pady=10)
    tk.Button(control_frame, text="Change Heatmap", command=change_heatmap).pack(pady=10)
    
    def on_histogram_select(eclick, erelease):
        """Handle rectangle selection for histogram filtering."""
        if eclick.xdata is None or erelease.xdata is None:
            print("Invalid selection area!")  # Debugging
            return
    
        min_intensity = int(min(eclick.xdata, erelease.xdata))
        max_intensity = int(max(eclick.xdata, erelease.xdata))
    
        print(f"Selected Intensity Range: {min_intensity} to {max_intensity}")  # Debugging
    
        # Apply filtering
        subvolume = current_subvolume[
            x_min.get():x_max.get() + 1,
            y_min.get():y_max.get() + 1,
            z_min.get():z_max.get() + 1,
        ]
        mask = (subvolume >= min_intensity) & (subvolume <= max_intensity)
        filtered_data = np.where(mask, subvolume, 0).astype(np.uint8)
    
        if filtered_data.sum() == 0:
            print("Warning: Filtered subvolume is empty!")  # Debugging

        # Render filtered volume
        new_volume = volume_render(filtered_data, colormap=colormap_var.get())
        renderer3D.RemoveAllViewProps()
        renderer3D.AddVolume(new_volume)
        render_window3D.Render()

    # Assign the rectangle selector for the histogram
    histogram_selector = RectangleSelector(
        ax_histogram, onselect=on_histogram_select, useblit=True, button=[1], interactive=True
    )
        
    # Scatter plot interaction
    def on_pick(event):
        if len(event.ind) == 0:
            return
        selected_index = event.ind[0]
        selected_z = selected_index
        if selected_z >= current_subvolume.shape[2]:
            print("Warning: Z-index out of bounds!")
            return
        mask = np.zeros(current_subvolume.shape, dtype=np.uint8)
        mask[:, :, selected_z] = current_subvolume[:, :, selected_z]
        new_volume = volume_render(mask, colormap="Hot")
        renderer3D.RemoveAllViewProps()
        renderer3D.AddVolume(new_volume)
        render_window3D.Render()
    
    canvas.mpl_connect("pick_event", on_pick)

           # Tab 2: Bubble Chart and Blob Highlighting
    tab_bubble_chart = ttk.Frame(notebook)
    notebook.add(tab_bubble_chart, text="Bubble Chart & Blob Highlighting")
    
    frame_bubble_controls = tk.Frame(tab_bubble_chart)
    frame_bubble_controls.pack(side=tk.LEFT, fill=tk.Y, padx=10, pady=10)
    
    frame_bubble_plot = tk.Frame(tab_bubble_chart)
    frame_bubble_plot.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
    
    # Bubble Chart Setup
    fig_bubble = Figure(figsize=(5, 6), dpi=100)
    canvas_bubble = FigureCanvasTkAgg(fig_bubble, master=frame_bubble_plot)
    canvas_bubble_widget = canvas_bubble.get_tk_widget()
    canvas_bubble_widget.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
    
    ax_bubble = fig_bubble.add_subplot(111)
    
    # Function to update the Bubble Chart
    def update_bubble_chart():
        """Update the bubble chart based on user-selected attributes."""
        ax_bubble.clear()
        x_attr = x_axis_var_bubble.get()
        y_attr = y_axis_var_bubble.get()
        size_attr = size_axis_var_bubble.get()
    
        if x_attr not in attributes or y_attr not in attributes or size_attr not in attributes:
            print("Invalid attribute selection!")
            return
    
        x_data = blob_test_data[x_attr]
        y_data = blob_test_data[y_attr]
        size_data = blob_test_data[size_attr]
    
        ax_bubble.scatter(
            x_data,
            y_data,
            s=size_data * 100,
            c="blue",
            alpha=0.6,
            picker=True,
        )
        ax_bubble.set_title(f"Bubble Chart: {x_attr} vs {y_attr} (Size: {size_attr})")
        ax_bubble.set_xlabel(x_attr)
        ax_bubble.set_ylabel(y_attr)
        canvas_bubble.draw()
    
    # Dropdown menus for Bubble Chart attributes
    tk.Label(frame_bubble_controls, text="X-Axis Attribute", font=("Arial", 10)).pack(pady=5)
    x_axis_var_bubble = tk.StringVar(value=attributes[0])
    x_axis_menu_bubble = ttk.Combobox(
        frame_bubble_controls, textvariable=x_axis_var_bubble, values=attributes, font=("Arial", 10)
    )
    x_axis_menu_bubble.pack(pady=5)
    
    tk.Label(frame_bubble_controls, text="Y-Axis Attribute", font=("Arial", 10)).pack(pady=5)
    y_axis_var_bubble = tk.StringVar(value=attributes[1])
    y_axis_menu_bubble = ttk.Combobox(
        frame_bubble_controls, textvariable=y_axis_var_bubble, values=attributes, font=("Arial", 10)
    )
    y_axis_menu_bubble.pack(pady=5)
    
    tk.Label(frame_bubble_controls, text="Size Attribute", font=("Arial", 10)).pack(pady=5)
    size_axis_var_bubble = tk.StringVar(value=attributes[2])
    size_axis_menu_bubble = ttk.Combobox(
        frame_bubble_controls, textvariable=size_axis_var_bubble, values=attributes, font=("Arial", 10)
    )
    size_axis_menu_bubble.pack(pady=5)
    
    # Button to update the Bubble Chart
    tk.Button(frame_bubble_controls, text="Update Chart", command=update_bubble_chart).pack(pady=10)
    
       
       
    def toggle_blobs():
        """Toggle the visibility of blobs in the 3D view."""
        global blobs_visible
        blobs_visible = not blobs_visible
        blob_actor.SetVisibility(blobs_visible)  # Flip visibility of blobs
        render_window3D.Render()  # Refresh the 3D view
        
        # Add the Toggle Button
    toggle_button = tk.Button(frame_bubble_controls, text="Toggle Blobs", command=toggle_blobs)
    toggle_button.pack(pady=10)
    
      # Rectangle Selector for Highlighting Blobs
    def on_rectangle_select(eclick, erelease):
        """Highlight blobs within a selected rectangle in the bubble chart."""
        global highlighted_actors
    
        # Remove previously highlighted actors
        for actor in highlighted_actors:
            renderer3D.RemoveActor(actor)
        highlighted_actors = []
    
        # Get the selection rectangle coordinates
        x_min, y_min = eclick.xdata, eclick.ydata
        x_max, y_max = erelease.xdata, erelease.ydata
    
        # Validate rectangle selection
        if None in (x_min, y_min, x_max, y_max):
            print("Invalid selection rectangle.")  # Debugging
            return
    
        # Find indices of blobs within the selected rectangle
        selected_indices = blob_test_data[
            (blob_test_data[x_axis_var_bubble.get()] >= x_min) &
            (blob_test_data[x_axis_var_bubble.get()] <= x_max) &
            (blob_test_data[y_axis_var_bubble.get()] >= y_min) &
            (blob_test_data[y_axis_var_bubble.get()] <= y_max)
        ].index
    
        # Get centroids of selected blobs
        selected_centroids = centroids[selected_indices]
    
        # Highlight selected blobs in yellow
        for centroid in selected_centroids:
            actor = render_blobs([centroid], renderer3D, sphere_radius=8, color=(1, 1, 0))  # Yellow color
            highlighted_actors.append(actor)
    
        # Ensure non-highlighted blobs remain red
        non_selected_indices = blob_test_data.index.difference(selected_indices)
        non_selected_centroids = centroids[non_selected_indices]
    
        # Render non-selected blobs in red
        for centroid in non_selected_centroids:
            render_blobs([centroid], renderer3D, sphere_radius=5, color=(1, 0, 0))  # Red color
    
        # Refresh the 3D rendering
        render_window3D.Render()
    
    # Rectangle Selector setup for the bubble chart
    rect_selector_bubble = RectangleSelector(
        ax_bubble,
        on_rectangle_select,
        interactive=True,
        useblit=True,
        button=[1],
        minspanx=5,
        minspany=5,
        spancoords="pixels",
    )
    # End of Tab 2
    # Ensure current_subvolume exists in this scope; if you donâ€™t use it yet, define:
    try:
        current_subvolume
    except NameError:
        current_subvolume = None

# Add the Skeletonization tab (uses ROI if present, else full volume)
    add_skeleton_tab(
    root,
    notebook,
    renderer3D,
    get_active_volume=lambda: current_subvolume if current_subvolume is not None else data,
    sphere_radius=0.5  # adjust if your voxel scale is large/small
)

    root.mainloop()
    interactor3D.Start()


In [None]:
def main():
    filename = r"C:\Users\adamd\Downloads\SC1_700x700x700.raw"
    dimensions = (700, 700, 700)
    combined_gui_with_interactive_filter(filename, dimensions)

main()