### Script to generate summaries using chunking based BART method

Assign the dataset and output_path variable according to requirements.  


In [4]:
import sys
sys.path.insert(0, '../')
from utilities.BART_utilities import *
import utilities.paper_functions as p_fct
import utilities.functions as fct

import pandas as pd
import numpy as np
import os

import time
from tqdm import tqdm


In [5]:
TEXTS_COUNT = 100

test_path_txt = '../SCOTUS_data/text_dev'
test_path_sum = '../SCOTUS_data/summary_dev'
target_path_csv = '../SCOTUS_data/paragraph_target_df_dev.csv'

In [6]:
# Loading Model and tokenizer
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartConfig

tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")

### For using fine tuned model 
1. uncomment the 2nd line in the following cell
2. add the path to the fine tuned model

In [7]:
bart_model = LitModel(learning_rate = 2e-5, tokenizer = tokenizer, model = model)

# bart_model = LitModel.load_from_checkpoint("/home/pahelibhattacharya/HULK/Abhay/models/BART_large_IN_MCS.ckpt",
#                                       learning_rate = 2e-5, tokenizer = tokenizer, model = model).to("cuda")

In [8]:
def generate_summary_gpu(nested_sentences,p=0.2):
  '''
    Function to generate summaries from the list containing chunks of the document
    input:  nested_sentences - chunks
            p - Number of words in summaries per word in the document
    output: document summary
    '''
  device = 'cuda'
  summaries = []
  for nested in nested_sentences:
    l = int(p * len(nested.split(" ")))
    input_tokenized = tokenizer.encode(nested, truncation=True, return_tensors='pt')
    input_tokenized = input_tokenized.to(device)
    summary_ids = bart_model.model.to(device).generate(input_tokenized,
                                      length_penalty=0.01,
                                      min_length=l-5,
                                      max_length=l+5)
    output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
    summaries.append(output)
  summaries = [sentence for sublist in summaries for sentence in sublist]
  return summaries

In [9]:
def BART_summarize(text, req_len=512):
    input_len = len(text.split(" "))
    req_len = 512 
    
    nested = p_fct.nest_sentences(text,1024)
    p = float(req_len/input_len)
    
    abs_summ = generate_summary_gpu(nested,p)
    abs_summ = " ".join(abs_summ)
    
    if len(abs_summ.split(" ")) > req_len:
        abs_summ = abs_summ.split(" ")
        abs_summ = abs_summ[:req_len]
        abs_summ = " ".join(abs_summ)

    return abs_summ


In [10]:
def evaluate_models(document, ref, short=True):
    start_time = time.time()
    
    summary = BART_summarize(document)
    bart_evaluations = fct.evaluations(summary, ref, short)

    end_time = time.time()
    execution_time = end_time - start_time


    bart_evaluations['Execution time'] = execution_time
                    
    return bart_evaluations, summary

In [11]:
summary_ref = []
texts = []

for file_name in tqdm(os.listdir(test_path_sum)[:TEXTS_COUNT]):
    with open(os.path.join(test_path_sum, file_name), 'r', encoding="utf-8") as f:
        text = f.read()
        summary_ref.append(text)
    text = open(os.path.join(test_path_txt, file_name), 'r', encoding="utf-8").read()
    texts.append(text)

100%|██████████| 100/100 [00:01<00:00, 53.81it/s]


In [12]:
summary_gen = []
results = pd.DataFrame()

df_target = pd.read_csv(target_path_csv)

for i in tqdm(range(0, TEXTS_COUNT)):
    r, summary = evaluate_models(texts[i], summary_ref[i], "TXT") 
    
    summary_gen.append(summary) 
    results = pd.concat([results, r], ignore_index=True)

df_results = pd.DataFrame({"Text": texts, "Reference": summary_ref, "Generated": summary_gen})
df_results.to_csv("./output/results_BART_dev.csv", index=False)


  attn_output = torch.nn.functional.scaled_dot_product_attention(
100%|██████████| 100/100 [19:39<00:00, 11.79s/it]


In [13]:
from rouge_score import rouge_scorer
from bert_score import BERTScorer

ROUGE_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
BERT_scorer = BERTScorer(lang="en")

scores = {
    'facts_of_the_case': {'rouge1': [], 'rouge2': [], 'rougeL': [], 'bert_score': []},
    'question': {'rouge1': [], 'rouge2': [], 'rougeL': [], 'bert_score': []},
    'conclusion': {'rouge1': [], 'rouge2': [], 'rougeL': [], 'bert_score': []}
}

for column_name in df_target.columns:
    for i in range(TEXTS_COUNT):
        ref = df_target[column_name].iloc[i]
        gen = summary_gen[i]

        # Scores ROUGE
        rouge_score = ROUGE_scorer.score(ref, gen)
        scores[column_name]['rouge1'].append(rouge_score['rouge1'].recall)
        scores[column_name]['rouge2'].append(rouge_score['rouge2'].recall)
        scores[column_name]['rougeL'].append(rouge_score['rougeL'].recall)

        # Scores BERT
        _, _, bert_score = BERT_scorer.score([gen], [ref])
        scores[column_name]['bert_score'].append(bert_score.mean().item())

avg_scores_target = {
    col: {
        'rouge1': sum(scores[col]['rouge1']) / len(scores[col]['rouge1']),
        'rouge2': sum(scores[col]['rouge2']) / len(scores[col]['rouge2']),
        'rougeL': sum(scores[col]['rougeL']) / len(scores[col]['rougeL']),
        'bert_score': sum(scores[col]['bert_score']) / len(scores[col]['bert_score'])
    }
    for col in df_target.columns
}

for col, metrics in avg_scores_target.items():
    print(f"\nScores moyens pour {col} :")
    print(metrics)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Scores moyens pour facts_of_the_case :
{'rouge1': 0.38908383734598234, 'rouge2': 0.061016172261099746, 'rougeL': 0.21610336427356894, 'bert_score': 0.7876326262950897}

Scores moyens pour question :
{'rouge1': 0.431647948600343, 'rouge2': 0.057740497056222265, 'rougeL': 0.33810648033418345, 'bert_score': 0.7837799018621445}

Scores moyens pour conclusion :
{'rouge1': 0.41191027332698676, 'rouge2': 0.062408293079037384, 'rougeL': 0.2200530137431965, 'bert_score': 0.7889574688673019}


In [14]:
metrics = [col for col in results.columns if col in ['rouge1', 'rouge2', 'rougeL', 'bert_score', 'Execution time']]
means = results[metrics].mean()

df_avg_scores_target = pd.DataFrame(avg_scores_target).T

global_row = pd.DataFrame(means).T
global_row.index = ['global']

if 'Execution time' not in df_avg_scores_target.columns:
    df_avg_scores_target = df_avg_scores_target.assign(**{'Execution time': None})

df_score = pd.concat([global_row, df_avg_scores_target], axis=0)

  df_score = pd.concat([global_row, df_avg_scores_target], axis=0)


In [15]:
print(df_score.head())

                     rouge1    rouge2    rougeL  bert_score  Execution time
global             0.421710  0.129064  0.186344    0.816769       11.788604
facts_of_the_case  0.389084  0.061016  0.216103    0.787633             NaN
question           0.431648  0.057740  0.338106    0.783780             NaN
conclusion         0.411910  0.062408  0.220053    0.788957             NaN


In [16]:
df_score.to_csv("./output/scores_BART_dev.csv")

In [19]:
print("Execution time in total : ", results["Execution time"].sum())

Execution time in total :  1178.8604154586792


### Score sur le résumé global pour chaque texte

In [17]:
styled_df = results.style.apply(fct.highlight_min_max, axis=None)

styled_df

Unnamed: 0,rouge1,rouge2,rougeL,bert_score,Execution time
0,0.53866,0.317829,0.270619,0.866737,13.91762
1,0.544928,0.212209,0.217391,0.828694,10.569826
2,0.279302,0.095,0.147132,0.834855,10.726489
3,0.438486,0.120253,0.189274,0.819487,10.699037
4,0.456284,0.106849,0.191257,0.80916,11.485411
5,0.339623,0.089189,0.134771,0.820974,12.602061
6,0.512894,0.152299,0.229226,0.813714,11.184083
7,0.439024,0.119565,0.176152,0.80466,12.215755
8,0.444444,0.122283,0.181572,0.815485,11.594822
9,0.595361,0.183463,0.283505,0.819019,11.735221
