In [1]:
from abstractor.train import get_training_batch
from abstractor.utils import AbstractorModel, AbstractorModelRNN
from abstractor.utils import obtain_initial_hidden_states
from bert.utils import obtain_sentence_embeddings
from bert.utils import obtain_word_embeddings
from data.utils import load_training_dictionaries
from extractor.utils import ExtractorModel
from pytorch_transformers import BertModel
from pytorch_transformers import BertTokenizer

import numpy as np
import torch

In [2]:
data = load_training_dictionaries()
documents, extraction_labels = get_training_batch(data, batch_size=2)

In [3]:
# Load data:
model = AbstractorModelRNN()
model_path = "results/models/abstractor.pt"
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [4]:
source_documents, target_summaries = get_training_batch(data, 2)

# Obtain embeddings
source_document_embeddings, source_mask, source_tokens = obtain_word_embeddings(
    model.bert_model, model.bert_tokenizer, source_documents, static_embeddings=False
)
target_summary_embeddings, target_mask, target_tokens = obtain_word_embeddings(
    model.bert_model, model.bert_tokenizer, target_summaries, static_embeddings=True
)

In [5]:
# Obtain extraction probability for each word in vocabulary
extraction_probabilities, teacher_forcing = model(
    source_document_embeddings,
    target_summary_embeddings,
    teacher_forcing_pct=0
)  # (batch_size, n_target_words, vocab_size)

vals, predicted_idx = torch.topk((extraction_probabilities), k=1, dim=2)

for x in [model.bert_tokenizer.convert_ids_to_tokens(p) for p in predicted_idx.squeeze().tolist()]:
    print(f"{x}")
    print()

['zu', '##lly', 'bro', '##uss', '##ard', 'decided', 'to', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.', 'a', '.']

['a', 'green', 'wood', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel', '##cker', 'weasel']



In [6]:
# Obtain extraction probability for each word in vocabulary
extraction_probabilities, teacher_forcing = model(
    source_document_embeddings,
    target_summary_embeddings,
    teacher_forcing_pct=1
)  # (batch_size, n_target_words, vocab_size)

vals, predicted_idx = torch.topk((extraction_probabilities), k=1, dim=2)

for x in [model.bert_tokenizer.convert_ids_to_tokens(p) for p in predicted_idx.squeeze().tolist()]:
    print(f"{x}")
    print()

['zu', '##lly', 'bro', '##uss', '##ard', 'decided', 'to', 'a', 'a', '.', 'to', 'a', '.', '.', 'a', '.', 'computer', 'program', 'helped', 'her', 'donation', 'spur', 'transplant', '##s', 'for', 'six', 'kidney', 'patients', '.', '[SEP]', 'computer', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney', 'kidney']

['a', 'green', 'of', 'a', 'green', 'wood', '##cker', '##cker', 'weasel', 'with', 'a', 'weasel', '##cker', 'its', 'back', 'has', 'weasel', 'viral', 'on', 'image', '.', 'has', 'image', 'was', 'snapped', 'by', '[SEP]', '.', '.', 'le', '-', 'may', 'london', 'london', '.', '[SEP]', '-', 'le', 'hash', '#', '#', '.', 'le', 'le', '.', '##tag', 'spawned', 'numerous', 'numerous', '##mes', '.', '[SEP]', 'london']



In [7]:
labels = [[model.bert_tokenizer.ids_to_tokens[token] for token in sentence] for sentence in target_tokens.tolist()]

for l in labels:
    print(l)
    print()


['[CLS]', 'zu', '##lly', 'bro', '##uss', '##ard', 'decided', 'to', 'give', 'a', 'kidney', 'to', 'a', 'stranger', '.', 'a', 'new', 'computer', 'program', 'helped', 'her', 'donation', 'spur', 'transplant', '##s', 'for', 'six', 'kidney', 'patients', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']

['[CLS]', 'a', 'photo', 'of', 'a', 'green', 'wood', '##pe', '##cker', 'flying', 'with', 'a', 'weasel', 'on', 'its', 'back', 'has', 'gone', 'viral', 'on', 'twitter', '.', 'the', 'image', 'was', 'snapped', 'by', 'amateur', 'photographer', 'martin', 'le', '-', 'may', 'near', 'london', '.', 'it', 'sparked', 'the', 'hash', '##tag', '#', 'weasel', '##pe', '##cker', 'and', 'has', 'spawned', 'numerous', 'me', '##mes', '.', '[SEP]']

