### import modules

In [None]:
import os

import pandas as pd
import plotly.graph_objects as go
import plotly.colors
from plotly.subplots import make_subplots
from typing import Optional

from load_tree import Tree

### load the tree

In [None]:
load_path = os.path.join("gbert_large_paraphrase_cosine.jsonl")

tree = Tree.build_tree_from_path(load_path)

### build dataframe from tree for visualization

#### setup

In [None]:
def build_hierarchical_dataframe(
    tree: Tree, tree_level: int = -1, root_index: Optional[int] = None
):
    """Build a dataframe from the imported tree to input to the plotly Treemap method.

    Args:
        tree: The clustering tree.
        tree_level: The level of the clustering tree which is considered the top level for the visualization.
        root_index: Specifies a specific cluster for the given tree_level for which the visualization is realized.
    """
    columns = ["id", "parent", "label", "description", "keywords", "value", "color"]
    df_all_trees = pd.DataFrame(columns=columns)

    def add_nodes(node, parent_name: str = ""):
        if node.children:
            for child in node.children.values():
                add_nodes(child, node.name)

        df_all_trees.loc[len(df_all_trees)] = [
            node.name,
            parent_name,
            node.label,
            node.description,
            node.keywords,
            node.weight(),
            node.sentiment_score(),
        ]

    if root_index is not None:
        root_node = tree.levels[tree_level][root_index]
        add_nodes(root_node)
    else:
        for root_node in tree.levels[tree_level].values():
            add_nodes(root_node, parent_name="root")

    return df_all_trees

#### build

In [None]:
df_hierarchical = build_hierarchical_dataframe(tree, 102, None)
df_hierarchical.tail()

### create figure

#### setup

In [None]:
def format_text(row, keywords_per_row: int=5):
    string = ""
    if row.label is not None:
        string += "<b>Topic</b> " + row.label + "<br>"
    if row.description is not None:
        string += "<b>Description</b> " + row.description + "<br><br>"

    formatted_keywords = ""
    for start in range(0, len(row.keywords), keywords_per_row):
        end = start + keywords_per_row if start + keywords_per_row < len(row.keywords) else None
        formatted_keywords += ", ".join(row.keywords[start:end]) + "<br>"
    string += "<b>Keywords</b> " + formatted_keywords + "<br>"
    
    string += "<b>Sentiment score</b> " + str(round(row.color, 2))

    return string

#### create

In [None]:
fig = go.Figure()

treemap = go.Treemap(
    labels=df_hierarchical["id"],
    parents=df_hierarchical["parent"],
    values=df_hierarchical["value"],
    branchvalues="total",
    marker=dict(
        colors=df_hierarchical["color"],
        colorscale="prgn",
        colorbar_title="Sentiment<br>score",
        cmid=0,
        cmin=-1,
        cmax=1,
        showscale=True,
    ),
    hoverinfo="text",
    hovertext=df_hierarchical.apply(format_text, axis=1),
)
fig.add_trace(treemap)

fig.update_layout(
    font=dict(
        family="Times New Roman",
        size=16,
    ),
    width=800,
    height=800,
    margin=dict(t=25, l=5, r=5, b=5),
)

# optionally, export as standalone HTML file
# fig.write_html("treemap.html")

fig.show()