In [10]:
# !pip install plotly networkx ipywidgets --quiet   # run once
import json, textwrap, networkx as nx, plotly.graph_objects as go
from collections import defaultdict
from IPython.display import display


## Load the Star Wars Story Data

Let's load the story data from the JSON file.

In [11]:
def wrap_text(txt, width=40):
    return "<br>".join(textwrap.wrap(txt, width))

def load_story_json(path):
    with open(path, 'r') as f:
        return json.load(f)

def build_story_graph(story_dict):
    """
    Returns:  G (DiGraph), node_attrs {id: {...}}, root
    Edge rule: a node is parent of any node whose ID starts with its ID + '_' .
    """
    nodes = story_dict["graph"]["nodes"]
    G          = nx.DiGraph()
    node_attrs = {}
    
    # add nodes with rich attrs
    for node_id, data in nodes.items():
        G.add_node(node_id)
        node_attrs[node_id] = data
    
    # infer edges via ID-prefix convention
    for child in G.nodes:
        if '_' in child:
            parent = child.rsplit('_', 1)[0]
            if parent in G:
                G.add_edge(parent, child)
    root = "node_0"
    return G, node_attrs, root


In [12]:
def bfs_layout(G, root="node_0"):
    levels = defaultdict(list)
    for node in nx.bfs_tree(G, root):
        lvl = nx.shortest_path_length(G, root, node)
        levels[lvl].append(node)

    pos = {}
    for lvl, nodes in levels.items():
        for i, n in enumerate(sorted(nodes)):
            y = i - (len(nodes)-1)/2        # center each level on y=0
            pos[n] = (lvl, y)
    return pos


In [13]:
import re, networkx as nx
from collections import defaultdict

def natural_key(s):
    # 'node_0_12_3' → ['node_',0,'_',12,'_',3] so numeric parts sort correctly
    return [int(t) if t.isdigit() else t for t in re.split(r'(\d+)', s)]

def flat_tree_layout(G, root="node_0", x_gap=2.0, y_gap=1.4):
    """Return {node: (x,y)} with:
       • x = depth * x_gap
       • y = centred siblings; if a node has one child, child keeps same y
    """
    # group by depth
    levels = defaultdict(list)
    for n in nx.bfs_tree(G, root):
        levels[nx.shortest_path_length(G, root, n)].append(n)

    pos = {}
    for depth in sorted(levels):
        level_nodes = sorted(levels[depth], key=natural_key)
        mid = (len(level_nodes) - 1) / 2
        for idx, n in enumerate(level_nodes):
            # inherit y if this edge is part of a chain
            parent = next(iter(G.predecessors(n)), None)
            if parent and len(list(G.successors(parent))) == 1:
                y = pos[parent][1]
            else:
                y = (idx - mid) * y_gap
            pos[n] = (depth * x_gap, y)
    return pos


In [14]:
def story_fig(G, node_data, pos, highlight=None, title="Star-Wars Story"):
    highlight = set(highlight or [])
    
    # === A. squares (story nodes) ===
    sx, sy, stxt, shover, scol = [], [], [], [], []
    for nid, (x,y) in pos.items():
        sx.append(x); sy.append(y)
        stxt.append(f"ID: {nid.split('_')[-1]}")  # short label
        hover = f"<b>{nid}</b><br>{wrap_text(node_data[nid]['story'],30)}"
        shover.append(hover)
        scol.append("#FF9966" if nid in highlight else "#A0CBE8")
    
    # === B. circles (choice points, mid-edge) ===
    cx, cy, cidx, chover, ccol = [], [], [], [], []
    for u,v in G.edges:
        (x0,y0), (x1,y1) = pos[u], pos[v]
        mx, my = (x0+x1)/2, (y0+y1)/2
        eid = f"{u}->{v}"
        cx.append(mx); cy.append(my); cidx.append(eid)
        chover.append(f"{u} → {v}")
        ccol.append("#66FF99" if eid in highlight else "#FFDAC1")
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=sx, y=sy, mode="markers+text",
        marker=dict(symbol="square", size=22, color=scol),
        text=stxt, textposition="top center",
        hovertext=shover, hoverinfo="text",
        name="Scenes"
    ))
    fig.add_trace(go.Scatter(
        x=cx, y=cy, mode="markers",
        marker=dict(symbol="circle", size=15, color=ccol),
        hovertext=chover, hoverinfo="text",
        name="Choices"
    ))
    
    # === C. arrows ===
    arrows=[]
    for u,v in G.edges:
        x0,y0 = pos[u]; x1,y1 = pos[v]
        arrows.append(dict(
            x=x1, y=y1, ax=x0, ay=y0,
            xref='x', yref='y', axref='x', ayref='y',
            showarrow=True, arrowhead=2, arrowwidth=1, opacity=.5
        ))
    fig.update_layout(
        title=title, annotations=arrows,
        xaxis=dict(visible=False), yaxis=dict(visible=False),
        hovermode="closest", plot_bgcolor="white", paper_bgcolor="white",
        margin=dict(l=40,r=40,t=60,b=40)
    )
    fig.update_traces(cliponaxis=False)
    return fig


In [15]:
story_data = load_story_json("star_wars_story.json")
G, attrs, root = build_story_graph(story_data)
pos = flat_tree_layout(G, root)             
fig = story_fig(G, attrs, pos, title="Star Wars Interactive Story")
fig.show()
