In [1]:
from abstractor.train import get_training_batch
from abstractor.utils import AbstractorModel
from abstractor.utils import obtain_initial_hidden_states
from bert.utils import obtain_sentence_embeddings
from bert.utils import obtain_word_embeddings
from data.utils import load_training_dictionaries
from extractor.utils import ExtractorModel
from pytorch_transformers import BertModel
from pytorch_transformers import BertTokenizer

import numpy as np
import torch

In [2]:
# Load data:
model = AbstractorModel()
model_path = "results/models/abstractor.pt"
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [3]:
data = load_training_dictionaries()
documents, extraction_labels = get_training_batch(data, batch_size=2)

In [4]:
source_documents, target_summaries = get_training_batch(data, 2)

# Obtain embeddings
source_document_embeddings, source_mask, source_tokens = obtain_word_embeddings(
    model.bert_model, model.bert_tokenizer, source_documents
)
target_summary_embeddings, target_mask, target_tokens = obtain_word_embeddings(
    model.bert_model, model.bert_tokenizer, target_summaries
)

for x in source_documents:
    print(x)
    print()

['that may sound like an esoteric adage , but when zully broussard selflessly decided to give one of her kidneys to a stranger , her generosity paired up with big data . it resulted in six patients receiving transplants .', 'that changed when a computer programmer named david jacobs received a kidney transplant . he had been waiting on a deceased donor list , when a live donor came along -- someone nice enough to give away a kidney to a stranger .']

['" weasels will go for anything that looks like food -- they \'ve got a high metabolism and they \'ve got to eat a lot , " she said . " it does n\'t surprise me that a weasel took a punt -- i \'ve seen a photo of a weasel charging a group of sparrows , they \'re very hungry animals . "', 'weasels would not normally target green woodpeckers , pacheco said -- their predators are normally the size of a stoat or larger . but the birds are known to spend a fair amount of time on the ground pulling up worms and hunting insects .', 'the pluckine

In [5]:
source_document_embeddings

tensor([[[-0.7868,  0.3158,  0.1873,  ...,  0.1196,  0.4806,  0.2568],
         [-1.2319,  0.1560,  0.0955,  ...,  0.3285,  0.2373,  0.1838],
         [-0.0765,  0.0044,  0.2287,  ..., -0.2080,  0.3268,  0.2944],
         ...,
         [-0.1307,  0.5610, -0.3290,  ...,  0.7418,  0.5953,  0.3729],
         [-0.1307,  0.5610, -0.3290,  ...,  0.7418,  0.5953,  0.3729],
         [-0.1307,  0.5610, -0.3290,  ...,  0.7418,  0.5953,  0.3729]],

        [[-0.7868,  0.3158,  0.1873,  ...,  0.1196,  0.4806,  0.2568],
         [-0.4989,  0.5350, -0.0610,  ...,  0.3245,  0.9572, -0.4521],
         [-0.4652, -0.8745, -1.1583,  ...,  0.5437,  0.3467, -0.2192],
         ...,
         [-0.7139,  0.2176,  0.1208,  ...,  0.7653,  0.6087,  0.0047],
         [-0.8020,  0.1763, -0.1569,  ...,  0.2405,  0.0256,  0.2617],
         [ 0.8146,  0.0918, -0.2448,  ...,  0.1399, -0.6142, -0.4597]]])

In [6]:
# Obtain extraction probability for each word in vocabulary
extraction_probabilities, teacher_forcing = model(
    source_document_embeddings,
    target_summary_embeddings,
    teacher_forcing_pct=0
)  # (batch_size, n_target_words, vocab_size)

vals, predicted_idx = torch.topk((extraction_probabilities), k=1, dim=2)

for x in [model.bert_tokenizer.convert_ids_to_tokens(p) for p in predicted_idx.squeeze().tolist()]:
    print(f"{x}")
    print()

['zu', '##lly', 'bro', '##uss', '##ard', 'decided', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a', 'kidney', 'to', 'give', 'a']

['zu', 'with', 'a', 'kidney', 'snapped', 'by', 'amateur', 'photographer', 'martin', 'le', '-', 'may', 'near', 'london', '.', '[SEP]', '-', 'may', 'near', 'london', '.', '[SEP]', '-', 'may', 'near', 'london', '.', '[SEP]', '-', 'may', 'near', 'london', '.', '[SEP]', '-', 'may', 'near', 'london', '.', '[SEP]', '-', 'may', 'near', 'london', '.', '[SEP]', '-', 'may', 'near', 'london', '.', '[SEP]', '-']



In [7]:
# Obtain extraction probability for each word in vocabulary
extraction_probabilities, teacher_forcing = model(
    source_document_embeddings,
    target_summary_embeddings,
    teacher_forcing_pct=1
)  # (batch_size, n_target_words, vocab_size)

vals, predicted_idx = torch.topk((extraction_probabilities), k=1, dim=2)

for x in [model.bert_tokenizer.convert_ids_to_tokens(p) for p in predicted_idx.squeeze().tolist()]:
    print(f"{x}")
    print()

['zu', '##lly', 'bro', '##uss', '##ard', 'decided', 'to', 'give', 'a', 'kidney', 'to', 'give', 'kidney', '.', '[SEP]', 'kidney', 'computer', 'program', 'helped', 'her', 'donation', 'spur', 'transplant', '##s', 'for', 'give', 'kidney', 'to', '.', '[SEP]', '-', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a']

['zu', 'kidney', 'of', 'a', 'kidney', 'wood', '##pe', '##cker', 'and', 'with', 'a', 'kidney', 'on', '-', 'back', 'has', 'le', 'viral', 'on', '-', '.', '[SEP]', 'hash', 'was', 'snapped', 'by', 'amateur', 'photographer', 'martin', 'le', '-', 'may', 'near', 'london', '.', '[SEP]', 'sparked', 'the', 'hash', '##tag', 'of', 'weasel', 'on', '##cker', 'and', 'has', 'le', 'numerous', 'me', '##mes', '.', '[SEP]', '-']



In [8]:
target_summaries

[['zully broussard decided to give a kidney to a stranger .',
  'a new computer program helped her donation spur transplants for six kidney patients .'],
 ['a photo of a green woodpecker flying with a weasel on its back has gone viral on twitter .',
  'the image was snapped by amateur photographer martin le - may near london .',
  'it sparked the hashtag #weaselpecker and has spawned numerous memes .']]

In [9]:
documents = torch.tensor([[  101, 16950,  9215, 22953, 17854,  4232,  2787,  2000,  2507,  1037,
         14234,  2000,  1037,  7985,  1012,  1037,  2047,  3274,  2565,  3271,
          2014, 13445, 12996, 22291,  2015,  2005,  2416, 14234,  5022,  1012,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [  101,  1037,  6302,  1997,  1037,  2665,  3536,  5051,  9102,  3909,
          2007,  1037, 29268,  2006,  2049,  2067,  2038,  2908, 13434,  2006,
         10474,  1012,  1996,  3746,  2001,  5941,  2011,  5515,  8088,  3235,
          3393,  1011,  2089,  2379,  2414,  1012,  2009, 13977,  1996, 23325,
         15900,  1001, 29268,  5051,  9102,  1998,  2038, 18379,  3365,  2033,
          7834,  1012,   102]])

model.bert_model(documents)

(tensor([[[-0.2617,  0.0380,  0.0511,  ...,  0.1047,  0.7850, -0.2537],
          [-1.1310, -0.7267,  0.6471,  ..., -0.1336,  1.3996,  0.7395],
          [ 0.3222, -0.3686,  0.7603,  ..., -0.3291,  0.1432, -0.6941],
          ...,
          [ 0.4166, -0.3199,  0.7278,  ..., -0.7318,  0.0372, -1.3327],
          [ 0.2818, -0.2167,  0.7193,  ..., -0.7352,  0.0445, -1.2620],
          [ 0.5745,  0.1054,  0.4005,  ..., -0.7345,  0.0988, -1.7492]],
 
         [[-0.1115, -0.1828, -0.4249,  ...,  0.2145,  0.5005,  0.3914],
          [ 0.2505, -0.1153, -0.4112,  ...,  0.1827, -0.3702,  0.4494],
          [ 0.5941, -0.2543,  0.0724,  ..., -0.0385, -0.3432,  0.4666],
          ...,
          [ 0.8742,  0.2570,  0.4955,  ..., -0.0061, -0.3800,  0.0748],
          [ 0.4946,  0.2816, -0.2287,  ...,  0.0602, -0.5411, -0.3153],
          [ 0.3065, -0.1702,  0.2163,  ...,  0.4447, -0.3457, -0.2315]]]),
 tensor([[-0.2098, -0.5410, -0.9906,  ..., -0.8945, -0.5418,  0.2921],
         [-0.9284, -0.5976, -

In [10]:
documents.shape

torch.Size([2, 53])

In [11]:
documents[0].shape

torch.Size([53])

In [12]:
model.bert_model(documents[0].view(1, -1))[0].shape

torch.Size([1, 53, 768])