In [233]:
import sys
sys.path.insert(0, "/notebooks")
sys.path.insert(0, "/notebooks/pipenv")
sys.path.insert(0, "/notebooks/nebula3_database")
#from torch.nn.functional import softmax
#from transformers import BertTokenizer, BertForNextSentencePrediction
import torch
import random
import csv
import numpy as np
from sumproduct import Variable, Factor, FactorGraph
import pickle
import itertools

In [234]:
stories_with_scores_saved = pickle.load( open( "roc10.pickle", "rb" ) )
# stories_with_scores_saved = pickle.load( open( "roc1k_3sent.pickle", "rb" ) )

In [235]:
# factors: an array of 2d factor matrices, for x12, x23, ..., where dimensions are x1*x2, x2*x3, ...

def create_2chain_graph(factors):
    g = FactorGraph(silent=True)  # init the graph without message printouts
    num_vars = len(factors)+1
    vars = []
    vnames = []
    gvars = []
    for i in range(len(factors)-1):
        assert factors[i].shape[1] == factors[i+1].shape[0]
        vars.append(factors[i].shape[0])
    vars.append(factors[-1].shape[0])
    vars.append(factors[-1].shape[1])
    for i, v_size in enumerate(vars):
        vname = 'x'+str(i+1)
        v = Variable(vname, v_size)
        vnames.append(vname)
        gvars.append(v)

    for i in range(len(gvars)-1):
        fname = 'f{}{}'.format(i+1, i+2)
        # factors are transposed, from x2 to x1, etc'
        fact = Factor(fname, factors[i].transpose())
        g.add(fact)
        g.append(fname, gvars[i+1])
        g.append(fname, gvars[i])

    return g, vnames


def create_3chain_graph(factors):
    g = FactorGraph(silent=True)  # init the graph without message printouts
    num_vars = len(factors)+2
    vars = []
    vnames = []
    gvars = []
    for i in range(len(factors)-2):
        assert factors[i].shape[1] == factors[i+1].shape[0]
        assert factors[i].shape[2] == factors[i+1].shape[1]
        assert factors[i].shape[2] == factors[i+2].shape[0]
        vars.append(factors[i].shape[0])
    vars.append(factors[-2].shape[0])
    vars.append(factors[-2].shape[1])
    vars.append(factors[-2].shape[2])
    vars.append(factors[-1].shape[2])
    for i, n in enumerate(vars):
        vname = 'x'+str(i+1)
        v = Variable(vname, n)
        vnames.append(vname)
        gvars.append(v)
    for i in range(len(gvars)-2):
        fname = 'f{}{}{}'.format(i+1, i+2, i+3)
        fact = Factor(fname, factors[i].transpose(
            2, 1, 0))     # factors are transposed
        g.add(fact)
        g.append(fname, gvars[i+2])
        g.append(fname, gvars[i+1])
        g.append(fname, gvars[i])

    return g, vnames

def compute_marginals(factors, chain_creator):
    g, vnames = chain_creator(factors)
    g.compute_marginals(max_iter=15500, tolerance=1e-8)
    rc = []
    for vname in vnames:
        rc.append(g.nodes[vname].marginal())
    return rc

def compute_2chain_marginals(factors):
    return compute_marginals(factors, create_2chain_graph)


def compute_3chain_marginals(factors):
    return compute_marginals(factors, create_3chain_graph)

# def compute_2chain_marginals_orig(factors):
#     g, vnames = create_2chain_graph(factors)
#     g.compute_marginals(max_iter=15500, tolerance=1e-8)
#     rc = []
#     for vname in vnames:
#         rc.append(g.nodes[vname].marginal())
#     return rc


In [239]:
def story_score(factors, story):
    score = 1.
    for i in range(len(story)-1):
        score *= factors[i][story[i]][story[i+1]]
    return score

def story_score3(factors, story):
    score = 1.
    for i in range(len(story)-2):
        score *= factors[i][story[i]][story[i+1]][story[i+2]]
    return score

def produce_story(story,selection):
    return [story[i][n] for i,n in enumerate(selection)]

def check_naive(story_dict, nbest=1, scoring_func=story_score):
    story = story_dict['story']
    scores = story_dict['scores']
    sspace_indices = list(itertools.product(*[range(len(x)) for x in story]))
    sspace = {}
    for i, ind in enumerate(sspace_indices):
        sspace[ind] = scoring_func(scores, ind)
    total_scores = sum([v for (k, v) in sspace.items()])
    sspace_vals = list(sspace)
    sorted_solutions = list(reversed(sorted(sspace.items(), key=lambda x: x[1])))
    return (0,0,0,0,0) in [x[0] for x in sorted_solutions[:nbest]]

def check_non_naive(story_dict, correct_index=0):
    story = story_dict['story']
    scores = story_dict['scores']
    rc = compute_2chain_marginals(scores)
    site_scores = []
    for site in rc:
        site_score = 1. - (list(zip(*sorted(enumerate(site), key=lambda x: -x[1])))[0].index(correct_index) / len(site))
        site_scores.append(site_score)
    return site_scores



In [228]:
len(stories_with_scores_saved)

10

In [231]:
sum([check_naive(x,1,story_score3) for x in stories_with_scores_saved])

3

In [223]:
story_score3(stories_with_scores_saved[1]['scores'],[0,1,1,2,0])

3.323916982785391e-08

In [158]:
list(zip(*sorted(enumerate('adsffge'),key=lambda x:x[1])))[0].index(6)

2

In [196]:
np.mean(check_non_naive(stories_with_scores_saved[0],1))

0.5570634920634921

In [179]:
rc_naive = [check_naive(x,1) for x in stories_with_scores_saved]
sum(rc_naive)

452

In [240]:
rc_non_naive = [np.mean(check_non_naive(x,0)) for x in stories_with_scores_saved]
np.mean(rc_non_naive)

0.6983896103896104

In [188]:
site = [0.05,0.3,0.5,0.2,0.1]
1. - (list(zip(*sorted(enumerate(site),key=lambda x: x[1])))[0].index(0) / len(site))

1.0

In [None]:
for story in stories_with_scores_saved:
    rc = compute_2chain_marginals(story['scores'])
    print('Story with candidates (First sentence is each array is good)')
    print(story['story'])
    print('BERT Scores')
    print(story['scores'])
    print('Marginals')
    print(rc)

In [129]:
story_dict = stories_with_scores_saved[9]
story = story_dict['story']
scores = story_dict['scores']

In [130]:
sspace_indices = list(itertools.product(*[range(len(x)) for x in story]))
sspace = {}
for i,ind in enumerate(sspace_indices):
    sspace[ind] = story_score(scores,ind)
total_scores = sum([v for (k,v) in sspace.items()])
sspace_vals = list(sspace)
sorted_solutions = list(reversed(sorted(sspace.items(),key=lambda x:x[1])))

In [131]:
sorted_solutions

[((0, 0, 0, 5, 0), 0.9999724629354875),
 ((0, 0, 0, 0, 5), 0.9999648337103483),
 ((0, 0, 0, 5, 3), 0.9999554163748986),
 ((0, 0, 0, 5, 1), 0.9999548203413117),
 ((0, 0, 0, 0, 4), 0.9999537475014888),
 ((0, 0, 0, 5, 5), 0.9999474295248325),
 ((0, 0, 0, 0, 0), 0.99994683352177),
 ((0, 0, 0, 0, 1), 0.9999419460533481),
 ((0, 0, 0, 0, 3), 0.9999406347813324),
 ((0, 0, 0, 8, 2), 0.9999195351349266),
 ((0, 0, 0, 8, 0), 0.9999035622402839),
 ((0, 0, 1, 5, 0), 0.9998893757270385),
 ((2, 1, 6, 0, 5), 0.999879961220347),
 ((0, 0, 1, 5, 3), 0.9998723305828399),
 ((0, 0, 1, 5, 1), 0.999871734598777),
 ((2, 1, 6, 0, 4), 0.9998688759524347),
 ((0, 0, 1, 5, 5), 0.999864344396397),
 ((0, 0, 0, 8, 5), 0.9998630340001454),
 ((2, 1, 6, 0, 0), 0.9998619625595432),
 ((2, 1, 6, 0, 1), 0.9998570755059475),
 ((2, 1, 6, 0, 3), 0.9998557643452266),
 ((0, 0, 1, 8, 2), 0.9998486104718716),
 ((3, 2, 2, 5, 0), 0.9998440824837149),
 ((0, 0, 1, 8, 0), 0.9998326387101921),
 ((3, 2, 2, 7, 0), 0.999827156908132),
 ((3, 

In [122]:
compute_2chain_marginals(scores)


[array([0.14997653, 0.13676918, 0.12168154, 0.10509127, 0.14128036,
        0.11887558, 0.06258793, 0.12129306, 0.04244454]),
 array([1.23301777e-01, 2.31324856e-04, 9.44220932e-02, 1.71360964e-01,
        1.27430693e-01, 2.38384659e-01, 1.45387785e-01, 9.94807032e-02]),
 array([0.09297309, 0.05935452, 0.32097126, 0.06825633, 0.28592907,
        0.17251574]),
 array([9.40612965e-02, 4.10015940e-06, 4.08081775e-04, 9.05526522e-01]),
 array([0.09384289, 0.01125536, 0.44750831, 0.44739343])]

In [None]:
produce_story(story,sorted_solutions[0][0])