In [13]:
import numpy as np

In [180]:
import pickle

with open('dialogue_sim/graphs/selected_sbert_10.pkl', 'rb') as file:
    graph = pickle.load(file)

In [181]:
import plotly.graph_objects as go
import networkx as nx

In [182]:
transitions = np.array(graph.get_transitions())
labels = graph.cluster_kmeans_labels[1]
num_nodes = len(transitions)

In [183]:
y_span = 1
x_span = 0.1
cur_x = {"USER" : 0, "SYSTEM" : 0} 
cur_y = {"USER" : y_span, "SYSTEM" : y_span * 2} 

edge_x = []
edge_y = []
node_x = []
node_y = []

for node in range(num_nodes):
    if node not in labels:
        node_x.append(x_span * (num_nodes / 6))
        node_y.append(0 * y_span)
        continue
    label = labels[node]
    node_y.append(cur_y[label])
    node_x.append(cur_x[label])
    cur_x[label] += x_span

edge_traces = []

thresh = 0.3
for node1 in range(num_nodes):
    for node2 in range(num_nodes):
        if node1 == node2:
            continue
        edge_x = [node_x[node1], node_x[node2], None]
        edge_y = [node_y[node1], node_y[node2], None]
        c = "#888"
        if node1 in labels:
            c = "red" if labels[node1] == "USER" else "blue"
        trace = go.Scatter(
            x=edge_x, y=edge_y,
            line=dict(width=transitions[node1][node2] * 2, color=c),
            hoverinfo='text',
            mode='lines')
        trace.text = [str(transitions[node1][node2])]
        edge_traces.append(trace)

In [184]:
node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers',
    hoverinfo='text',
    marker=dict(
        showscale=True,
        colorscale='Jet',
        reversescale=True,
        color=[],
        size=10,
        line_width=2))

In [185]:
node_text = []
node_sizes = []
lbls = []
for node in range(num_nodes - 1):
    content = graph.get_node_content(node)
    text = f"{labels[node]} ({len(content)})<br>" + "<br>".join(content[:4])
    node_text.append(text)
    node_sizes.append(len(graph.get_node_content(node)) / 200)
    lbls.append(0 if labels[node] == "USER" else 1)

node_sizes.append(20)
node_text.append("START NODE")

node_trace.marker.color = lbls
node_trace.marker.size = node_sizes
node_trace.text = node_text

In [186]:
dial = "i need a place to dine in the center thats expensive. <br> I have several options for you; do you prefer African, Asian, or British food?.<br> Any sort of food would be fine, as long as it is a bit expensive. Could I get the phone number for your recommendation?.<br> There is an Afrian place named Bedouin in the centre. How does that sound?. <br> Sounds good, could I get that phone number? Also, could you recommend me an expensive hotel?. <br> Bedouin's phone is 01223367660. As far as hotels go, I recommend the University Arms Hotel in the center of town.. <br> Yes. Can you book it for me?. <br> Sure, when would you like that reservation?. <br> i want to book it for 2 people and 2 nights starting from saturday.. <br> Your booking was successful. Your reference number is FRGZWQL2 . May I help you further?. <br> That is all I need to know. Thanks, good bye.. <br> Thank you so much for Cambridge TownInfo centre. Have a great day!."

In [187]:
fig = go.Figure(data=edge_traces + [node_trace],
             layout=go.Layout(
                title='Dialogue graph',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                annotations=[ dict(
                    text="Target Dialogue: <br>" + dial,
                    showarrow=False,
                    xref="paper", yref="paper",
                    x=0.005, y=-0.001 ) ],
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.update_traces(marker_showscale=False)
fig.write_html("pictures/graph_viz_10.html")
fig.show()