In [1]:
import math
import torch
import torch.nn.functional as F
import llama_cpp

import networkx as nx
import numpy as np

from path import *

In [29]:
def p1(n, k=2):
    return sum([k**i for i in range(n+1)])
n = 12
k = 3
print(float(p1(n, k)), (k**(n+1) - 1)/(k-1))
float(p1(n, k)) == (k**(n+1) - 1)/(k-1)

797161.0 797161.0


True

In [2]:
phrase = "Italy"
k = 2
max_depth = 12
G = phrase_graph(phrase, k, max_depth)

Building graph: 100%|█████████▉| 8190/8191 [04:51<00:00, 28.10it/s]
Assigning phrases:   0%|          | 4/1892 [00:39<5:08:16,  9.80s/it]


KeyboardInterrupt: 

In [None]:
def arc_hierarchical_layout(G, subset_key="depth", initial_radius=100, radius_growth_factor=1.5, arc_angle=np.pi/4):
    pos = {}
    layers = {}
    for node in G.nodes():
        depth = G.nodes[node][subset_key]
        if depth not in layers:
            layers[depth] = []
        layers[depth].append(node)
    
    max_depth = max(layers.keys())
    for depth, nodes in layers.items():
        radius = initial_radius * (radius_growth_factor ** depth)
        angle_step = arc_angle / (len(nodes) - 1) if len(nodes) > 1 else 0
        start_angle = np.pi/2  # Start from the top
        for i, node in enumerate(nodes):
            angle = start_angle - i * angle_step
            x = radius * np.cos(angle)
            y = -radius * np.sin(angle)  # Negative to expand downward
            pos[node] = np.array([x, y])
    
    return pos

# Apply the arc hierarchical layout with tunable radius growth factor
pos = arc_hierarchical_layout(G, subset_key="depth", initial_radius=100, radius_growth_factor=1.5, arc_angle=np.pi/2)
for node, position in pos.items():
    G.nodes[node]['pos'] = position

In [None]:
import networkx as nx
import plotly.graph_objects as go
from ipywidgets import widgets

seed_node = (phrase, 0)
# Assume G is your graph and phrase is defined

# Create a dictionary to store all paths from each depth 5 node to the seed
all_paths_to_seed = {node: list(nx.all_simple_paths(G, seed_node, node)) for node in G.nodes() if G.nodes[node]['depth'] == max_depth}

# Create edge traces for the graph structure
edge_x, edge_y = [], []
for edge in G.edges():
    x0, y0 = G.nodes[edge[0]]['pos']
    x1, y1 = G.nodes[edge[1]]['pos']
    edge_x.extend([x0, x1, None])
    edge_y.extend([y0, y1, None])

edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'),
                        hoverinfo='none', mode='lines')

# Create node trace
node_x, node_y, node_text = [], [], []
for node in G.nodes():
    x, y = G.nodes[node]['pos']
    node_x.append(x)
    node_y.append(y)
    node_text.append(f"{node[0]} (Depth: {G.nodes[node]['depth']})")

node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers',
    hoverinfo='text',
    text=node_text,
    marker=dict(
        showscale=True,
        colorscale='YlGnBu',
        reversescale=True,
        color=[G.nodes[node]['depth'] for node in G.nodes()],
        size=10,
        colorbar=dict(thickness=15, title='Node Depth', xanchor='left', titleside='right'),
        line_width=2
    )
)

# Create figure
fig = go.FigureWidget(data=tuple([edge_trace, node_trace]),
                layout=go.Layout(
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    width=1000,
                    height=1000
                ))

# Create a widget to display clicked node info
click_info = widgets.Label(value="Click a depth 5 node to highlight paths")

In [None]:
from ipywidgets import Button, VBox, HBox

def custom_click_handler(trace, points, selector):
    if not points.point_inds:
        click_info.value = "No point selected"
        return

    try:
        clicked_index = points.point_inds[0]
        all_nodes = list(G.nodes())
        
        clicked_node = all_nodes[clicked_index]
        node_depth = G.nodes[clicked_node]['depth']
        
        click_info.value = f"Clicked node: {clicked_node[0]} (Depth: {node_depth})"    

        with fig.batch_update():
            fig.update_traces(visible=False, selector=dict(name='path'), overwrite=True)

        with fig.batch_update():
            # Hide all path traces
          
            if node_depth == max_depth:
                paths = all_paths_to_seed.get(clicked_node, [])
                phrases = G.nodes[clicked_node]["phrase"]
                num_paths = len(paths)
                click_info.value += f" - Found {num_paths} path{'s' if num_paths != 1 else ''}"
                # print(f"""\n{"\n".join(phrases)}""" if isinstance(phrases, list) else phrases)
                
                # Highlight new paths
                for i, path in enumerate(paths):
                    x = [G.nodes[n]['pos'][0] for n in path]
                    y = [G.nodes[n]['pos'][1] for n in path]
                    fig.add_trace(go.Scatter(x=x, y=y, mode='lines', line=dict(color='red', width=2),
                                             hoverinfo='none', name='path',
                                             showlegend=False))
            else:
                click_info.value += f" - Not a depth {max_depth} node, no paths to highlight"

    except Exception as e:
        click_info.value = f"An unexpected error occurred: {type(e).__name__}: {str(e)}"

def reset_plot(b):
    with fig.batch_update():
        # Hide all path traces
        fig.update_traces(visible=False, selector=dict(name='path'))
    click_info.value = "Plot reset to initial state"

reset_button = Button(description="Reset Plot")
reset_button.on_click(reset_plot)

# Attach the custom click event handler
fig.data[1].on_click(custom_click_handler)
button_info_box = HBox([reset_button, click_info])

display(widgets.VBox([fig, button_info_box]))

VBox(children=(FigureWidget({
    'data': [{'hoverinfo': 'none',
              'line': {'color': '#888', 'widt…