In [None]:
%pip install vedo

In [None]:
import numpy as np
from vedo import Plotter, Line

class FiberSelectionGUI:
    def __init__(self, streamlines):
        self.streamlines = streamlines
        self.selected_fibers = []
        self.lines = []
        self.plotter = Plotter(title="Fiber Tracts Selector", axes=1, interactive=True)
        self._build_scene()
    
    def _build_scene(self):
        for i, streamline in enumerate(self.streamlines):
            line = Line(streamline, c="blue")
            line.name = str(i)
            self.lines.append(line)
            self.plotter += line
        self.plotter.add_callback("mouse click", self._on_click)
        self.plotter.show(interactive=True)

    def _on_click(self, evt):
        picked = self.plotter.picked3d
        if picked is None:
            return
        closest_idx = self._find_closest_line(picked)
        if closest_idx is not None:
            line = self.lines[closest_idx]
            if closest_idx in self.selected_fibers:
                self.selected_fibers.remove(closest_idx)
                line.color("blue")
            else:
                self.selected_fibers.append(closest_idx)
                line.color("red")
            self.plotter.render()

    def _find_closest_line(self, point):
        min_dist = float("inf")
        closest_idx = None
        for idx, line in enumerate(self.lines):
            dist = np.linalg.norm(line.closest_point(point) - point)
            if dist < min_dist:
                min_dist = dist
                closest_idx = idx
        return closest_idx


# Example usage with dummy data
streamlines = [
    np.cumsum(np.random.randn(20, 3), axis=0) * 2,  # Simulated streamline 1
    np.cumsum(np.random.randn(25, 3), axis=0) * 2,  # Simulated streamline 2
    np.cumsum(np.random.randn(30, 3), axis=0) * 2   # Simulated streamline 3
]

# Launch GUI
gui = FiberSelectionGUI(streamlines)


In [None]:
gui

In [None]:
from pathlib import Path
import numpy as np

# Example: simulate three random fibers
fibers = [np.cumsum(np.random.randn(n, 3), 0) for n in (20, 25, 30)]

gui = FiberSelectionGUI(fibers, merge_handler=MergeHandler())
gui.show()
