In [10]:
import plotly.graph_objects as go
import networkx as nx
import random
import datetime
from collections import defaultdict

# Function to generate a tree structure using networkx
def generate_tree_networkx(root_children=4, levels=4):
    """
    Generate a tree structure using networkx.
    root_children: max children per node
    levels: number of levels in the tree
    """
    G = nx.DiGraph()
    node_id = 0
    G.add_node(node_id, level=0, infection_time=datetime.datetime.now())

    current_level_nodes = [node_id]
    node_id += 1

    for level in range(1, levels):
        next_level_nodes = []
        for parent_id in current_level_nodes:
            num_children = random.randint(1, root_children)
            for _ in range(num_children):
                G.add_node(
                    node_id,
                    level=level,
                    infection_time=datetime.datetime.now()
                    - datetime.timedelta(days=random.randint(1, 30)),
                )
                G.add_edge(parent_id, node_id)
                next_level_nodes.append(node_id)
                node_id += 1
        current_level_nodes = next_level_nodes

    for node in G.nodes():
        G.nodes[node]["characteristics"] = {"color": "steelblue"}

    return G


def compute_tree_layout(G, root, x_spacing=2.0, y_spacing=1.5):
    """
    Compute positions for tree nodes using a hierarchical layout algorithm.
    """
    pos = {}

    children = defaultdict(list)
    for parent, child in G.edges():
        children[parent].append(child)

    def calc_widths(node, width_map):
        if node not in children or len(children[node]) == 0:
            width_map[node] = 1
        else:
            width_map[node] = sum(calc_widths(child, width_map) for child in children[node])
        return width_map[node]

    def position_nodes(node, x, y, width_map):
        pos[node] = (x, y)
        if node not in children or len(children[node]) == 0:
            return
        child_nodes = children[node]
        total_width = sum(width_map[child] for child in child_nodes)
        x_offset = x - (total_width * x_spacing) / 2
        for child in child_nodes:
            child_width = width_map[child]
            child_x = x_offset + (child_width * x_spacing) / 2
            position_nodes(child, child_x, y - y_spacing, width_map)
            x_offset += child_width * x_spacing

    width_map = {}
    calc_widths(root, width_map)
    position_nodes(root, 0, 0, width_map)

    return pos


# Generate the tree
G = generate_tree_networkx(root_children=4, levels=4)

# Find root node (node with no incoming edges)
root = None
for node in G.nodes():
    if G.in_degree(node) == 0:
        root = node
        break

# Compute hierarchical layout
pos = compute_tree_layout(G, root, x_spacing=2.0, y_spacing=1.5)

# Extract edge coordinates
edge_x = []
edge_y = []
for edge in G.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.extend([x0, x1, None])
    edge_y.extend([y0, y1, None])

# Extract node coordinates
node_x = []
node_y = []
node_colors = []
node_labels = []
for node in G.nodes():
    x, y = pos[node]
    node_x.append(x)
    node_y.append(y)
    node_colors.append("steelblue")
    node_labels.append(f"Node {node}")

# Create Plotly figure with better sizing for larger trees
num_nodes = G.number_of_nodes()
fig_width = max(1200, 300 + num_nodes // 2)
fig_height = max(600, 300 + num_nodes // 3)

fig = go.Figure()

# Add edges
fig.add_trace(
    go.Scatter(
        x=edge_x,
        y=edge_y,
        mode="lines",
        line=dict(width=2, color="#888"),
        hoverinfo="none",
        showlegend=False,
    )
)

# Add nodes
fig.add_trace(
    go.Scatter(
        x=node_x,
        y=node_y,
        mode="markers+text",
        text=node_labels,
        textposition="top center",
        textfont=dict(size=9),
        marker=dict(size=40, color=node_colors, line=dict(width=2, color="#333")),
        hoverinfo="text",
        hovertext=node_labels,
        showlegend=False,
    )
)

# Update layout
fig.update_layout(
    showlegend=False,
    hovermode="closest",
    margin=dict(b=20, l=20, r=20, t=20),
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    plot_bgcolor="white",
    width=fig_width,
    height=fig_height,
)

fig.show()

# Save the figure as an image and SVG
fig.write_image("transmission_tree.png")
fig.write_image("transmission_tree.svg")

ValueError: 
Image export using the "kaleido" engine requires the kaleido package,
which can be installed using pip:
    $ pip install -U kaleido


In [7]:
G.number_of_nodes()

16

In [9]:
# Sample nodes and create a colored visualization
sample_size = min(15, int(G.number_of_nodes()*0.3))
sampled_nodes = set(random.sample(list(G.nodes()), sample_size))

# Create node colors based on sampling
node_x_sampled = []
node_y_sampled = []
node_colors_sampled = []
node_labels_sampled = []
for node in G.nodes():
    x, y = pos[node]
    node_x_sampled.append(x)
    node_y_sampled.append(y)
    if node in sampled_nodes:
        node_colors_sampled.append("red")
    else:
        node_colors_sampled.append("lightgray")
    node_labels_sampled.append(f"Node {node}")

# Create Plotly figure for sampled visualization
fig_sampled = go.Figure()

# Add edges
fig_sampled.add_trace(
    go.Scatter(
        x=edge_x,
        y=edge_y,
        mode="lines",
        line=dict(width=2, color="#888"),
        hoverinfo="none",
        showlegend=False,
    )
)

# Add nodes
fig_sampled.add_trace(
    go.Scatter(
        x=node_x_sampled,
        y=node_y_sampled,
        mode="markers+text",
        text=node_labels_sampled,
        textposition="top center",
        textfont=dict(size=9),
        marker=dict(size=28, color=node_colors_sampled, line=dict(width=2, color="#333")),
        hoverinfo="text",
        hovertext=node_labels_sampled,
        showlegend=False,
    )
)

# Update layout
fig_sampled.update_layout(
    showlegend=False,
    hovermode="closest",
    margin=dict(b=20, l=20, r=20, t=20),
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    plot_bgcolor="white",
    width=fig_width,
    height=fig_height,
)

fig_sampled.show()