In [31]:
import os

In [32]:
%pwd

'd:\\Test Summ'

In [3]:
os.chdir('../')

In [4]:
%pwd

'd:\\Test Summ'

In [33]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class ModelEvaluationConfig:
    root_dir: Path
    data_path: Path
    model_path: Path
    token_path: Path
    metric_file_name: Path
    

In [34]:
from textSummarizer.constants import *
from textSummarizer.utils.common import read_yaml, create_directories

In [39]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        
    def get_model_eval_config(self) -> ModelEvaluationConfig:
        config = self.config.model_eval
        create_directories([config.root_dir])
        
        model_eval_config = ModelEvaluationConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            model_path = config.model_path,
            token_path = config.token_path,
            metric_file_name = config.metric_file_name
        )   
        return model_eval_config 

In [36]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_dataset, load_from_disk, load_metric
import torch
import pandas as pd
from tqdm import tqdm

In [37]:
class ModelEvaluation:
    def __init__(self, config: ModelEvaluationConfig):
        self.config = config
        
    def generate_batch_sixe_chunk(self,list_of_element,batch_sixe):
            for i in range(0,len(list_of_element),batch_sixe):
                yield list_of_element[i : i+batch_sixe]
                
    def calculate_metric_on_test(self,dataset,metric,model,tokenixer,batch_sixe=16,column_text="article",column_summary="highlights"):
        article_batch = list(self.generate_batch_sixe_chunk(dataset[column_text], batch_sixe))
        target_batch =  list(self.generate_batch_sixe_chunk(dataset[column_summary], batch_sixe))     
        
        for article_batch,target_batch in tqdm(zip(article_batch, target_batch), total=len(article_batch)):
            inputs = tokenixer(article_batch,max_length=1024,truncation=True,padding="max_length",return_tensors="pt") 
            summaries = model.generate(input_ids=inputs["input_ids"],attention_mask=inputs["attention_mask"],length_penalty=0.8, num_beams=8, max_length=128)
            decoded_sum = [tokenixer.decode(s,skip_special_tokens=True,clean_up_tokenization_spaces=True) for s in summaries] 
            decoded_sum = [d.replace(""," ") for d in decoded_sum] 
            metric.add_batch(predictions=decoded_sum,references=target_batch)  
            
            score = metric.compute()
            return score
        
    def evaluate(self):
        tokenixer = AutoTokenizer.from_pretrained(self.config.token_path)
        model = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_path)   
        
        dataset_samsum_pt = load_from_disk(self.config.data_path)
        rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
        rouge_metric = load_metric('rouge')
        
        score = self.calculate_metric_on_test(
            dataset_samsum_pt['test'][0:10],rouge_metric,model,tokenixer,batch_sixe=2,column_text = 'dialogue', column_summary= 'summary'
        ) 
        
        rouge_dict = dict((n,score[n].mid.fmeasure)for n in rouge_names)
        
        df = pd.DataFrame(rouge_dict, index = ['bart_model'] )
        df.to_csv(self.config.metric_file_name, index=False)

In [40]:
try:
    config = ConfigurationManager()
    model_evaluation_config = config.get_model_eval_config()
    model_evaluation_config = ModelEvaluation(config=model_evaluation_config)
    model_evaluation_config.evaluate()
except Exception as e:
    raise e

[2024-01-15 22:12:50,638: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-01-15 22:12:50,641: INFO: common: yaml file: params.yaml loaded successfully]
[2024-01-15 22:12:51,329: INFO: common: File Created artifacts/model_evaluation]


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
  0%|          | 0/5 [00:00<?, ?it/s]

[2024-01-15 22:19:15,702: INFO: rouge_scorer: Using default tokenizer.]


  0%|          | 0/5 [03:40<?, ?it/s]
