In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Import required libraries

In [94]:
import numpy as np
from src.bertsum import BertSummarizerConfig, BertSummarizer
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from utils import tokenize_text_to_sentences, prepare_sample

checkpoint = 'eReverter/bert-finetuned-cnn_dailymail'
dataset = 'eReverter/cnn_dailymail_extractive'

Load dataset, model, and tokenizer

In [80]:
data_dict = load_dataset('eReverter/cnn_dailymail_extractive')
data_dict

Using custom data configuration eReverter--cnn_dailymail_extractive-724c7cce7ac202ac
Found cached dataset parquet (/home/usuaris/veu/enric.reverter/.cache/huggingface/datasets/eReverter___parquet/eReverter--cnn_dailymail_extractive-724c7cce7ac202ac/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['src', 'tgt', 'labels'],
        num_rows: 11490
    })
    validation: Dataset({
        features: ['src', 'tgt', 'labels'],
        num_rows: 13368
    })
    train: Dataset({
        features: ['src', 'tgt', 'labels'],
        num_rows: 287113
    })
})

In [81]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = BertSummarizer.from_pretrained(checkpoint)

Downloading:   0%|          | 0.00/348 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/125 [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Look at one sample

In [82]:
sample = data_dict['test'][24]['src']
sample

["(CNN)Since Iran's Islamic Revolution in 1979, women have been barred from attending most sports events involving men.",
 'But the situation appears set to improve in the coming months after a top Iranian sports official said that the ban will be lifted for some events.',
 'A plan to allow "women and families" to enter sports stadiums will come into effect in the next year, Deputy Sports Minister Abdolhamid Ahmadi said Saturday, according to state-run media.',
 "But it isn't clear exactly which games women will be able to attend.",
 'According to the state-run Press TV, Ahmadi said the restrictions would be lifted for indoor sports events.',
 'The rules won\'t change for all matches because some sports are mainly related to men and "families are not interested in attending" them, Press TV cited him as saying.',
 "Iranian authorities imposed the ban on women attending men's sports events after the revolution, deeming that mixed crowds watching games together was un-Islamic.",
 "During 

Prepare the sample

In [83]:
model_inputs = prepare_sample(sample, tokenizer)
updated_sample = model_inputs.pop('sample')

Inference

In [84]:
outputs = model(**model_inputs)
outputs

{'logits': tensor([[9.9696e-01, 8.5784e-01, 9.9671e-01, 2.9444e-02, 9.4092e-01, 6.1778e-01,
          9.5534e-01, 3.4963e-03, 2.1245e-01, 4.4468e-02, 2.4277e-04, 1.8779e-04,
          6.4062e-04, 8.5547e-04, 1.6864e-04, 1.1845e-02, 1.8611e-04]],
        grad_fn=<SqueezeBackward1>),
 'mask_cls': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [85]:
len(outputs['logits'][0]) == len(updated_sample)

True

Get the summary from the updated sample (some sentences are filtered out during the preprocessing)

In [86]:
# Select top 3 sentences for the summary
summary = ' '.join([updated_sample[i] for i in outputs['logits'].topk(3).indices.detach().cpu().numpy()[0]])
summary

'(CNN)Since Iran\'s Islamic Revolution in 1979, women have been barred from attending most sports events involving men. A plan to allow "women and families" to enter sports stadiums will come into effect in the next year, Deputy Sports Minister Abdolhamid Ahmadi said Saturday, according to state-run media. Iranian authorities imposed the ban on women attending men\'s sports events after the revolution, deeming that mixed crowds watching games together was un-Islamic.'

Example from Wikipedia text 

In [93]:
wikipedia_text = """
Wine is an alcoholic drink typically made from fermented grapes. Yeast consumes the sugar in the grapes and converts it to ethanol and carbon dioxide, releasing heat in the process. Different varieties of grapes and strains of yeasts are major factors in different styles of wine. These differences result from the complex interactions between the biochemical development of the grape, the reactions involved in fermentation, the grape's growing environment (terroir), and the wine production process. Many countries enact legal appellations intended to define styles and qualities of wine. These typically restrict the geographical origin and permitted varieties of grapes, as well as other aspects of wine production. Wines can be made by fermentation of other fruit crops such as plum, cherry, pomegranate, blueberry, currant and elderberry.
"""

sample = tokenize_text_to_sentences(wikipedia_text)
model_inputs = prepare_sample(sample, tokenizer)
updated_sample = model_inputs.pop('sample')
outputs = model(**model_inputs)
summary = ' '.join([updated_sample[i] for i in np.sort(outputs['logits'].topk(3).indices.detach().cpu().numpy()[0])])
summary

'\nWine is an alcoholic drink typically made from fermented grapes. Yeast consumes the sugar in the grapes and converts it to ethanol and carbon dioxide, releasing heat in the process. Different varieties of grapes and strains of yeasts are major factors in different styles of wine.'