In [1]:
import math
from typing import Dict, Tuple, List
from brancharchitect.io import parse_newick

from IPython.display import display

# pythreejs imports
from pythreejs import (
    PerspectiveCamera,
    DirectionalLight,
    AmbientLight,
    Scene,
    Mesh,
    MeshPhongMaterial,
    CylinderGeometry,
    SphereGeometry,
    Group,
    OrbitControls,
    Renderer,
)


################################################################################
# 1) Node class (or from brancharchitect.tree import Node)
################################################################################
class Node:
    def __init__(self, name=None, length=0.0, children=None):
        self.name = name
        self.length = length
        self.children = children if children else []

    def traverse(self):
        yield self
        for c in self.children:
            yield from c.traverse()

    def is_internal(self) -> bool:
        return len(self.children) > 0

    def __repr__(self):
        return f"Node('{self.name}')"


################################################################################
# 2) Radial layout in Y–Z plane (returns {node: (r, angle)})
################################################################################
def compute_radial_layout(root: Node) -> Dict[Node, Tuple[float, float]]:
    """
    Assign each node (radius, angle) in a radial layout:
      - radius = distance from root (sum of branch lengths)
      - angle for leaves is spaced around 2*pi
      - internal node angle = avg of child angles
    """
    leaves = []

    def collect_leaves(n):
        if not n.children:
            leaves.append(n)
        else:
            for c in n.children:
                collect_leaves(c)

    collect_leaves(root)

    if len(leaves) == 0:
        return {root: (0.0, 0.0)}

    # Distances from root
    dist = {}

    def assign_distances(node, parent_dist=0.0):
        dist[node] = parent_dist
        for c in node.children:
            assign_distances(c, parent_dist + (c.length if c.length else 0.0))

    assign_distances(root)

    # Leaves: angles equally spaced
    layout = {}
    total_leaves = len(leaves)
    for i, leaf in enumerate(leaves):
        leaf_angle = 2 * math.pi * i / total_leaves
        layout[leaf] = (dist[leaf], leaf_angle)

    # Internal nodes: angle = average of child angles
    def post_order_angle(node):
        if not node.children:
            return layout[node][1]
        child_angles = []
        for c in node.children:
            if c not in layout:
                post_order_angle(c)
            child_angles.append(layout[c][1])
        avg_angle = sum(child_angles) / len(child_angles)
        layout[node] = (dist[node], avg_angle)
        return avg_angle

    post_order_angle(root)
    return layout


################################################################################
# 3) Build pythreejs group for a single tree (nodes + edges)
################################################################################
def build_tree_group(
    root: Node,
    layout: Dict[Node, Tuple[float, float]],
    node_radius=0.2,
    edge_radius=0.05,
    node_color="green",
    edge_color="gray",
) -> Group:

    group = Group()

    node_mat = MeshPhongMaterial(color=node_color)
    edge_mat = MeshPhongMaterial(color=edge_color)

    # Convert radial (r, angle) -> (0,y,z)
    pos_map = {}
    for node, (r, angle) in layout.items():
        y = r * math.cos(angle)
        z = r * math.sin(angle)
        pos_map[node] = (0.0, y, z)

    # Spheres for each node
    for node, (x, y, z) in pos_map.items():
        geo = SphereGeometry(radius=node_radius, widthSegments=12, heightSegments=12)
        mesh = Mesh(geometry=geo, material=node_mat)
        mesh.position = [x, y, z]
        group.add(mesh)

    # Cylinders for edges
    for parent in root.traverse():
        for child in parent.children:
            start = pos_map[parent]
            end = pos_map[child]
            edge_mesh = make_cylinder_between_points(start, end, edge_radius, edge_mat)
            group.add(edge_mesh)

    return group


################################################################################
# 4) Cylinder creation between two 3D points
################################################################################
def make_cylinder_between_points(
    start: Tuple[float, float, float],
    end: Tuple[float, float, float],
    radius: float,
    material: MeshPhongMaterial,
) -> Mesh:
    sx, sy, sz = start
    ex, ey, ez = end

    dx = ex - sx
    dy = ey - sy
    dz = ez - sz

    length = math.sqrt(dx * dx + dy * dy + dz * dz)
    if length < 1e-9:
        length = 0.0001  # avoid zero-length

    # Cylinder oriented along +Y, so we create it, then rotate/translate
    cyl_geo = CylinderGeometry(
        radiusTop=radius, radiusBottom=radius, height=length, radialSegments=8
    )
    mesh = Mesh(geometry=cyl_geo, material=material)

    # Midpoint
    mx = sx + dx / 2.0
    my = sy + dy / 2.0
    mz = sz + dz / 2.0
    mesh.position = [mx, my, mz]

    # We want +Y of the cylinder to point along (dx, dy, dz).
    # We'll compute an axis-angle from (0,1,0) to (ux,uy,uz).
    dlen = math.sqrt(dx * dx + dy * dy + dz * dz)
    ux, uy, uz = dx / dlen, dy / dlen, dz / dlen

    # cross((0,1,0), (ux,uy,uz)) = ( (1*uz - 0*uy),
    #                                (0*ux -0*uz),
    #                                (0*uy -1*ux )) => (uz, 0, -ux )
    ax = uz
    ay = 0.0
    az = -ux

    angle = math.acos(uy) if -1 <= uy <= 1 else 0.0

    axis_len = math.sqrt(ax * ax + ay * ay + az * az)
    if axis_len > 1e-9 and abs(angle) > 1e-9:
        ax /= axis_len
        ay /= axis_len
        az /= axis_len

        half_angle = angle * 0.5
        sin_ = math.sin(half_angle)
        qx = ax * sin_
        qy = ay * sin_
        qz = az * sin_
        qw = math.cos(half_angle)
        mesh.quaternion = [qx, qy, qz, qw]

    return mesh


################################################################################
# 5) Stack multiple trees along X axis (2.5D)
################################################################################
def visualize_stacked_trees_2_5d(
    trees: List[Node], x_separation=10.0, width=800, height=500
) -> Renderer:

    scene = Scene()

    # Lights
    ambient_light = AmbientLight(color="#ffffff", intensity=0.6)
    scene.add(ambient_light)

    directional_light = DirectionalLight(color="#ffffff", intensity=0.6)
    directional_light.position = [20, 50, 10]
    scene.add(directional_light)

    # Camera, controls
    camera = PerspectiveCamera(position=[0, 15, 50], fov=45)
    controls = OrbitControls(controlling=camera)

    # Build geometry for each tree
    for i, root in enumerate(trees):
        layout = compute_radial_layout(root)
        g = build_tree_group(
            root,
            layout,
            node_radius=0.3,
            edge_radius=0.05,
            node_color="lightgreen",
            edge_color="gray",
        )
        # shift group by X
        g.position = [i * x_separation, 0, 0]
        scene.add(g)

    renderer = Renderer(
        camera=camera, scene=scene, controls=[controls], width=width, height=height
    )
    return renderer


################################################################################
# 6) Example usage (in a Jupyter cell)
################################################################################
# Example trees
rootA = parse_newick("(A:1,(B:1,C:1):1,(D:1,E:1):1);")
rootB = parse_newick("(A:1,(B:1,C:1):1,(D:1,E:1):1);")
rootC = parse_newick("(A:1,(B:1,C:1):1,(D:1,E:1):1);")
rootD = parse_newick("(A:1,(B:1,C:1):1,(D:1,E:1):1);")
rootE = parse_newick("(A:1,B:1,(D:1,(C:1,E:1)):1);")

stacked_renderer = visualize_stacked_trees_2_5d(
    [rootA, rootB, rootC, rootD], x_separation=12.0
)
display(stacked_renderer)

Renderer(camera=PerspectiveCamera(fov=45.0, position=(0.0, 15.0, 50.0), projectionMatrix=(1.0, 0.0, 0.0, 0.0, …