In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [26]:
from FilesToUploadToColab.experiments import DirectTemplate, EnumeratedTemplate, Experiment
from pytorch_pretrained_bert import BertForMaskedLM, GPT2LMHeadModel
import pandas as pd
from copy import deepcopy
import numpy as np

In [3]:
bert = BertForMaskedLM.from_pretrained('bert-large-uncased')
gpt = GPT2LMHeadModel.from_pretrained('gpt2')

In [54]:
class KnowledgeMiner:
    def __init__(self, dev_data_path, device, Template, bert, template_loc = None, language_model = None):
        self.sentences_mask_tail = Template(
            dev_data_path, 
            False, 
            True, 
            device, 
            template_loc = template_loc, 
            language_model = language_model
        )
        self.sentences_mask_head = Template(
            dev_data_path, 
            True, 
            False, 
            device, 
            template_loc = template_loc, 
            language_model = language_model
        )
        self.sentences_mask_both = Template(
            dev_data_path,
            True,
            True, 
            device,
            template_loc = template_loc, 
            language_model = language_model
        )
        
        bert.eval()
        bert.to(device)
        self.bert = bert

        self.device = device
        self.results = []
    
    def make_predictions(self):
        data = []
        for idx, ((sent, masked_tail, ids, label), (_, masked_head, _, _), (_, masked_both, _, _)) \
                in enumerate(zip(self.sentences_mask_tail, self.sentences_mask_head, self.sentences_mask_both)):
            tail_masked_ids = [idx for idx, token in enumerate(masked_tail) if token == 103]
            head_masked_ids = [idx for idx, token in enumerate(masked_head) if token == 103]

            # conditional
            logprob_tail_conditional = self.predict(sent, masked_tail, ids, tail_masked_ids)
            logprob_head_conditional = self.predict(sent, masked_head, ids, head_masked_ids)
            # marginal
            logprob_tail_marginal = self.predict(sent, masked_both, ids, tail_masked_ids)
            logprob_head_marginal = self.predict(sent, masked_both, ids, head_masked_ids)

            NLL = -logprob_tail_conditional/len(tail_masked_ids)

            mutual_inf = logprob_tail_conditional - logprob_tail_marginal
            mutual_inf += logprob_head_conditional - logprob_head_marginal
            mutual_inf /= 2.
            try:
                print(idx, (NLL.item(), mutual_inf.item(), label, self.sentences_mask_tail.id_to_text(sent)))
                data.append((NLL.item(), logprob_tail_conditional.item(), logprob_tail_marginal.item(),
                             logprob_head_conditional.item(), logprob_head_marginal.item(),
                             mutual_inf.item(), label, self.sentences_mask_tail.id_to_text(sent)))
            except AttributeError:
                print(idx, (NLL, mutual_inf, label, self.sentences_mask_tail.id_to_text(sent)))
                data.append((NLL,  logprob_tail_conditional.item(), logprob_tail_marginal.item(),
                             logprob_head_conditional.item(), logprob_head_marginal.item(),
                             mutual_inf.item(), label, self.sentences_mask_tail.id_to_text(sent)))
        # prep data
        df = pd.DataFrame(data, columns = ('nll','tail_conditional','tail_marginal',
                                           'head_conditional','head_marginal','mut_inf','label','sent')) 
        self.results = df
        return df
    
    def predict(self, sent, masked, ids, masked_ids):
        print(masked_ids)
        logprob = 0
        masked = deepcopy(masked)
        masked_ids = masked_ids.copy()
        for _ in range(len(masked_ids)):
            # make prediction
            pred = self.bert(masked.reshape(1,-1),ids.reshape(1,-1)).log_softmax(2)
            
            # get log probs for each token
            max_log_prob = -np.inf
            
            for idx in masked_ids:
                if pred[0, idx, sent[idx]] > max_log_prob:
                    most_likely_idx = idx
                    max_log_prob = pred[0, idx, sent[idx]]
            
            logprob += max_log_prob
            masked[most_likely_idx] = sent[most_likely_idx]
            masked_ids.remove(most_likely_idx) 
            print(self.sentences_mask_tail.id_to_text(pred[0, most_likely_idx,:].topk(10)[1]))
            print(self.sentences_mask_tail.id_to_text(sent[most_likely_idx:most_likely_idx+1]))
        return logprob


In [59]:
ck_miner = KnowledgeMiner(
    '../FilesToUploadToColab/NovelTuples.csv', 
    'cpu', 
    EnumeratedTemplate, 
    bert, 
    language_model = gpt, 
    template_loc = '../FilesToUploadToColab/relation_map_multiple.json'
)

In [None]:
df = ck_miner.make_predictions()

[6, 7]
are feel look can do sit get become will stand
take
notes care comfort pictures notice control it note action orders
notice
[3]
do see notice look listen hear know watch work learn
sit
[6, 7]
are were feel do go can get would will become
take
it care risks control action everything advantage life them charge
notice
[3]
die do sleep can run fall live fight want come
sit
0 (4.850859642028809, 1.0750312805175781, 11, '[CLS] when you sit , you take notice . [SEP]')
[7]
burn grow heal survive rain dry die bleed soil weep
temperature
[3, 4]
cold hot warm here up out heat outside wet water
moisture
see find smell feel get have need taste lose want
soil
[7]
die run know fight live sleep think kill listen change
temperature
[3, 4]
were are get see want need got feel saw have
soil
yourself it ##ed themselves something myself them yourselves up someone
moisture
1 (11.350737571716309, 2.0678091049194336, 9, '[CLS] when you soil moisture , you temperature . [SEP]')
[9, 10]
up out high money 

[7]
write dream die wonder know believe laugh love live read
success
[3, 4]
were are get see have want feel got saw need
science
work learn fail show study test believe progress experiment play
fiction
[7]
die run know fight live sleep think kill listen change
success
[3, 4]
were are get see want need got feel saw have
science
class learn teach it work me fail yourself fiction people
fiction
15 (9.481706619262695, -0.5286884307861328, 9, '[CLS] when you science fiction , you success . [SEP]')
[6, 7]
buy sell get pay look make shop do go eat
post
orders advertisements flyers it them messages ads advertising cards posters
office
[3]
come leave arrive land pay write go call do sign
shop
[6, 7]
are were feel do go can get would will become
post
it messages posts them posters guards notes orders advertisements cards
office
[3]
die do sleep can run fall live fight want come
shop
16 (10.592296600341797, 0.9979822635650635, 38, '[CLS] when you shop , you post office . [SEP]')
[7]
die change ar

[7]
change die lose grow survive adapt do know disappear kill
warming
[3, 4]
happy hungry up good ready there here out together hard
change
do can really first ... did must truly actually all
climate
[7]
die run know fight live sleep think kill listen change
warming
[3, 4]
hungry tired it alone there angry up ready happy out
change
do can really first did ... need must could actually
climate
30 (9.85317611694336, 2.630490779876709, 63, '[CLS] when you climate change , you warming . [SEP]')
[7]
win play lose coach learn watch know shine fight change
team
[3, 4]
together out up teams one it team here there fighting
coach
play can are need call say become want have get
basketball
[7]
die run know fight live sleep think kill listen change
team
[3, 4]
hungry tired it alone there angry up ready happy out
coach
play were become are make lose need played see change
basketball
31 (7.0933074951171875, 4.0369391441345215, 7, '[CLS] when you basketball coach , you team . [SEP]')
[7]
swim drown run

see hear say have said are mean need get were
v8
[3]
die do sleep can run fall live fight want come
power
45 (12.894832611083984, 2.209825277328491, 7, '[CLS] when you power , you v8 engine . [SEP]')
[9, 10]
sleep eat die feed drink awake come live heal can
series
watch do make have win host get start see join
television
[12]
watch dance sing talk write audition listen work smile entertain
host
[9, 10]
old up older out young home hungry angry here it
series
write make do watch read create produce have like sell
television
[12]
think sleep talk forget relax work fight eat run listen
host
46 (10.711282730102539, 3.2334110736846924, 16, '[CLS] one of the things you do when you television series is host . [SEP]')
[6, 7]
it think idea do know believe something yourself act thinking
system
think use are say mean have get create do become
operating
[3]
run work write start call do boot die load drive
idea
[6, 7]
strong out stronger it better good free alive happy different
system
have use nee

[6, 7]
are become will feel must go leave do can fall
reform
yourself yourselves together me it them him again completely quickly
movement
[3]
reform win can fight change do go die succeed work
part
[6, 7]
are were feel do go can get would will become
reform
yourself them it quickly completely people again me him themselves
movement
[3]
die do sleep can run fall live fight want come
part
62 (11.425189971923828, -0.7350168228149414, 7, '[CLS] when you part , you reform movement . [SEP]')
[11]
variance matrix mean vector inverse density product function correlation kernel
ensemble
[4, 5, 6, 7]
group ensemble band one whole music individual orchestra song single
co
##s ##e ##o ##a ##es ##t ##y ##re ##us ##r
matrix
- ##hesion ##r ##hesive ##hn ##hort ##set ##spar ##rt ##g
##var
##iance ##icative ##ization ##iation ##iate ##ification ##iable ##ian ##ity ##ie
##iance
[11]
best man king lord same first one day better last
ensemble
[4, 5, 6, 7]
lord sun one man first world day heart light king