In [151]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
import logging

model_name = "facebook/bart-large-cnn"

class Summarizer():
    def __init__(self, model_name, max_length=1024, verbose=False):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, verbose=False)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True, output_scores=True,)
        self.article_max_length = max_length
        if verbose:
            logging.basicConfig(level=logging.DEBUG)
        else:
            logging.basicConfig(level=logging.INFO)
        self.max_length = 100
        self.min_length = 60
            
    def check_token_number(self, article):
        WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
        input_ids = self.tokenizer([WHITESPACE_HANDLER(article)], 
                                   return_tensors="pt")
        logging.info(f"There is {len(input_ids.input_ids[0])} word")
        if len(input_ids[0]) > 1024:
            return [0, 0]
        self.current_input_ids = input_ids.input_ids
        return [1, input_ids.input_ids]
    
    def summarize(self,article):
        flag, output = self.check_token_number(article)
        if flag == 0:
            logging.error("The article length contains more than 1024 word")
            return -1
        else:
            self.current_output = self.model.generate(
                input_ids=output,
                max_length=self.max_length,
                no_repeat_ngram_size=2,
                num_beams=4,
                output_attentions=True, 
                output_scores=True, 
                return_dict_in_generate=True,)
            output_ids = self.current_output.sequences[0]
            summary = self.tokenizer.decode(
                    output_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False
                )
            return summary
        
    def increase_max_length(self):
        self.max_length += 50
        self.min_length += 50
        logging.info(f"output max length : {self.max_length} output min length: {self.min_length}")

    def decrease_max_length(self):
        if self.min_length -50 <= 0:
            logging.error(f"output max length cannot be lower with the current value of {self.max_length}")
            return None
        self.max_length -= 50
        self.min_length -= 50
        logging.info(f"output max length : {self.max_length} output min length: {self.min_length}")
    


In [152]:
input_file = "input.txt"
with open(input_file, "r") as file:
    review = file.read()
file.close()

In [153]:
summarizer = Summarizer(model_name)

DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /facebook/bart-large-cnn/resolve/main/tokenizer_config.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /facebook/bart-large-cnn/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /facebook/bart-large-cnn/resolve/main/tokenizer_config.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /facebook/bart-large-cnn/resolve/main/vocab.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /facebook/bart-large-cnn/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /facebook/bart-large-cnn/resolve/main/generation_config.json HTTP/1.1" 200 0


In [154]:
print(summarizer.summarize(review))

INFO:root:There is 807 word


Former President Donald Trump held a CNN town hall in New Hampshire on Wednesday night. He said the US should be willing to blow up the debt ceiling to avoid a default. Trump also said he would return to one of his harshest immigration enforcement policies: separating migrant families at the border.


In [155]:
summarizer.current_output.encoder_attentions

(tensor([[[[4.8211e-05, 1.6352e-04, 5.9066e-04,  ..., 9.7105e-04,
            2.5769e-03, 3.4959e-04],
           [5.7326e-02, 2.2401e-02, 6.7632e-03,  ..., 1.2028e-03,
            1.2189e-03, 7.9913e-03],
           [2.1035e-01, 3.8867e-02, 1.7565e-02,  ..., 4.3584e-04,
            3.9148e-04, 3.4857e-03],
           ...,
           [2.1035e-02, 7.5629e-04, 1.6593e-03,  ..., 1.2840e-02,
            2.4875e-02, 1.6699e-02],
           [7.6485e-03, 3.7942e-04, 1.7795e-04,  ..., 5.8411e-03,
            5.3894e-02, 1.2054e-02],
           [3.5860e-01, 8.3769e-04, 5.5471e-04,  ..., 4.4215e-03,
            1.2306e-02, 2.7352e-02]],
 
          [[5.5225e-04, 1.3011e-04, 1.2595e-04,  ..., 1.2468e-05,
            8.1877e-03, 3.1829e-02],
           [1.8773e-02, 4.0166e-03, 1.3582e-02,  ..., 1.0695e-03,
            1.4611e-03, 5.5162e-03],
           [7.1583e-02, 1.1715e-02, 4.6135e-03,  ..., 5.3804e-04,
            4.5545e-04, 2.1474e-03],
           ...,
           [2.3675e-03, 4.3009e-04, 2.

In [156]:
from bertviz import model_view, head_view
tokens = summarizer.tokenizer.convert_ids_to_tokens(summarizer.current_input_ids[0]) 
head_view(summarizer.current_output.encoder_attentions, tokens)