In [None]:
import sys
sys.path.insert(0, "/notebooks")
sys.path.insert(0, "/notebooks/pipenv")
sys.path.insert(0, "/notebooks/nebula3_database")
from torch.nn.functional import softmax
from transformers import BertTokenizer, BertForNextSentencePrediction
import torch
import random
import csv
import numpy as np
from sumproduct import Variable, Factor, FactorGraph
import pickle


In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
device = "cuda:0"
model = model.to(device)

In [None]:
def load_stories(dataset = 'roc'):
    if dataset ==  'roc':
        stories = []
        with open("1000ROC.csv") as file:
            tsv_file = csv.reader(file, delimiter=",")
            for line in tsv_file:
                #print(line)
                line.pop(0)
                #print(line)
                line.pop(0)
                stories.append(line)
    return(stories)

In [None]:
def create_stories(stories):
    stories_with_candidates = []
    for story in stories:
        story_length = len(story) - 1
        corpus_length = len(stories) -1
        scenes = []
        #print(orig_story[0])
        for sentence in story:
            candidates = []
            candidates.append(sentence)
            for candidate in range(0, random.randint(3, 10 )):
                cand_sent = random.randint(0, story_length)
                cand_story = random.randint(0, corpus_length)
                candidates.append(stories[cand_story][cand_sent])
            scenes.append(candidates)
        stories_with_candidates.append(scenes)
    return(stories_with_candidates)

In [None]:
def story_compatability(scene1, scene2):
    rows_ = []
    for sent_a in scene1:
        cols_ = []
        for sent_b in scene2:
            encoded = tokenizer.encode_plus(sent_a, sent_b, return_tensors='pt').to(device)
            seq_relationship_logits = model(**encoded)[0]
            probs = softmax(seq_relationship_logits, dim=1)
            score = probs[0][0].tolist()
            cols_.append(score)
        rows_.append(cols_)
    return(np.array(rows_))

In [None]:
stories = load_stories('roc')
candidated_stories = create_stories(stories)
stories_with_scores = []
for story in candidated_stories:
    scenes_scores = []
    for idx in range(0, len(story) -1):
        scene1 = story[idx]
        scene2 = story[idx + 1]
        scene_matrix = story_compatability(scene1, scene2)
        scenes_scores.append(scene_matrix)
    storie_with_scores = {
        'story': story,
        'scores': scenes_scores
    }

    stories_with_scores.append(storie_with_scores)
    

In [None]:
with open('roc1k.pickle', 'wb') as handle:
    pickle.dump(stories_with_scores, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
# for i in stories_with_scores:
#     print(i)

In [None]:
stories_with_scores_saved = pickle.load( open( "roc10.pickle", "rb" ) )

In [None]:
# factors: an array of 2d factor matrices, for x12, x23, ..., where dimensions are x1*x2, x2*x3, ...

def create_2chain_graph(factors):
    g = FactorGraph(silent=True)  # init the graph without message printouts
    num_vars = len(factors)+1
    vars = []
    vnames = []
    gvars = []
    for i in range(len(factors)-1):
        assert factors[i].shape[1] == factors[i+1].shape[0]
        vars.append(factors[i].shape[0])
    vars.append(factors[-1].shape[0])
    vars.append(factors[-1].shape[1])
    for i, v_size in enumerate(vars):
        vname = 'x'+str(i+1)
        v = Variable(vname, v_size)
        vnames.append(vname)
        gvars.append(v)

    for i in range(len(gvars)-1):
        fname = 'f{}{}'.format(i+1, i+2)
        # factors are transposed, from x2 to x1, etc'
        fact = Factor(fname, factors[i].transpose())
        g.add(fact)
        g.append(fname, gvars[i+1])
        g.append(fname, gvars[i])

    return g, vnames


def compute_2chain_marginals(factors):
    g, vnames = create_2chain_graph(factors)
    g.compute_marginals(max_iter=15500, tolerance=1e-8)
    rc = []
    for vname in vnames:
        rc.append(g.nodes[vname].marginal())
    return rc


In [None]:
for story in stories_with_scores_saved:
    rc = compute_2chain_marginals(story['scores'])
    print('Story with candidates (First sentence is each array is good)')
    print(story['story'])
    print('BERT Scores')
    print(story['scores'])
    print('Marginals')
    print(rc)