In [2]:
import json
import torch
import numpy as np
import base64
import matplotlib.pyplot as plt
from io import BytesIO
import random



In [4]:
def generate_random_binary_tree(num_nodes, tensor_size=784, dist_size=10):
    """
    Generates a random binary tree with grayscale data.

    Args:
        num_nodes (int): Total nodes in the tree.
        tensor_size (int): Size of the 1D tensor (for 2D image).
        dist_size (int): Size of the 1D probability distribution tensor.

    Returns:
        dict: Root node of the generated binary tree.
    """
    if num_nodes < 1:
        return None

    nodes = [{"node_id": str(random.randint(0, 1000000)), 
              "data": torch.nn.functional.sigmoid(torch.rand(tensor_size)), 
              "distribution": torch.nn.functional.softmax(torch.rand(dist_size)),
              "children": []} for i in range(num_nodes)]
    # arrange it into balanced binary tree
    for i in range(num_nodes):
        if 2*i+1 < num_nodes:
            nodes[i]["children"].append(nodes[2*i+1])
        if 2*i+2 < num_nodes:
            nodes[i]["children"].append(nodes[2*i+2])
    return nodes[0]
    

# Function to convert tensor data into a base64 grayscale image
def tensor_to_base64(tensor, shape):
    array = tensor.numpy().reshape(shape)
    plt.imshow(array, cmap="gray", aspect="auto", interpolation="nearest")
    plt.axis("off")

    buf = BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
    plt.close()
    
    buf.seek(0)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


# Process the tree
num_nodes = 500  # Adjust for a larger or smaller tree
tree = generate_random_binary_tree(num_nodes)

image_shape = (28, 28)  # Shape for 2D tensor visualization
dist_shape = (1, 10)   # Shape for 1D probability distribution

def process_tree(node):
    if "data" in node:
        node["image"] = tensor_to_base64(node["data"], image_shape)
        del node["data"]

    if "distribution" in node:
        node["dist_image"] = tensor_to_base64(node["distribution"], dist_shape)
        del node["distribution"]

    for child in node.get("children", []):
        process_tree(child)

    return node

# Convert and save the tree
tree_json = process_tree(tree)

with open("tree_data.json", "w") as f:
    json.dump(tree_json, f, indent=2)

print("Saved tree data to tree_data.json")

  "distribution": torch.nn.functional.softmax(torch.rand(dist_size)),


Saved tree data to tree_data.json
