# Setup

In [None]:
from scipy import spatial
from scipy.sparse.csgraph import shortest_path
from ast import literal_eval
import json
import numpy as np

# Compute Shortest Path

In [None]:
caption_type = "CLIP"

graphs = json.load(open("/content/{}_story_graph.json".format(caption_type)))

In [None]:
def get_path(Pr, i, j):
    path = [j]
    k = j
    while Pr[i, k] != -9999:
        path.append(Pr[i, k])
        k = Pr[i, k]
    return path[::-1]

In [None]:
%%time

method = "pmi_edge_weights"
search_method = "BF"

story_ids = list(graphs.keys())
storylines = {}

count = 0

for story_id in story_ids:

    node_ids = graphs[story_id]['node_ids']
    id_to_node_name = dict((v,k) for k,v in node_ids.items())
    num_nodes = len(node_ids)

    if method not in graphs[story_id]:
        continue
    weights = graphs[story_id][method]
    b4_and_after_info = graphs[story_id]['before_and_after_nodes']
    ending_nodes = b4_and_after_info[str(len(b4_and_after_info)-1)]["after"]

    # create adjacency matrix
    adj_matrix = np.zeros((num_nodes+1, num_nodes+1)) # +1 for starting and ending node
    for e in weights:
        e_tup = literal_eval(e)
        source, dest = e_tup[0], e_tup[1]
        w = weights[e]
        adj_matrix[source, dest] = -w

    # connect end nodes with last dummy node
    adj_matrix[ending_nodes, adj_matrix.shape[1]-1] = -99

    starting_node = [node_ids[x] for x in node_ids if x.split('_')[0] == "theme"][0] # starting node is first theme node

    if search_method == "BF":

        D, Pr = shortest_path(adj_matrix, directed=True, method='BF', return_predecessors=True, indices = starting_node) # FW, BF or J
        Pr = np.expand_dims(Pr, axis=0)
        path = get_path(Pr, 0, adj_matrix.shape[1]-1)

    elif search_method == "J": # requires all graphs to be negatively weighted

        D, Pr = shortest_path(adj_matrix, directed=True, method='J', return_predecessors=True, indices = starting_node) # FW, BF or J
        Pr = np.expand_dims(Pr, axis=0)
        path = get_path(Pr, 0, adj_matrix.shape[1]-1)

    elif search_method == "FW": # floyd warshall
        D, Pr = shortest_path(adj_matrix, directed=True, method='FW', return_predecessors=True) # FW, BF or J
        path = get_path(Pr, starting_node, adj_matrix.shape[1]-1)

    lines = [id_to_node_name[n].split("_")[1] for n in path[:-1] \
             if id_to_node_name[n].split("_")[0] != "theme"] #theme node is part of the storyline

    storylines[story_id] = lines

    count += 1
    if (count % 100) == 0:
        print(count)

In [None]:
with open('/content/{}_pmi_srl.json'.format(caption_type), 'w') as f:
    json.dump(storylines, f)