In [1]:
import sys
sys.path.insert(0, "/notebooks/pipenv")
sys.path.insert(0, "/notebooks/nebula3_vlm")
sys.path.insert(0, "/notebooks/nebula3_database")
sys.path.insert(0, "/notebooks/")
import os
import math
import random
import bisect
import pickle
import time
import numpy as np


In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import urllib
import subprocess
import re
import tempfile
import itertools
import torch
import spacy
from torch.nn.functional import softmax as torch_softmax
from sumproduct import Variable, Factor, FactorGraph

from typing import List, Tuple
from operator import itemgetter 
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, BertTokenizer, BertForNextSentencePrediction
from database.arangodb import DatabaseConnector
from config import NEBULA_CONF
from movie_db import MOVIE_DB


In [3]:
class PIPELINE:
    def __init__(self):
        config = NEBULA_CONF()
        self.db_host = config.get_database_host()
        self.database = config.get_playground_name()
        self.gdb = DatabaseConnector()
        self.db = self.gdb.connect_db(self.database)

pipeline = PIPELINE()
mdb = MOVIE_DB()
from vlm.clip_api import CLIP_API
clip=CLIP_API('vit')
s2_collection_name = 's2_pipeline_after_gpt'
s2_results_orig_collection_name = 's2_pipeline_optim_orig'
s2_results_relaxed_collection_name = 's2_pipeline_optim_relaxed'
s2_compatibility_collection_name = 's2_pipeline_compatibility_scores'

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
device = "cuda:0"
model = model.to(device)

In [None]:
def create_2chain_graph(factors):
    g = FactorGraph(silent=True)  # init the graph without message printouts
    num_vars = len(factors)+1
    vars = []
    vnames = []
    gvars = []
    for i in range(len(factors)-1):
        assert factors[i].shape[1] == factors[i+1].shape[0]
        vars.append(factors[i].shape[0])
    vars.append(factors[-1].shape[0])
    vars.append(factors[-1].shape[1])
    for i, v_size in enumerate(vars):
        vname = 'x'+str(i+1)
        v = Variable(vname, v_size)
        vnames.append(vname)
        gvars.append(v)

    for i in range(len(gvars)-1):
        fname = 'f{}{}'.format(i+1, i+2)
        # factors are transposed, from x2 to x1, etc'
        fact = Factor(fname, factors[i].transpose())
        g.add(fact)
        g.append(fname, gvars[i+1])
        g.append(fname, gvars[i])

    return g, vnames


def create_3chain_graph(factors):
    g = FactorGraph(silent=True)  # init the graph without message printouts
    num_vars = len(factors)+2
    vars = []
    vnames = []
    gvars = []
    for i in range(len(factors)-2):
        assert factors[i].shape[1] == factors[i+1].shape[0]
        assert factors[i].shape[2] == factors[i+1].shape[1]
        assert factors[i].shape[2] == factors[i+2].shape[0]
        vars.append(factors[i].shape[0])
    vars.append(factors[-2].shape[0])
    vars.append(factors[-2].shape[1])
    vars.append(factors[-2].shape[2])
    vars.append(factors[-1].shape[2])
    for i, n in enumerate(vars):
        vname = 'x'+str(i+1)
        v = Variable(vname, n)
        vnames.append(vname)
        gvars.append(v)
    for i in range(len(gvars)-2):
        fname = 'f{}{}{}'.format(i+1, i+2, i+3)
        fact = Factor(fname, factors[i].transpose(
            2, 1, 0))     # factors are transposed
        g.add(fact)
        g.append(fname, gvars[i+2])
        g.append(fname, gvars[i+1])
        g.append(fname, gvars[i])

    return g, vnames

def compute_marginals(factors, chain_creator):
    g, vnames = chain_creator(factors)
    g.compute_marginals(max_iter=15500, tolerance=1e-8)
    rc = []
    for vname in vnames:
        rc.append(g.nodes[vname].marginal())
    return rc

def compute_2chain_marginals(factors):
    return compute_marginals(factors, create_2chain_graph)


def compute_3chain_marginals(factors):
    return compute_marginals(factors, create_3chain_graph)

# def compute_2chain_marginals_orig(factors):
#     g, vnames = create_2chain_graph(factors)
#     g.compute_marginals(max_iter=15500, tolerance=1e-8)
#     rc = []
#     for vname in vnames:
#         rc.append(g.nodes[vname].marginal())
#     return rc

# Input: A list of list of strings
# Output: A list of list of scores

def story_compatability(scene1, scene2):
    rows_ = []
    for sent_a in scene1:
        cols_ = []
        for sent_b in scene2:
            encoded = tokenizer.encode_plus(sent_a, sent_b, return_tensors='pt').to(device)
            seq_relationship_logits = model(**encoded)[0]
            probs = torch_softmax(seq_relationship_logits, dim=1)
            score = probs[0][0].tolist()
            cols_.append(score)
        rows_.append(cols_)
    return(np.array(rows_))


def score_story(story):
    scenes_scores = []
    for idx in range(0, len(story) -1):
        scene1 = story[idx]
        scene2 = story[idx + 1]
        scene_matrix = story_compatability(scene1, scene2)
        scenes_scores.append(scene_matrix)

    return scenes_scores
    
    


In [None]:
flatten = lambda lst: [x for l in lst for x in l]
softmax = lambda x: np.exp(x)/sum(np.exp(x))
normalize = lambda x: (x - np.mean(x)) / np.std(x)



In [None]:
query = 'FOR doc IN {} RETURN doc'.format(s2_results_orig_collection_name)
cursor = pipeline.db.aql.execute(query)
all_docs = list(cursor)


In [9]:
movies = set([x['movie_id'] for x in all_docs])
all_movies = {}

for mid in movies:
    story = []
    elements = sorted([x for x in all_docs if x['movie_id'] == mid],key=lambda y:y['scene_element'])
    all_movies[mid] = elements
    # n = len(elements[0]['sentences'])    # Number of copies
    # for i in range(n):
    #     story=[elem['sentences'][i] for elem in elements]
    #     for (elem,story_part) in zip(elements,story):
    #         elem['sentences'][i]=zip(elem['sentences'][i],story[i])     # Update scene_element with scores
        


In [18]:
story_obj = all_movies[list(all_movies.keys())[0]]
story = [x['sentences'][0] for x in story_obj]
rc = score_story(story)

In [23]:
len(story[0])

4