In [4]:
import os

In [6]:
%pwd

'D:\\text_Summerizer\\text_summarizer'

In [7]:
import torch

In [5]:
os.chdir("D:/text_Summerizer/text_summarizer")

In [9]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelEvaluationConfig:
    root_dir:Path
    data_path: Path
    model_path: Path
    metric_file_name: Path

In [10]:
from textSummarizer.constants import *
from textSummarizer.utils.common import read_yaml,create_directories


In [11]:
class ConfriguationManager:
    def __init__(self,
                config_file_path=CONFIG_FILE_PATH,
                params_file_path=PARAMS_FILE_PATH):
        self.config=read_yaml(config_file_path)
        self.params=read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_model_evaluation_config(self) -> ModelEvaluationConfig:
        config=self.config.model_evaluation

        create_directories([config.root_dir])

        model_evaluation_config=ModelEvaluationConfig(
            root_dir =config.root_dir,
            data_path=config.data_path,
            model_path=config.model_path,
            metric_file_name=config.metric_file_name
        )

        return model_evaluation_config
    

In [12]:
from transformers import AutoModelForSeq2SeqLM,AutoTokenizer
from datasets import load_dataset,load_from_disk
import pandas as pd
from tqdm import tqdm
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
!pip install evaluate -q

In [13]:
device="cuda" if torch.cuda.is_available else "cpu"
device

'cuda'

In [14]:
import evaluate

In [18]:
class ModelEvaluation:
    def __init__(self, config):
        self.config = config

    def generate_batch_sized_chunks(self, list_of_elements, batch_size):
        for i in range(0, len(list_of_elements), batch_size):
            yield list_of_elements[i:i + batch_size]

    def calculate_metric_on_test_ds(self,dataset,metric,model,tokenizer,batch_size=16,column_text="article",column_summary="highlights",device="cpu"):
        article_batches = list(self.generate_batch_sized_chunks(dataset[column_text], batch_size))
        target_batches = list(self.generate_batch_sized_chunks(dataset[column_summary], batch_size))

        for article_batch, target_batch in tqdm(zip(article_batches, target_batches), total=len(article_batches)):
            inputs = tokenizer(article_batch,max_length=1024,truncation=True,padding="max_length",return_tensors="pt")
            summaries = model.generate(input_ids=inputs["input_ids"].to(device),attention_mask=inputs["attention_mask"].to(device),length_penalty=0.8,num_beams=8,max_length=256,early_stopping=True)
            decoded_summaries = [
                tokenizer.decode(s, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                for s in summaries
            ]
            metric.add_batch(predictions=decoded_summaries, references=target_batch)

        score = metric.compute()
        return score

    def evaluate(self):

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
        model_bart = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_path).to(device)

        dataset_samsum_pt = load_from_disk(self.config.data_path)

        rouge_names = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
        rouge_metric = evaluate.load('rouge')

        score = self.calculate_metric_on_test_ds(dataset=dataset_samsum_pt['test'][0:10],metric=rouge_metric,model=model_bart,tokenizer=tokenizer,batch_size=4,column_text='dialogue',column_summary='summary',device=device)

        rouge_dict = {rn: score[rn] for rn in rouge_names}
        df = pd.DataFrame(rouge_dict, index=['bart'])

        return df.to_csv(self.config.metric_file_name, index=False)


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
try:
    config=ConfriguationManager()
    data_evaluation_config=config.get_model_evaluation_config()
    data_evaluation=ModelEvaluation(config=data_evaluation_config)
    data_evaluation.evaluate()
except Exception as e:
    raise e