In [2]:
# %pip install bm25s
# %pip install spacy
# %pip install -U 'spacy[cuda12x]'
# %pip install rouge_score
# %pip install pysbd

In [14]:
import sys
sys.path.insert(0, '../')
import utilities.functions as fct
import time
import os
import pandas as pd
from tqdm import tqdm

# BM25

In [4]:
def evaluate_models(document, paragraph_target, short=True):
    results = pd.DataFrame()
    summaries = []
    
    start_time = time.time()
    
    # Segmentation des phrases du document source
    sentences = fct.sent_segmentation(document, method="pySBD")
    
    # Résumé des phrases
    query = fct.select_query(document)
    summary = fct.bb25LegalSum(sentences, model, query)
    
    # Évaluation de la qualité du résumé à l'aide de la métrique ROUGE et BERT
    bb25_evaluation = fct.evaluations(" ".join(summary), paragraph_target, short)

    end_time = time.time()
    execution_time = end_time - start_time

    # Prepare results for this method
    bb25_evaluation['Execution time'] = execution_time
            
    # Append results
    results = pd.concat([results, bb25_evaluation], ignore_index=True)
    summaries.append("\n".join(summary))
        
    return results, summaries

### BM25 with pySBD on 100 cleaned documents

In [5]:
test_path_txt = '../SCOTUS_data/text_dev'
test_path_sum = '../SCOTUS_data/summary_dev'
target_path_csv = '../SCOTUS_data/paragraph_target_df_dev.csv'

TEXTS_COUNT = 100

summary_ref = []
texts = []

for file_name in tqdm(os.listdir(test_path_sum)[:TEXTS_COUNT]):
    with open(os.path.join(test_path_sum, file_name), 'r', encoding="utf-8") as f:
        text = f.read()
        summary_ref.append(text)
    text = open(os.path.join(test_path_txt, file_name), 'r', encoding="utf-8").read()
    texts.append(text)

100%|██████████| 100/100 [00:01<00:00, 65.72it/s]


In [6]:
model = "nlpaueb/legal-bert-base-uncased"

summary_gen = []
results = pd.DataFrame()
df_target = pd.read_csv(target_path_csv)

for i in tqdm(range(len(summary_gen), TEXTS_COUNT)):
    r, summary = evaluate_models(texts[i], summary_ref[i], True) 
    
    summary_gen.append(summary) 
    results = pd.concat([results, r], ignore_index=True)

df_results = pd.DataFrame({"Text": texts, "Reference": summary_ref, "Generated": summary_gen})
df_results.to_csv("./output/results_BM25_dev.csv", index=False)


100%|██████████| 100/100 [43:42<00:00, 26.22s/it]


In [7]:
from rouge_score import rouge_scorer
from bert_score import BERTScorer

ROUGE_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
BERT_scorer = BERTScorer(lang="en")

scores = {
    'facts_of_the_case': {'rouge1': [], 'rouge2': [], 'rougeL': [], 'bert_score': []},
    'question': {'rouge1': [], 'rouge2': [], 'rougeL': [], 'bert_score': []},
    'conclusion': {'rouge1': [], 'rouge2': [], 'rougeL': [], 'bert_score': []}
}

for column_name in df_target.columns:
    for i in range(TEXTS_COUNT):
        ref = df_target[column_name].iloc[i]
        gen = summary_gen[i][0]

        # Scores ROUGE
        rouge_score = ROUGE_scorer.score(ref, gen)
        scores[column_name]['rouge1'].append(rouge_score['rouge1'].recall)
        scores[column_name]['rouge2'].append(rouge_score['rouge2'].recall)
        scores[column_name]['rougeL'].append(rouge_score['rougeL'].recall)

        # Scores BERT
        _, _, bert_score = BERT_scorer.score([gen], [ref])
        scores[column_name]['bert_score'].append(bert_score.mean().item())

avg_scores_target = {
    col: {
        'rouge1': sum(scores[col]['rouge1']) / len(scores[col]['rouge1']),
        'rouge2': sum(scores[col]['rouge2']) / len(scores[col]['rouge2']),
        'rougeL': sum(scores[col]['rougeL']) / len(scores[col]['rougeL']),
        'bert_score': sum(scores[col]['bert_score']) / len(scores[col]['bert_score'])
    }
    for col in df_target.columns
}

for col, metrics in avg_scores_target.items():
    print(f"\nScores moyens pour {col} :")
    print(metrics)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Scores moyens pour facts_of_the_case :
{'rouge1': 0.3260580939501271, 'rouge2': 0.04578622265580652, 'rougeL': 0.19392047626640999, 'bert_score': 0.7872273826599121}

Scores moyens pour question :
{'rouge1': 0.37437800523822196, 'rouge2': 0.04377399909499029, 'rougeL': 0.30698746717832215, 'bert_score': 0.7839809638261795}

Scores moyens pour conclusion :
{'rouge1': 0.3362871780123839, 'rouge2': 0.04518677434537322, 'rougeL': 0.19120166132210803, 'bert_score': 0.7867434746026993}


In [None]:
metrics = [col for col in results.columns if col in ['rouge1', 'rouge2', 'rougeL', 'bert_score', 'Execution time']]
means = results[metrics].mean()

df_avg_scores_target = pd.DataFrame(avg_scores_target).T

global_row = pd.DataFrame(means).T
global_row.index = ['global']

if 'Execution time' not in df_avg_scores_target.columns:
    df_avg_scores_target = df_avg_scores_target.assign(**{'Execution time': None})

df_score = pd.concat([global_row, df_avg_scores_target], axis=0)

In [9]:
print(df_score.head())

                     rouge1    rouge2    rougeL  bert_score  Execution time
global             0.385961  0.135818  0.194667    0.819438       26.219409
facts_of_the_case  0.326058  0.045786  0.193920    0.787227             NaN
question           0.374378  0.043774  0.306987    0.783981             NaN
conclusion         0.336287  0.045187  0.191202    0.786743             NaN


In [10]:
df_score.to_csv("./output/scores_BM25_dev.csv")

In [19]:
print("Execution time in total : ", results["Execution time"].sum())

Execution time in total :  2621.940869808197


### Score sur le résumé global pour chaque texte

In [18]:
styled_df = results.style.apply(fct.highlight_min_max, axis=None)

styled_df

Unnamed: 0,rouge1,rouge2,rougeL,bert_score,Execution time
0,0.446623,0.266376,0.202614,0.852792,8.103309
1,0.583554,0.260638,0.291777,0.858557,7.615244
2,0.207407,0.064356,0.093827,0.790141,15.158792
3,0.502857,0.166667,0.308571,0.816922,16.805995
4,0.563107,0.272727,0.33657,0.853419,29.782994
5,0.537234,0.192513,0.276596,0.852082,45.752655
6,0.342318,0.067568,0.177898,0.794284,19.23982
7,0.455982,0.151584,0.218962,0.835727,31.565603
8,0.421687,0.140097,0.159036,0.807131,21.189183
9,0.383838,0.142132,0.222222,0.787479,29.01312
