# Sebastian Petrik - Stranasum - Final evaluation module

This module is intended for final evaluation of models produced by Stranasum Development module, along of existing models from huggingface.

In [None]:
%pip install evaluate rouge_score bert_score contractions transformers[sentencepiece] --quiet

import pandas as pd
import numpy as np
from datasets import load_dataset
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
import tensorflow as tf
import evaluate
import bert_score
import matplotlib.pyplot as plt
import re
import contractions
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import rouge_score
from rouge_score import rouge_scorer

In [None]:
%env CUDA_VISIBLE_DEVICES=1

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
tf.config.experimental.set_memory_growth(gpus[0], True)

CONFIG = dict(
    # hf transformers of direct path to stranasum otherwise
    model = "",
    path = "./summarizer_module",
    ds = 'gigaword'
)

In [None]:
# load ds

eval_set = None

if CONFIG['ds'] == 'gigaword':
    # gigaword_stranasum_test = pd.read_csv(f"../input/stranasum-gigaword-70-to-25/gigaword_test.csv")
    eval_set = load_dataset("gigaword", split="test").to_pandas().rename(columns={'document': 'article'})

In [None]:
class TextProcessor:
    
    # Text cleanup
    def clean_text(self, text: str):

        # lowercase
        text = str(text).lower()

        # remove &-escaped characters
        text = re.sub(r"&.[1-9]+;"," ", str(text))

        # remove escaped characters
        text=re.sub("(\\t)", ' ', str(text))
        text=re.sub("(\\r)", ' ', str(text))
        text=re.sub("(\\n)", ' ', str(text))

        # remove double characters
        text=re.sub("(__+)", ' ', str(text))  #remove _ if it occurs more than one time consecutively
        text=re.sub("(--+)", ' ', str(text))   #remove - if it occurs more than one time consecutively
        text=re.sub("(~~+)", ' ', str(text))   #remove ~ if it occurs more than one time consecutively
        text=re.sub("(\+\++)", ' ', str(text))  #remove + if it occurs more than one time consecutively
        text=re.sub("(\.\.+)", ' ', str(text))  #remove . if it occurs more than one time consecutively
        
        # special - fix u.s. contraction in gigaword
        text = re.sub("(u\.s\.)", 'united states', str(text))
        
        # fix contractions to base form
        text = contractions.fix(text)

        #remove special tokens <>()|&©ø"',;?~*!
        text=re.sub(r"[<>()|&©ø\[\]\'\",;?~*!]", ' ', str(text)).lower()

        # CNN mail data cleanup
        text=re.sub("(mailto:)", ' ', str(text)) #remove mailto:
        text=re.sub(r"(\\x9\d)", ' ', str(text)) #remove \x9* in text
        text=re.sub("([iI][nN][cC]\d+)", 'INC_NUM', str(text)) #replace INC nums to INC_NUM
        text=re.sub("([cC][mM]\d+)|([cC][hH][gG]\d+)", 'CM_NUM', str(text)) #replace CM# and CHG# to CM_NUM

        # url replacement into base form
        try:
            url = re.search(r'((https*:\/*)([^\/\s]+))(.[^\s]+)', str(text))
            repl_url = url.group(3)
            text = re.sub(r'((https*:\/*)([^\/\s]+))(.[^\s]+)',repl_url, str(text))
        except:
            pass


        # handle dot at the end of words
        text=re.sub("(\.\s+)", ' ', str(text)) # remove
        
        text=re.sub("(\-\s+)", ' ', str(text)) #remove - at end of words(not between)
        text=re.sub("(\:\s+)", ' ', str(text)) #remove : at end of words(not between)

        #remove multiple spaces
        text=re.sub("(\s+)",' ',str(text))

        # apply lowercase again
        text = text.lower().strip()
        
        # remove trailing dot, we will apply end of sequence anyway
        text = re.sub("(\.)$", '', str(text)).strip()
        
        # gigaword - UNK token
        text = re.sub("unk", '', str(text).strip())
        
        # gigaword - change numbers to hashtags
        text = re.sub("\d", "#", str(text).strip())

        return text

    def apply_special_tokens(self, text):
        text = str(text).strip()
        text = "<sos> " + str(text).strip() + " <eos>"
        return text

    def remove_special_tokens(self, text):
        text = text.lower()
        text = text.replace("<sos>", "").replace("<eos>", "")
        text = text.strip()
        return text
    
processor = TextProcessor()

In [None]:
# Final Stranasum summarization class using loadable TF graph model an processing
class Summarizer:
    
    # Initialize using SummarizationModule or loaded graph tf module
    def __init__(self, module: tf.Module):
        self.module = module
        self.processor = TextProcessor()
        
    def summarize(self, text: str):
        prepared = self.processor.apply_special_tokens(self.processor.clean_text(text))
        output_text, output_tensor, weights = self.module.predict(tf.constant([prepared]))
        return prepared, self.processor.remove_special_tokens(bytes.decode(output_text.numpy())), output_tensor, weights
    
    # Shorthand for text output only
    def __call__(self, text: str):
        return self.summarize(text)[1]

In [None]:
model, tokenizer = None, None
summarize = None

if CONFIG['model'] == "google/pegasus-xsum":
    tokenizer = PegasusTokenizer.from_pretrained(CONFIG['model'])
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = PegasusForConditionalGeneration.from_pretrained(CONFIG['model']).to(device)
    
    def hf_sum(src_text):
        batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)
        translated = model.generate(**batch, max_new_tokens=25)
        tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)

        return processor.clean_text(tgt_text[0])
    
    summarize = hf_sum

elif CONFIG['model'] == "google/roberta2roberta_L-24_gigaword":
    
    tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_gigaword")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForSeq2SeqLM.from_pretrained("google/roberta2roberta_L-24_gigaword").to(device)
    
    def hf_sum(src_text):
        batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)
        translated = model.generate(**batch, max_new_tokens=25)
        tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)

        return processor.clean_text(tgt_text[0])
    
    summarize = hf_sum

elif CONFIG['model'] == "a1noack/bart-large-gigaword":
    tokenizer = AutoTokenizer.from_pretrained("a1noack/bart-large-gigaword")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForSeq2SeqLM.from_pretrained("a1noack/bart-large-gigaword").to(device)    
    def hf_sum(src_text):
        batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)
        translated = model.generate(**batch, max_new_tokens=25)
        tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)

        return processor.clean_text(tgt_text[0])
    
    summarize = hf_sum

else:
    # stranasum
    model = tf.saved_model.load(CONFIG['path'])
    summarize = Summarizer(model)
    

In [None]:
# test summarization
eval_set['summary'].iloc[1], summarize(eval_set['article'].iloc[1])

In [None]:
# summarize with processor
summarize(processor.clean_text(eval_set['article'].iloc[1]))

In [None]:
def clean_frame(frame):
    for i in range(0, frame.shape[0]):
        frame.iloc[i, frame.columns.get_loc('summary')] = processor.clean_text(frame.iloc[i]['summary'])
    return frame

eval_frame = clean_frame(eval_frame)

In [None]:
def summarize_frame(frame, summarize_fn):
    
    frame = frame.copy()
    frame['predicted'] = '<NONE>'
    
    for i in range(0, frame.shape[0]):
        if i%100 == 0:
            print(f"Summarising ... {i}/{frame.shape[0]}")
            
        frame.iloc[i, frame.columns.get_loc('predicted')] = summarize_fn(frame.iloc[i]['article'])
        
    return frame

In [None]:
%%time
eval_frame = summarize_frame(eval_frame, summarize)

In [None]:
eval_frame

In [None]:
def rouge_full(
        predictions, references, rouge_types=None, use_aggregator=True
    ):
        if rouge_types is None:
            rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
        
        multi_ref = isinstance(references[0], list)
            
        scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types)
        
        if use_aggregator:
            aggregator = rouge_score.scoring.BootstrapAggregator()
        else:
            scores = []

        for ref, pred in zip(references, predictions):
            if multi_ref:
                score = scorer.score_multi(ref, pred)
            else:
                score = scorer.score(ref, pred)
            if use_aggregator:
                aggregator.add_scores(score)
            else:
                scores.append(score)

        if use_aggregator:
            result = aggregator.aggregate()
            for key in result:
                result[key] = result[key].mid

        else:
            result = {}
            for key in scores[0]:
                result[key] = list(score[key] for score in scores)
                
        return result


In [None]:
# Metrics evaluation

bleu_metric = evaluate.load('bleu')
bertscore_metric = evaluate.load('bertscore')
# rouge_metric = evaluate.load('rouge')

In [None]:
def plotbins(x):
    plt.hist(x, bins=100)
    plt.xlabel("score")
    plt.ylabel("counts")
    plt.show()

def evaluate_frame(frame: pd.DataFrame):
    preds = frame['predicted']
    refs = frame['summary']
    # rouge = rouge_metric.compute(references=refs, predictions=preds)
    bleu = bleu_metric.compute(references=refs, predictions=preds)
    rouge= rouge_full(references=refs, predictions=preds)
    
    # BP, BR, BF = bert_score.score(frame['predicted'], frame['summary'], lang="en", verbose=True)
    bertscore = bertscore_metric.compute(references=refs, predictions=preds, lang="en", verbose=True)

    b = dict(
        p = np.average(bertscore['precision']),
        r = np.average(bertscore['recall']),
        f = np.average(bertscore['f1'])
    )

    r = lambda x: f"{x:.3f}"
    rprf = lambda x: f"{r(rouge[x].precision)} & {r(rouge[x].recall)} & {r(rouge[x].fmeasure)}"
    
    print("\n---------")
    print(CONFIG)
    
    print(f"{CONFIG['model']} & {r(bleu['bleu'])} & {rprf('rouge1')} & {rprf('rouge2')} & {rprf('rougeL')}")
    
    print(f"{CONFIG['model']} & m & {rprf('rougeLsum')} & {r(b['p'])} & {r(b['r'])} & {r(b['f'])}")
    
    print("---------\n")
    
    return bleu, rouge, b

evaluate_frame(eval_frame)

In [None]:
def pretty_summaries(frame):
    
    for i, row in frame.iterrows():
        print(f"\n ------------------")
        print(f"Article  : {row['article']}")
        print(f"\nSummary  : {row['summary']}")
        print(f"\nPredicted: {row['predicted']}")
        print()
        print(f"------------------")
        
pretty_summaries(eval_frame[100:120])