In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math
import torch
import torch.nn.functional as F
import llama_cpp

from path import *

In [3]:
phrase = "Italy"
k = 3
max_depth = 6
G = phrase_graph(phrase, k, max_depth)

In [6]:
pos = nx.multipartite_layout(G, subset_key="depth", scale=3)
for node, position in pos.items():
    G.nodes[node]['pos'] = position

In [27]:
import networkx as nx
import plotly.graph_objects as go
from ipywidgets import widgets

seed_node = (phrase, 0)
# Assume G is your graph and phrase is defined

# Create a dictionary to store all paths from each depth 5 node to the seed
all_paths_to_seed = {node: list(nx.all_simple_paths(G, seed_node, node)) for node in G.nodes() if G.nodes[node]['depth'] == max_depth}

# Create edge traces for the graph structure
edge_x, edge_y = [], []
for edge in G.edges():
    x0, y0 = G.nodes[edge[0]]['pos']
    x1, y1 = G.nodes[edge[1]]['pos']
    edge_x.extend([x0, x1, None])
    edge_y.extend([y0, y1, None])

edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'),
                        hoverinfo='none', mode='lines')

# Create node trace
node_x, node_y, node_text = [], [], []
for node in G.nodes():
    x, y = G.nodes[node]['pos']
    node_x.append(x)
    node_y.append(y)
    node_text.append(f"{node[0]} (Depth: {G.nodes[node]['depth']})")

node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers',
    hoverinfo='text',
    text=node_text,
    marker=dict(
        showscale=True,
        colorscale='YlGnBu',
        reversescale=True,
        color=[G.nodes[node]['depth'] for node in G.nodes()],
        size=10,
        colorbar=dict(thickness=15, title='Node Depth', xanchor='left', titleside='right'),
        line_width=2
    )
)

# Create figure
fig = go.FigureWidget(data=[edge_trace, node_trace],
                layout=go.Layout(
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
                ))

# Create a widget to display clicked node info
click_info = widgets.Label(value="Click a depth 5 node to highlight paths")

def custom_click_handler(trace, points, selector):
    if not points.point_inds:
        click_info.value = "No point selected"
        return
    
    try:
        clicked_node = list(G.nodes())[points.point_inds[0]]
        node_depth = G.nodes[clicked_node]['depth']
        
        click_info.value = f"Clicked node: {clicked_node[0]} (Depth: {node_depth})"    
        
        with fig.batch_update():
            # Always remove previously highlighted paths
            fig.data = fig.data[:2]  # Keep only edge_trace and node_trace
            
            if node_depth == max_depth:
                paths = all_paths_to_seed.get(clicked_node, [])
                num_paths = len(paths)
                click_info.value += f" - Found {num_paths} path{'s' if num_paths != 1 else ''}"
                
                # Highlight new paths
                for i, path in enumerate(paths):
                    x = [G.nodes[n]['pos'][0] for n in path]
                    y = [G.nodes[n]['pos'][1] for n in path]
                    fig.add_trace(go.Scatter(x=x, y=y, mode='lines', line=dict(color='red', width=2),
                                            hoverinfo='none', name=f'Path {i+1} to {clicked_node[0]}'))
            else:
                click_info.value += f" - Not a depth {max_depth} node, no paths to highlight"

    except IndexError as e:
        click_info.value = f"Error: {str(e)}. Please try clicking again."
    except Exception as e:
        click_info.value = f"An unexpected error occurred: {str(e)}"


# Attach the custom click event handler
fig.data[1].on_click(custom_click_handler)

display(widgets.VBox([fig, click_info]))

VBox(children=(FigureWidget({
    'data': [{'hoverinfo': 'none',
              'line': {'color': '#888', 'widt…

IndexError: tuple index out of range

IndexError: tuple index out of range

In [32]:
G.nodes[(" excellent", 6)]

{'phrase': "Italy. It's an excellent",
 'token': ' excellent',
 'depth': 6,
 'pos': array([ 0.01731314, -0.57024793])}

In [17]:


for n in G.nodes():
    if G.nodes[n]["depth"] == 6:
        print(n,"\t", len(G.nodes[n]["phrase"]))

('.', 6) 	 13
(',', 6) 	 23
(' of', 6) 	 9
(' (', 6) 	 3
(' Em', 6) 	 29
(' Republic', 6) 	 35
('ic', 6) 	 28
('\n', 6) 	 16
(' The', 6) 	 13
(' In', 6) 	 8
(' and', 6) 	 20
(' as', 6) 	 2
(' among', 6) 	 3
(' been', 6) 	 6
(' all', 6) 	 35
(' the', 6) 	 53
('S', 6) 	 18
('K', 6) 	 18
(' S', 6) 	 19
(' have', 6) 	 6
('bek', 6) 	 20
('be', 6) 	 19
('amb', 6) 	 20
('The', 6) 	 9
('-', 6) 	 12
(' first', 6) 	 10
(' most', 6) 	 13
(' United', 6) 	 12
(' ', 6) 	 24
(' addition', 6) 	 2
(' was', 6) 	 11
(' has', 6) 	 8
(' well', 6) 	 2
(' they', 6) 	 3
('In', 6) 	 4
('  ', 6) 	 21
(' A', 6) 	 21
(' C', 6) 	 22
(' D', 6) 	 22
(' B', 6) 	 22
(' company', 6) 	 31
(' it', 6) 	 6
(' latter', 6) 	 5
(' particularly', 6) 	 41
(' seen', 6) 	 2
(' experienced', 6) 	 39
(' highest', 6) 	 2
(' largest', 6) 	 5
(' Italy', 6) 	 3
('Ital', 6) 	 24
('9', 6) 	 6
('8', 6) 	 4
('5', 6) 	 7
('0', 6) 	 9
('2', 6) 	 12
('1', 6) 	 12
(' are', 6) 	 5
(' Netherlands', 6) 	 33
(' rest', 6) 	 3
(' in', 6) 	 12
(' loc

In [None]:
# import cProfile
# import pstats
# from pstats import SortKey

# cProfile.run("G = phrase_graph(phrase, k, max_depth)", "stats")
# p = pstats.Stats("stats")
# p.sort_stats(SortKey.TIME)
# p.print_stats()