In [1]:
import networkx as nx
from pyvis.network import Network
from math import sqrt
import json

class Node:
    def __init__(self, name, redo=0, id=0):
        self.name = name
        self.amount = 0
        self.redo = [redo]
        self.id = id
    
    def change_redo(self, redo):
        self.redo.append(redo)

class Link:
    def __init__(self, p_state1, p_state2, node1=None, node2=None):
        self.p_state1 = [p_state1]
        self.p_state2 = [p_state2]
        self.node1 = node1
        self.node2 = node2
    
    def change_state_value(self, p_state1, p_state2):
        self.p_state1.append(p_state1)
        self.p_state2.append(p_state2)

class Tree:
    def __init__(self):
        self.nodes = []
        self.links = []
        self.current_id = 0
    
    def add_node(self, node: Node) -> int:
        current_id = self.current_id
        self.current_id += 1
        node.id = current_id
        self.nodes.append(node)
        return current_id

    def add_link(self, link: Link, node_id1, node_id2):
        for node in self.nodes:
            if node.id == node_id1:
                link.node1 = node
            if node.id == node_id2:
                link.node2 = node
        self.links.append(link)

    def to_networkx(self):
        G = nx.DiGraph()
        for node in self.nodes:
            G.add_node(node.id, label=f"{node.name}")

            G.add_edge(node.id, node.id, label=f"{node.redo[-1]}")

        for link in self.links:
            G.add_edge(link.node1.id, link.node2.id, label=f"{link.p_state2[-1]}")
            G.add_edge(link.node2.id, link.node1.id, label=f"{link.p_state1[-1]}")

        return G

In [2]:
def visualize_tree_at_time(tree, time):
    G = nx.DiGraph()
    for node in tree.nodes:
        redo_value = node.redo[time] if time < len(node.redo) else node.redo[-1]
        G.add_node(node.id, label=f"{node.name}", title=f"Amount: {node.amount}\nRedo at time {time}: {redo_value}")

        if redo_value > 0:
            G.add_edge(node.id, node.id, label=f"{redo_value}")

    for link in tree.links:
        p_state1_value = link.p_state1[time] if time < len(link.p_state1) else link.p_state1[-1]
        p_state2_value = link.p_state2[time] if time < len(link.p_state2) else link.p_state2[-1]

        G.add_edge(link.node1.id, link.node2.id, label=f"{p_state2_value}", width=10 * sqrt(p_state2_value))
        G.add_edge(link.node2.id, link.node1.id, label=f"{p_state1_value}", width=10 * sqrt(p_state1_value))

    net = Network(notebook=True, height="720px", width="100%", directed=True, cdn_resources='remote')
    net.from_nx(G)
    net.force_atlas_2based(spring_length=300)

    filename = f"./visualization/frames/tree_visualization_time_{time}.html"
    net.show(filename)
    return filename

In [3]:
import os
import re

def delete_old_frames():
    directory = "./visualization/frames"
    pattern = r'tree_visualization_time_\d+\.html'

    for filename in os.listdir(directory):
        if re.match(pattern, filename):
            os.remove(os.path.join(directory, filename))
            print(f"Deleted {filename}")

In [4]:
import re

def update_html_max_frame(new_max_frame):
    with open("./visualization/index.html", 'r') as file:
        content = file.read()

    content = re.sub(r'var maxFrame = \d+;', f'var maxFrame = {new_max_frame};', content)
    content = re.sub(r'<input type="range" id="timeSlider" min="0" max="\d+"', 
                     f'<input type="range" id="timeSlider" min="0" max="{new_max_frame}"', content)

    with open("./visualization/index.html", 'w') as file:
        file.write(content)

In [5]:
import random
from tqdm import tqdm

frames_amount = 5

tree = Tree()

ns = [
    Node("Attack", 0.1),
    Node("Block", 0.4),
    Node("Dodge", 0.78)
]

for i in range(frames_amount):
    for n in ns:
        n.change_redo(round(random.random(), 2))

n1_id = tree.add_node(ns[0])
n2_id = tree.add_node(ns[1])
n3_id = tree.add_node(ns[2])

ls = [
    Link(0.3, 0.7),
    Link(0.1, 0.9),
    Link(0.45, 0.65)
]
for i in range(frames_amount):
    for l in ls:
        v = round(random.random(), 2)
        l.change_state_value(v, round(1 - v, 2))

tree.add_link(ls[0], n1_id, n2_id)
tree.add_link(ls[1], n2_id, n3_id)
tree.add_link(ls[2], n1_id, n3_id)

filenames = []
delete_old_frames()
for time in tqdm(range(len(tree.nodes[0].redo))):
    filenames.append(visualize_tree_at_time(tree, time))
update_html_max_frame(frames_amount)

Deleted tree_visualization_time_0.html
Deleted tree_visualization_time_1.html
Deleted tree_visualization_time_2.html
Deleted tree_visualization_time_3.html
Deleted tree_visualization_time_4.html
Deleted tree_visualization_time_5.html


 17%|█▋        | 1/6 [00:00<00:00,  9.03it/s]

./visualization/frames/tree_visualization_time_0.html
./visualization/frames/tree_visualization_time_1.html
./visualization/frames/tree_visualization_time_2.html


100%|██████████| 6/6 [00:00<00:00, 23.41it/s]

./visualization/frames/tree_visualization_time_3.html
./visualization/frames/tree_visualization_time_4.html
./visualization/frames/tree_visualization_time_5.html



