In [1]:
from transformers import pipeline
import pandas as pd

from transformers import BartTokenizer, BartForConditionalGeneration
from tqdm import tqdm
import numpy as np
import evaluate
import textstat

from src.RAG_Calculater import RAG, get_top_n_articles
from src.Prompt_Factory import prompt_factory
from src.Case_Builder import (bert_version,
                              genai_version,
                              genai_model_name,
                              dataset_name,
                              massage_strategy
                              )

In [2]:
data_train = pd.read_json(f'src/dataset/clean/{dataset_name}/{bert_version}_train.json')
data_val = pd.read_json(f'src/dataset/clean/{dataset_name}/{bert_version}_validation.json')

In [3]:
data_val = data_val.loc[:4].copy()

In [4]:
data_train['rag_sentences'] = data_train['sentences_similarity'].apply(RAG)
data_val['rag_sentences'] = data_val['sentences_similarity'].apply(RAG)

In [None]:
chatbot = pipeline("text-generation", model=genai_model_name, min_new_tokens=256, max_new_tokens=512, repetition_penalty=1.2, no_repeat_ngram_size=3)

In [None]:
# ROUGE and BERTScore
rouge = evaluate.load('rouge')
bertscore = evaluate.load("bertscore")

# Load a pre-trained BART model for BARTScore
bart_model_name = "facebook/bart-large-cnn"
bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name)
bart_model = BartForConditionalGeneration.from_pretrained(bart_model_name)

def get_results(results):
    batch_size = 5
    
    # Compute ROUGE metrics
    print("ROUGE Metrics Calculater:")
    rouge_results = rouge.compute(
        predictions=results['prediction'],
        references=results['reference'],
        use_aggregator=True,
        use_stemmer=True,
    )
    # Compute BERTScore
    print("BERTScore Calculater:")
    bertscore_results = {
        "precision": [],
        "recall": [],
        "f1": [],
    }
    for idx in tqdm(range(0, len(results), batch_size)):
        str_idx = idx
        end_idx = idx + batch_size
        tmp_bertscore_results = bertscore.compute(
            predictions=results['prediction'][str_idx:end_idx].to_list(),
            references=results['reference'][str_idx:end_idx].to_list(),
            model_type="microsoft/deberta-xlarge-mnli",
        )
        bertscore_results["precision"].extend(tmp_bertscore_results["precision"])
        bertscore_results["recall"].extend(tmp_bertscore_results["recall"])
        bertscore_results["f1"].extend(tmp_bertscore_results["f1"])
    
    # Compute FKGL and DCRS for Readability
    print("FKGL Metrics Calculater:")
    fkgl_scores = [textstat.flesch_kincaid_grade(p) for p in results['prediction'].to_list()]
    print("DCRS Metrics Calculater:")
    dcrs_scores = [textstat.dale_chall_readability_score(p) for p in results['prediction'].to_list()]
    
    # Compute BARTScore for Factuality
    def compute_bart_score(predictions, references):
        bart_scores = []
        for pred, ref in zip(predictions, references):
            inputs = bart_tokenizer(ref, return_tensors="pt", truncation=True, max_length=1024)
            outputs = bart_tokenizer(pred, return_tensors="pt", truncation=True, max_length=1024)
            ref_to_pred_score = bart_model(**inputs, labels=outputs["input_ids"]).loss.item()
            pred_to_ref_score = bart_model(**outputs, labels=inputs["input_ids"]).loss.item()
            bart_scores.append((ref_to_pred_score + pred_to_ref_score) / 2)
        return bart_scores
    
    print("BARTScore Calculater:")
    bart_scores = {
        "bart_scores": [],
    }
    for idx in tqdm(range(0, len(results), batch_size)):
        str_idx = idx
        end_idx = idx + batch_size
        tmp_bart_scores = compute_bart_score(results['prediction'][str_idx:end_idx].to_list(), results['reference'][str_idx:end_idx].to_list())
        bart_scores["bart_scores"].extend(tmp_bart_scores)
        
    final_results = {
        "ROUGE1": [rouge_results['rouge1']],
        "ROUGE2": [rouge_results['rouge2']],
        "ROUGEL": [rouge_results['rougeL']],
        "BERTScore_Precision": [np.average(bertscore_results["precision"])],
        "BERTScore_Recall": [np.average(bertscore_results["recall"])],
        "BERTScore_F1": [np.average(bertscore_results["f1"])],
        "FKGL": [np.average(fkgl_scores)],
        "DCRS": [np.average(dcrs_scores)],
        "BARTScore": [np.average(bart_scores["bart_scores"])],
    }
    
    result_df = pd.DataFrame(final_results)
    
    return result_df

In [None]:
def get_predictions(prompt_strategy_used):
    results = []
    summaries = []
    
    for idx in range(len(data_val)):
        print(f"\n {idx+1} / {len(data_val)}", end="")
        
        target_row = data_val.loc[idx]
        
        if massage_strategy == "few_shot": 
            ref_rows_indexes = get_top_n_articles(data_train['title_embedding'], target_row['title_embedding'], n=3)
            ref_rows = data_train.loc[ref_rows_indexes].reset_index(drop=True)
            
        else: 
            ref_rows = None
        
        prompt = prompt_factory(prompt_strategy_used, target_row, ref_rows)
        massage = [{"role": "user", "content": prompt}]
        summary = " ".join(data_val.loc[idx, 'summary'])
        
        answer = chatbot(massage)
        
        if genai_version == "BioGBT":
            answer = answer[0]['generated_text'].split("## Answer:\n")[-1]
        else:
            answer = answer[0]['generated_text'][-1]['content']
            
        results.append(answer)
        summaries.append(summary)
    
    model_results = pd.DataFrame({
        'reference': summaries,
        'prediction': results
    })
    
    return model_results

In [None]:
for prompt_strategy_used in [1,2]:
    model_results = get_predictions(prompt_strategy_used)
    result_df = get_results(model_results)
    
    model_results.to_csv(f'results/{genai_version}_summaries_{bert_version}_{dataset_name}_{prompt_strategy_used}_val.csv', index=False)
    
    result_df.to_csv(f'results/{genai_version}_results_{bert_version}_{dataset_name}_{prompt_strategy_used}_val.csv', index=False)