In [None]:
def get_cap_and_imgs(stories, story_id, img2cap, cap_type = "VIST", display = False):
    """
    Retrieve the image names and corresponding captions
    for given story id.
    """

    if str(story_id) not in stories:
        print("This story id does not exist.")
        return

    story = stories[str(story_id)]['story']
    images = stories[str(story_id)]['images']

    image_formats = ['.jpg', '.gif', '.png', '.bmp']
    image_list = []
    cap_list = []

    for img in images:

        for f in image_formats:
            if cap_type == "VIST":
                img_name = img
            else:
                img_name = img + f
            try:
                caption = img2cap[img_name]
                if display == True:
                    print(img_name, ":" , caption)
                image_list.append(img_name)
                if cap_type == "CLIP":  # get rid of last full stop for CLIP captions
                    if caption[-1] == ".":
                        cap_list.append(caption[:-1])
                    else:
                        cap_list.append(caption)
                else:
                    cap_list.append(caption)
            except:
                continue

    return image_list, cap_list

In [None]:
split = "test"
caption_type = "CLIP"

cks = json.load(open("/content/{}_{}_comet_ck.json".format(caption_type, split)))
stories = json.load(open("/content/{}_stories.json".format(split)))
img2cap = json.load(open("/content/{}_captions.json".format(caption_type)))
vist_cks = json.load(open("/content/Common Sense Dicts/cks_VIST_{}.json".format(split)))

In [None]:
rels = ["AtLocation", "CapableOf", "xNeed", "xIntent", "xWant", "xEffect", "xReact", "xAttr"]

cks_dict = {}
problem_stories = []

for story_id in vist_cks:
    _, cap_list = get_cap_and_imgs(stories, story_id, img2cap, cap_type = "BLIP")
    # cap_list = list(set(cap_list))
    if len(cap_list) != 5:
        problem_stories.append(story_id)
        continue
    temp = {} # [img_num : rels]
    for i in range(0, len(cap_list)):
        img_num = i
        image_ck = {} # rel: [ck1, ck2...]
        for rel in rels:
            key = cap_list[i] + " " + rel + " " + "[GEN]"
            common_sense = cks[key]
            common_sense = [x.strip() for x in common_sense if x != " none"]
            image_ck[rel] = common_sense
        image_ck["caption"] = cap_list[i]
        temp[img_num] = image_ck
    cks_dict[story_id] = temp
print(len(problem_stories))
print(len(cks_dict))
print(len(vist_cks))

0
3385
3385


In [None]:
with open('cks_{}_{}.json'.format(caption_type, split), 'w') as f:
    json.dump(cks_dict, f)

shutil.move("/content/cks_{}_{}.json".format(caption_type, split),
            "/content/Common Sense Dicts/cks_{}_{}.json".format(caption_type, split))

'/content/drive/My Drive/PhD Project/VIST Model/VIST Model Data/Common Sense Dicts/cks_BLIP_test.json'

# Setup

In [None]:
!pip install -U sentence-transformers

In [None]:
from scipy import spatial
from sentence_transformers import SentenceTransformer
import numpy as np
import json
import shutil
import torch
import torch.nn.functional as F
import nltk
from nltk.stem.porter import PorterStemmer


nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('brown')
nltk.download('wordnet')
nltk.download('omw-1.4')

use_cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('We are using GPU.' if use_cuda else 'We are using CPU.')

We are using GPU.


In [None]:
### intialize sentence transformer
text_embedder = SentenceTransformer('all-mpnet-base-v2')
text_embedder = text_embedder.to(device)
text_embedder.eval()

# Create Story Graphs

In [None]:
split = "valid"
caption_type = "CLIP"

cks_dict = json.load(open( "/content/cks_{}_{}.json".format(caption_type, split)))
stories = json.load(open("/content/{}_stories.json".format(split)))
img2cap = json.load(open("/content/{}_captions.json".format(caption_type)))

In [None]:
def filter_out_sim_nodes(embs, cks, threshold = 0.50):

    # lower threshold --> filter out more nodes

    nodes = []
    final_embs = [embs[0]]
    nodes.append(cks[0])
    for i in range(1, len(embs)):
        found_sim = False
        for j in range(len(final_embs)): # see if there's a similar one already added
            sim = 1 - spatial.distance.cosine(embs[i], final_embs[j])
            if sim > threshold:
                found_sim = True
        if found_sim == False:
            nodes.append(cks[i])
            final_embs.append(embs[i])

    return nodes, final_embs

In [None]:
def join_b4_and_after_with_cap(node_ids, edges, cap_node_name, nodes, direction = "before"):
    ### join cap with before nodes and after nodes
    ids = []

    for i in range(len(nodes)):
        node_name = direction + "_" + nodes[i] + "_" + img_num
        node_ids[node_name] = len(node_ids)
        ids.append(node_ids[node_name])

        source, dest = node_ids[node_name], node_ids[cap_node_name]
        edges.append([source, dest])
        edges.append([dest, source])

    return node_ids, edges, ids

In [None]:
def clean_ck(cks):

    cleaned_cks = []
    for ck in cks:
        ck_tokens = ck.split(' ')
        if ck_tokens[0] == 'to':
            cleaned_cks.append(' '.join(ck_tokens[1:]))
        else:
            cleaned_cks.append(ck)

    return cleaned_cks

In [None]:
rels = ["xNeed", "xIntent", "xWant", "xEffect"]

idx = 0
graphs = {}
MAX_NUM_EMBS = 50

story_ids = list(cks_dict.keys())
print(len(story_ids))
story_ids = sorted([int(x) for x in story_ids])
story_ids = [str(x) for x in story_ids]

for story_id in story_ids:

    node_ids = {}
    edges = []
    cap_b4_and_after_nodes = {} #img_num: {before: ids, after: ids}
    final_embs = []

    for img_num in range(len(cks_dict[story_id])):
        img_num = str(img_num)
        intents = ["wants " + x for x in cks_dict[story_id][img_num]["xIntent"]]
        # wants = ["wants " + x for x in cks_dict[story_id][img_num]["xWant"]]
        before_cks = cks_dict[story_id][img_num]["xNeed"] + intents
        after_cks = cks_dict[story_id][img_num]["xWant"] + cks_dict[story_id][img_num]["xEffect"]

        # clean cks
        before_cks = clean_ck(before_cks)
        after_cks = clean_ck(after_cks)

        before_embs = text_embedder.encode(before_cks)
        before_nodes, before_embs = filter_out_sim_nodes(before_embs, before_cks)

        after_embs = text_embedder.encode(after_cks) # might need to clean common sense here
        after_nodes, after_embs = filter_out_sim_nodes(after_embs, after_cks)

        image_cap = cks_dict[story_id][img_num]["caption"]
        cap_emb = text_embedder.encode(image_cap)

        cap_node_name = "cap_" + image_cap + "_" + img_num
        node_ids[cap_node_name] = len(node_ids)

        node_ids, edges, before_ids = join_b4_and_after_with_cap(node_ids, edges, cap_node_name, before_nodes, direction = "before")
        node_ids, edges, after_ids = join_b4_and_after_with_cap(node_ids, edges, cap_node_name, after_nodes, direction = "after")

        temp = {"before": before_ids, "after": after_ids}
        cap_b4_and_after_nodes[img_num] = temp

        final_embs.append(cap_emb)
        final_embs.append(before_embs)
        final_embs.append(after_embs)

    # connect cap i-1 after nodes with cap i before nodes
    for i in range(len(cap_b4_and_after_nodes)-1):
        cap_after = cap_b4_and_after_nodes[str(i)]["after"]
        next_cap_before = cap_b4_and_after_nodes[str(i+1)]["before"]
        for source in cap_after:
            for dest in next_cap_before:
                edges.append([source, dest])
                edges.append([dest, source])

    temp = {}
    temp["node_ids"] = node_ids
    temp["edges"] = edges
    temp["before_and_after_nodes"] = cap_b4_and_after_nodes
    graphs[story_id] = temp

    idx += 1
    if (idx % 100) == 0:
        print(idx)

3354
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300


In [None]:
with open('story_graph_{}_{}.json'.format(caption_type, split), 'w') as f:
    json.dump(graphs, f)

# Save Node Phrase Embeddings

In [None]:
caption_type = "CLIP"

graph = json.load(open("/content/{}_story_graph.json".format(caption_type)))

path_to_stories = "/content/Data"
train_stories = json.load(open("{}/{}_stories.json".format(path_to_stories,"train")))
valid_stories = json.load(open("{}/{}_stories.json".format(path_to_stories,"valid")))
test_stories = json.load(open("{}/{}_stories.json".format(path_to_stories,"test")))

stories = {**train_stories, **valid_stories, **test_stories}
node_embs = np.memmap('{}_node_embs.dat'.format(caption_type), dtype='float32', mode='w+', shape=(len(graph), 71, 768))
print(node_embs.shape)

(33691, 71, 768)


In [None]:
### find the maximum number of nodes
max_nodes = 0

for story_id in graph:
    num_nodes = len(graph[story_id]["node_ids"])
    if num_nodes >= max_nodes:
        max_nodes = num_nodes
max_nodes

71

In [None]:
%%time
MAX_NUM_EMBS = max_nodes
story_to_ind = {}
idx = 0

for story_id in graph:

    story_to_ind[story_id] = idx
    node_ids = graph[story_id]["node_ids"]
    node_ids = sorted(node_ids.items(), key=lambda x: x[1])
    node_names = [x[0].split('_')[1] for x in node_ids]
    embs = text_embedder.encode(node_names)

    if MAX_NUM_EMBS > embs.shape[0]:
        pad_len = MAX_NUM_EMBS - embs.shape[0]
        padding = np.zeros((pad_len, 768))
        embs = np.vstack((embs, padding))

    node_embs[idx]= embs
    idx += 1
    if (idx % 1000) == 0:
        print(idx)

1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
CPU times: user 33min 33s, sys: 34.3 s, total: 34min 8s
Wall time: 32min 41s


In [None]:
### save dictionary storying story id to node index mapping

with open('{}_node_inds.json'.format(caption_type), 'w') as f:
    json.dump(story_to_ind, f)

CPU times: user 2.87 s, sys: 6.48 s, total: 9.35 s
Wall time: 38.3 s


'/content/drive/My Drive/PhD Project/VIST Model/VIST Model Data/VIST_node_embs.dat'

# Create Directed Story Graph

In [None]:
caption_type = "CLIP"
split = "valid"

graphs = json.load(open("/content/story_graph_{}_{}.json".format(caption_type, split)))
len(graphs)

In [None]:
### make into directed graph
final_graphs = {}

for story_id in graphs:

    node_ids = graphs[story_id]['node_ids']
    cap_node_ids = [node_ids[x] for x in node_ids if x.split('_')[0] == "cap"]
    b4_and_after = graphs[story_id]['before_and_after_nodes']

    new_edges = []

    for i in range(len(b4_and_after)):
        before_nodes = b4_and_after[str(i)]['before']
        # connect before nodes with after node of previous image
        if i != 0:
            for node1 in after_nodes: # from previous image
                for node2 in before_nodes: # from current image
                    new_edges.append([node1, node2])
        # connect before nodes with cap node of current image
        for node in before_nodes:
            new_edges.append([node, cap_node_ids[i]])
        # connect cap node with after node of current image
        after_nodes = b4_and_after[str(i)]['after']
        for node in after_nodes:
            new_edges.append([cap_node_ids[i], node])

    temp = {}
    temp["node_ids"] = node_ids
    temp["edges"] = new_edges
    temp["before_and_after_nodes"] = b4_and_after

    final_graphs[story_id] = temp

assert len(final_graphs) == len(graphs)

In [None]:
with open('/content/story_graph_{}_{}_directed.json'.format(caption_type, split), 'w') as f:
    json.dump(final_graphs, f)

# Adding Theme Nodes

In [None]:
img2concepts = json.load(open("/content/img_concepts.json"))

img_concepts = {}

for key in img2concepts:
    img_name = key.split(".")[0]
    img_concepts[img_name] = img2concepts[key]

In [None]:
caption_type = "CLIP"
train_graphs = json.load(open("/content/story_graph_{}_train_directed.json".format(caption_type)))
valid_graphs = json.load(open("/content/story_graph_{}_valid_directed.json".format(caption_type)))
test_graphs = json.load(open("/content/story_graph_{}_test_directed.json".format(caption_type)))

graphs = {**train_graphs, **valid_graphs, **test_graphs}
len(graphs)

CPU times: user 3.57 s, sys: 402 ms, total: 3.97 s
Wall time: 4.05 s


33678

In [None]:
path_to_stories = "/content"
train_stories = json.load(open("{}/{}_stories.json".format(path_to_stories,"train")))
valid_stories = json.load(open("{}/{}_stories.json".format(path_to_stories,"valid")))
test_stories = json.load(open("{}/{}_stories.json".format(path_to_stories,"test")))

stories = {**train_stories, **valid_stories, **test_stories}

In [None]:
final_graphs = {}
ignore_concepts = set(["person", "no person", "people", "no people", "adult", "child", "man",
                       "woman", "girl", "boy", "wear", "group", "many", "group together"])
count = 0

for story_id in graphs:
    image_seq = stories[story_id]['images']
    node_ids = graphs[story_id]['node_ids'].copy()
    cap_nodes_ids = [node_ids[x] for x in node_ids if x.split("_")[0] == "cap"]
    b4_and_after_nodes = graphs[story_id]['before_and_after_nodes']

    for i in range(len(image_seq)):
        img = image_seq[i]
        if img in img_concepts:
            theme_concepts = img_concepts[img]
            theme_concepts = [x for x in theme_concepts if x not in ignore_concepts]
            node_name = "theme_" + " ".join(theme_concepts) + "_" + str(i)
            node_ids[node_name] = len(node_ids)
        else:
            node_name = "theme_NULL_" + str(i)
            node_ids[node_name] = len(node_ids)
            print(story_id)

    theme_nodes_ids = [node_ids[x] for x in node_ids if x.split("_")[0] == "theme"]
    ### connect edges
    edges = []
    for i in range(len(b4_and_after_nodes)):
        before_nodes = b4_and_after_nodes[str(i)]["before"]
        after_nodes = b4_and_after_nodes[str(i)]["after"]
        for j in range(len(before_nodes)):
            # theme node --> before nodes
            edges.append([theme_nodes_ids[i], before_nodes[j]])
            # before nodes --> cap
            edges.append([before_nodes[j], cap_nodes_ids[i]])
        for j in range(len(after_nodes)):
            # cap --> after nodes
            edges.append([cap_nodes_ids[i], after_nodes[j]])
            # after nodes --> theme nodes
            if i != len(b4_and_after_nodes)-1:
                edges.append([after_nodes[j], theme_nodes_ids[i+1]])

    temp_graph = {}
    temp_graph["node_ids"] = node_ids
    temp_graph["edges"] = edges
    temp_graph["before_and_after_nodes"] = b4_and_after_nodes

    final_graphs[story_id] = temp_graph

    count += 1
    if (count % 10000) == 0:
        print(count)

10000
20000
30000


In [None]:
with open('/content/{}_story_graph.json'.format(caption_type), 'w') as f:
    json.dump(final_graphs, f)

# Getting Graph Weights

In [None]:
caption_type = "CLIP"

graphs = json.load(open("/content/{}_story_graph.json".format(caption_type)))

In [None]:
path_to_stories = "/content"
train_stories = json.load(open("{}/{}_stories.json".format(path_to_stories,"train")))
valid_stories = json.load(open("{}/{}_stories.json".format(path_to_stories,"valid")))
test_stories = json.load(open("{}/{}_stories.json".format(path_to_stories,"test")))

stories = {**train_stories, **valid_stories, **test_stories}
pmi_stories = {**train_stories, **valid_stories}

In [None]:
def combine_stories(stories, pmi_stories, vst = True):


    for story_id in stories:
        if vst == True:
            pmi_stories[len(pmi_stories)] = stories[story_id]['story']
        else:
            pmi_stories[len(pmi_stories)] = stories[story_id]
    return pmi_stories

pmi_stories = {}
pmi_stories = combine_stories(train_stories, pmi_stories)
pmi_stories = combine_stories(valid_stories, pmi_stories)
print(len(pmi_stories))

45145


## PMI Weights

In [None]:
stemmer = PorterStemmer()

word_freqs = {}
cooccurence_freqs = {}
all_sents = []

count = 0
for story_id in pmi_stories:
    story_sents = pmi_stories[story_id] # ['story']
    for sent in story_sents:
        all_sents.append(sent)
        word_tokens = nltk.word_tokenize(sent)
        word_tokens = [stemmer.stem(token) for token in word_tokens]
        word_tokens = list(set(word_tokens))
        for word in word_tokens:
            if word not in word_freqs:
                word_freqs[word] = 1
            else:
                word_freqs[word] += 1
        # record co-occurence counts
        pairs = [(a, b) for idx, a in enumerate(word_tokens) for b in word_tokens[idx + 1:]]
        pairs = list(set(pairs))
        for pair in pairs:
            if pair not in cooccurence_freqs:
                cooccurence_freqs[pair] = 1
            else:
                cooccurence_freqs[pair] += 1
    count += 1
    if (count % 10000) == 0:
        print(count)

10000
20000
30000
40000


In [None]:
def calc_pmi(word_i, word_j, word_freqs, cooccurence_freqs, all_sents):

    W = len(all_sents)
    word_i = stemmer.stem(word_i)
    word_j = stemmer.stem(word_j)

    if word_i not in word_freqs or word_j not in word_freqs:
        return -1 # never occuring together

    wi = word_freqs[word_i]
    wj = word_freqs[word_j]

    wij = 0
    if (word_i, word_j) in cooccurence_freqs:
        wij += cooccurence_freqs[(word_i, word_j)]
    if (word_j, word_i) in cooccurence_freqs:
        wij += cooccurence_freqs[(word_j, word_i)]

    if wij == 0:
        return -1

    pi = wi/W
    pj = wj/W
    pij = wij/W
    pmi = np.log(pij/(pi * pj))

    normalized_pmi = pmi/-np.log(pij)

    return normalized_pmi

In [None]:
%%time

story_ids = list(graphs.keys())
count = 0
# story_ids = ['45575']
for story_id in story_ids:

    b4_and_after_info = graphs[story_id]['before_and_after_nodes']
    edges = graphs[story_id]['edges']
    node_ids = graphs[story_id]['node_ids']
    id_to_node_name = dict((v,k) for k,v in node_ids.items())

    edge_weights = {}
    for e in edges:
        source, dest = e[0], e[1]
        source_name = id_to_node_name[source].split("_")[1]
        dest_name = id_to_node_name[dest].split("_")[1]

        source_tokens = nltk.word_tokenize(source_name)
        dest_tokens = nltk.word_tokenize(dest_name)

        pmi_list = []
        for word_i in source_tokens:
            for word_j in dest_tokens:
                pmi = calc_pmi(word_i, word_j, word_freqs, cooccurence_freqs, all_sents)
                pmi_list.append(pmi)

        max_weight = max(pmi_list)
        print(source_name, " | ",  dest_name, " | ", max_weight)
        edge_weights[str((source, dest))] = max_weight

    count += 1
    if (count % 100) == 0:
        print(count)

    graphs[story_id]['pmi_edge_weights'] = edge_weights

In [None]:
with open('/content/{}_story_graph.json'.format(caption_type), 'w') as f:
    json.dump(graphs, f)

## Cosine Weights

In [None]:
# download graph phrase node embeddings: https://drive.google.com/drive/folders/1pco6VIbPlhOqoIkMYh33npq_NjMubQUc
!wget --header="Host: doc-0k-50-docs.googleusercontent.com" --header="User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" --header="Cookie: AUTH_kalqqtkfvspenavrni8dvae4pqdav05i_nonce=2b2eo12n628bm" --header="Connection: keep-alive" "https://doc-0k-50-docs.googleusercontent.com/docs/securesc/7kgsp37ueou2iob822ka25iskagriljd/vjor04auljqam5e8jktlbqidvnsbovdv/1665131325000/01091172001417006453/01091172001417006453/1-4yJoG14fCsZLddxDVfg8SUbZaZQnYAy?e=download&ax=ALW9-sCohthSFwKsT6XlD0yDpAG1OjI7cgpxaRRS8YfgzUa8JIEW9PQFdNEInmTU9rlbyw3yQ2RavflOYmQqEm_GvISnZOAxrzt6d8OjNgi6xM4tK5wPzwIv6SVIosm338M8z4J3aSZXMRO0drF-2CWQHmnaBp5V1R5T5OqxnlV7_145vmnIP18IivW5tJeH6WtBBddcIHNdskeeJJAImRUZqy96oFwAca7YFdi3dL1etjV80ZC74RsqKUfK7q7Oyzq5QDEigruOI3ivmWF_hwKVeKPjrs46BSEPJ-DjGwk3-6zv2L3Dmy3RO_4Y2wBzMX_6G8_CJRyBrlmkYpV4BeZ_CBXxKB-ljapBQQ-sz-Ul04tcEI7vFfGTNmINvMtcOWDidAihWk2BAQZ7xfe5Tg5i7ONy66-I63Sg3nNOfTSvaalSMyEGNQCcLF3HpyjJuJrZscH9B6wtpZ1XGREBi0m_rcgOrGtd0WDKDsNDW0EgOy1c-kJ3UF42bqZ9RffeC6zFKXw_XYgx1OabZFjoDLKYSdNQK0lJfA77kItFzh6aCjsMBgwH6_40UdYuTAiBhihL_UW0NYlvLhtBvu-KdJdPwAgXyLMO7ql1sU8XhMxYBFVCCmQk1SG0ZdcNgzfTnkzsWDYBWWnAe9HbYBQMLLMBlCHiB5pbW2mvILfU2s6wbg6FbGC5BFYRtT87uAX99Bpi72EpEhFQE5fW3BKrvHzkewMB3x4Vss6oCSyQp1xwq2sIXjowIeCEY2RwdBoGomNMhlY9ML5ij_ZMyY_GmCcfov3VfOiRX8FkNm5Pod0TAqc-PVROtMLFIvkIfbi48Y4ZJbJaldY4LScQh8_KjdO334smjv-ETaMX_Sa7foT5xHVphfB_CPywqoPI&uuid=31f5e461-4cc1-493b-9b7b-2e3300aca038&authuser=0&nonce=2b2eo12n628bm&user=01091172001417006453&hash=6ulrt9n8g1of7c7hk35m1dh0g9g2ul63" -c -O 'VIST_node_embs.dat'

--2022-10-07 08:32:29--  https://doc-0k-50-docs.googleusercontent.com/docs/securesc/7kgsp37ueou2iob822ka25iskagriljd/vjor04auljqam5e8jktlbqidvnsbovdv/1665131325000/01091172001417006453/01091172001417006453/1-4yJoG14fCsZLddxDVfg8SUbZaZQnYAy?e=download&ax=ALW9-sCohthSFwKsT6XlD0yDpAG1OjI7cgpxaRRS8YfgzUa8JIEW9PQFdNEInmTU9rlbyw3yQ2RavflOYmQqEm_GvISnZOAxrzt6d8OjNgi6xM4tK5wPzwIv6SVIosm338M8z4J3aSZXMRO0drF-2CWQHmnaBp5V1R5T5OqxnlV7_145vmnIP18IivW5tJeH6WtBBddcIHNdskeeJJAImRUZqy96oFwAca7YFdi3dL1etjV80ZC74RsqKUfK7q7Oyzq5QDEigruOI3ivmWF_hwKVeKPjrs46BSEPJ-DjGwk3-6zv2L3Dmy3RO_4Y2wBzMX_6G8_CJRyBrlmkYpV4BeZ_CBXxKB-ljapBQQ-sz-Ul04tcEI7vFfGTNmINvMtcOWDidAihWk2BAQZ7xfe5Tg5i7ONy66-I63Sg3nNOfTSvaalSMyEGNQCcLF3HpyjJuJrZscH9B6wtpZ1XGREBi0m_rcgOrGtd0WDKDsNDW0EgOy1c-kJ3UF42bqZ9RffeC6zFKXw_XYgx1OabZFjoDLKYSdNQK0lJfA77kItFzh6aCjsMBgwH6_40UdYuTAiBhihL_UW0NYlvLhtBvu-KdJdPwAgXyLMO7ql1sU8XhMxYBFVCCmQk1SG0ZdcNgzfTnkzsWDYBWWnAe9HbYBQMLLMBlCHiB5pbW2mvILfU2s6wbg6FbGC5BFYRtT87uAX99Bpi72EpEhFQE5fW3BKrvHzkewMB3x4Vss6oCSyQp1

In [None]:
graph_data_path = "/content"
node_inds = json.load(open(graph_data_path + "/{}_node_inds.json".format(caption_type)))
node_embs = np.memmap("/content/{}_node_embs.dat".format(caption_type),
                            dtype='float32', mode='r+', shape=(len(graphs), 71, 768))
print(node_embs.shape)

(33691, 71, 768)


In [None]:
count = 0
story_ids = list(graphs.keys())
# story_ids = ['48121']
for story_id in story_ids:

    # b4_and_after_info = graphs[story_id]['before_and_after_nodes']
    node_ids = graphs[story_id]['node_ids']
    id_to_node_name = dict((v,k) for k,v in node_ids.items())
    edges = graphs[story_id]['edges']
    ind = node_inds[story_id]

    edge_weights = {}
    for e in edges:
        source_id, dest_id = e[0], e[1]
        source_emb = node_embs[ind][source_id]
        dest_emb = node_embs[ind][dest_id]
        sim = 1 - spatial.distance.cosine(source_emb, dest_emb)
        edge_weights[str((source_id, dest_id))] = sim

    graphs[story_id]['cosine_sim_weights'] = edge_weights
    count += 1
    if (count % 1000) == 0:
        print(count)

1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000


In [None]:
with open('/content/{}_story_graph.json'.format(caption_type), 'w') as f:
    json.dump(graphs, f)