In [2]:
import torch
import networkx as nx
from pyvis.network import Network
import matplotlib.pyplot as plt
import numpy as np
import base64
from io import BytesIO


In [3]:

# Sample tree structure
tree = {
    "node_id": 1,
    "data": torch.rand(16),  # Example 1D tensor
    "children": [
        {
            "node_id": 2,
            "data": torch.rand(16),
            "children": []
        },
        {
            "node_id": 3,
            "data": torch.rand(16),
            "children": []
        }
    ]
}

# Define the shape (e.g., 4x4 for 16 values)
shape = (4, 4)


# Function to convert tensor data into a grayscale image and return Base64 encoding
def tensor_to_image_base64(tensor, shape):
    array = tensor.numpy().reshape(shape)  # Reshape to 2D
    plt.imshow(array, cmap="gray", interpolation="nearest")
    plt.axis("off")

    buf = BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
    plt.close()
    
    buf.seek(0)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


# Function to build tree graph in NetworkX
def build_graph(graph, node, parent=None):
    node_id = node["node_id"]
    img_data = tensor_to_image_base64(node["data"], shape)
    
    # Add node with image tooltip
    graph.add_node(node_id, title=f'<img src="data:image/png;base64,{img_data}" width="100">')

    # Add edge if there's a parent
    if parent is not None:
        graph.add_edge(parent, node_id)

    # Recursively add children
    for child in node["children"]:
        build_graph(graph, child, node_id)


# Create NetworkX graph
G = nx.DiGraph()
build_graph(G, tree)

# Convert to Pyvis
net = Network(notebook=True, directed=True)
net.from_nx(G)

# Customize node appearance
for node in net.nodes:
    node["shape"] = "image"
    node["size"] = 20  # Node size

# Save and display
net.show("tree_visualization.html")

tree_visualization.html
