In [1]:
from bert.utils import obtain_sentence_embeddings
from data.utils import load_training_dictionaries
from extractor.utils import ExtractorModel
from extractor.train import get_training_batch
from pytorch_transformers import BertModel
from pytorch_transformers import BertTokenizer

import torch
import numpy as np

In [2]:
# Load data:
model = ExtractorModel()
model_path = "results/models/extractor.pt"
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [3]:
data = load_training_dictionaries()
documents, extraction_labels = get_training_batch(data, batch_size=2)

In [4]:
sentence_embeddings, mask = obtain_sentence_embeddings(model.bert_model, model.bert_tokenizer, documents)

In [5]:
extraction_probabilities = model(sentence_embeddings)
extraction_probabilities

tensor([[0.0042, 0.9473, 0.0042, 0.0052, 0.0042, 0.0052, 0.0040, 0.1575, 0.0052,
         0.0052, 0.0052, 0.0042, 0.0274, 0.0052, 0.0040, 0.0044, 0.0040, 0.0043,
         0.0043, 0.0043, 0.0040, 0.0043, 0.9157, 0.0043, 0.0043, 0.0043, 0.0043,
         0.0043, 0.0040],
        [0.0225, 0.1229, 0.0226, 0.0356, 0.0225, 0.0356, 0.0222, 0.9789, 0.0356,
         0.0356, 0.9018, 0.0224, 0.9699, 0.0356, 0.0685, 0.0685, 0.0685, 0.0685,
         0.0685, 0.0685, 0.0685, 0.0685, 0.0685, 0.0685, 0.0685, 0.0685, 0.0685,
         0.0685, 0.0685]], grad_fn=<SqueezeBackward0>)

In [6]:
max_extracted_sentences = int(torch.max(torch.tensor([x.sum() for x in extraction_labels])))
vals, predicted_idx = torch.topk((extraction_probabilities * mask), k=max_extracted_sentences, dim=1)
vals

tensor([[0.9473, 0.9157, 0.1575],
        [0.9789, 0.9699, 0.9018]], grad_fn=<TopkBackward>)

In [7]:
for i, doc in enumerate(documents):
    label_idx = extraction_labels[i].numpy()[:len(doc)]
    label_idx = label_idx.astype(bool)
    
    n_summary_sentences = label_idx.sum()
    
    target = np.array(doc)[label_idx]
    prediction = np.array(doc)[predicted_idx[i]][:n_summary_sentences]
    
    print("\n\n".join(target.tolist()))
    print("\n\n---------------------------\n\n")
    print("\n\n".join(prediction.tolist()))
    
    print("\n\n-----------------------------------------------------")
    print("-----------------------------------------------------")
    print("-----------------------------------------------------\n\n")


that may sound like an esoteric adage , but when zully broussard selflessly decided to give one of her kidneys to a stranger , her generosity paired up with big data . it resulted in six patients receiving transplants .

that changed when a computer programmer named david jacobs received a kidney transplant . he had been waiting on a deceased donor list , when a live donor came along -- someone nice enough to give away a kidney to a stranger .


---------------------------


that may sound like an esoteric adage , but when zully broussard selflessly decided to give one of her kidneys to a stranger , her generosity paired up with big data . it resulted in six patients receiving transplants .

that changed when a computer programmer named david jacobs received a kidney transplant . he had been waiting on a deceased donor list , when a live donor came along -- someone nice enough to give away a kidney to a stranger .


-----------------------------------------------------
----------------

In [8]:
sentence_embeddings[0][0]

tensor([ 9.1648e-01, -3.2343e-01,  2.4150e-01,  3.9926e-01, -2.5483e-02,
        -1.1539e-01, -4.3888e-01,  2.1978e-01,  9.9330e-02, -3.9748e-01,
        -1.6836e-01, -4.5559e-01, -1.3534e-01,  3.1823e-01,  1.2135e-01,
         7.3621e-01,  4.4606e-01,  4.9633e-01, -1.7027e-01,  1.8118e-01,
         3.4028e-01,  6.3721e-02,  1.4311e-01,  7.1670e-01,  1.5095e-02,
         3.0640e-01, -3.6382e-01, -2.3979e-01, -4.5630e-01, -1.3057e-01,
         1.6660e-01,  2.7311e-01, -6.4907e-01,  3.4096e-01,  3.8613e-02,
        -5.8771e-01,  3.2407e-02, -3.5140e-01, -4.7036e-01, -3.5125e-01,
        -7.4052e-01, -3.8173e-02, -2.2250e-02,  3.7574e-01,  6.4439e-01,
        -3.5434e-01, -8.7023e-01,  1.0865e-01, -3.8263e-01, -6.0926e-02,
        -5.8005e-01,  1.6905e-01, -6.1255e-02,  4.3012e-01,  2.5527e-01,
         4.4028e-01, -1.9812e-01, -5.3067e-01, -2.4937e-01, -7.9281e-01,
         2.0003e-01, -2.9016e-02, -5.7169e-02, -2.8392e-01, -2.6981e-01,
         1.3185e-01,  6.6718e-01, -7.0120e-02,  2.5

In [9]:
sentence_embeddings[0][1]

tensor([ 4.8680e-01,  4.9151e-02,  1.9606e-01, -3.4058e-01,  3.0008e-01,
        -6.2466e-01,  9.2847e-02,  4.5261e-01, -2.3005e-02, -1.7222e-02,
         5.4992e-01, -3.9054e-01,  1.5670e-01,  8.9556e-01,  4.3296e-01,
         2.2120e-01, -4.7739e-02, -3.5062e-02,  3.1028e-01, -4.0670e-01,
         2.2673e-02, -2.3472e-01, -1.0332e-01, -2.8811e-01, -4.9319e-01,
         9.1536e-02, -7.1086e-01, -6.2026e-01, -5.2332e-01,  4.0454e-01,
        -2.8737e-02, -2.8800e-02, -1.9619e-01,  1.9554e-02,  1.2328e-01,
        -2.6908e-01,  2.7553e-01,  9.1126e-02, -6.2420e-02, -3.1911e-01,
        -4.1409e-01, -1.6598e-02,  2.6347e-02,  2.6132e-01, -5.3553e-02,
        -2.1431e-01, -2.6839e+00,  2.0801e-02, -3.7034e-01, -1.9504e-01,
         2.7280e-01, -4.6700e-01,  3.4906e-02,  2.9805e-01,  3.3433e-01,
         3.0009e-01, -1.3517e-01, -2.2972e-01,  1.3278e-01, -3.0929e-01,
         3.6339e-01, -2.9009e-01,  2.6028e-01, -7.3844e-02, -8.9725e-02,
         2.8340e-01,  4.2069e-01,  4.6197e-02, -3.7