In [1]:
import os

In [2]:
os.chdir('../')

In [3]:
from dataclasses import dataclass
from pathlib import Path 


@dataclass(frozen=True)
class Data_Transformation_Config:
    root_dir: Path
    data_path: Path
    preprocessor: str

In [6]:
from src.txt_sm.utils.common import read_yaml,create_dirs
from src.txt_sm.constants import *

class Configuration_Manager:
    def __init__(self,config_path=CONFIG_PATH,params_path=PARAMS_PATH):
        self.config = read_yaml(config_path)
        self.params = read_yaml(params_path)
        
        create_dirs([self.config.artifacts_root])
        
    def get_data_transformation_config(self) -> Data_Transformation_Config:
        config = self.config
        create_dirs([config.root_dir])
        data_transformation_config = Data_Transformation_Config(root_dir=config.root_dir,data_path=config.data_path,preprocessor=config.preprocessor)
        return data_transformation_config

In [7]:
from datasets import load_dataset,load_from_disk
from transformers import AutoTokenizer
from src.txt_sm.logging import logger
import os

class Data_Transformation:
    def __init__(self):
        config = Configuration_Manager()
        self.data_transformation_config = config.get_data_transformation_config()
        self.tokenizer = AutoTokenizer.from_pretrained(config.preprocessor)
        
    
    def convert_data_to_features(self,data):
        input_encodings = self.tokenizer(data['dialogue'],max_length=1024,truncation=True)
        
        with self.tokenizer.as_target_tokenizer():
            target_encodings = self.tokenizer(data['summary'],max_length=128,truncation=True)
        
        return {
            'input_ids': input_encodings['input_ids'],
            'attention_mask': input_encodings['attention_mask'],
            'labels': target_encodings['input_ids']
        }
    
    def initiate_data_transformation(self):
        dataset_samsum = load_from_disk(self.data_transformation_config.data_path)
        dataset_pt = dataset_samsum.map(self.convert_data_to_features,batched=True)
        dataset_pt.save_to_disk(os.path.join(self.data_transformation_config.root_dir,'preprocessed_dataset'))
        

  from .autonotebook import tqdm as notebook_tqdm


[2024-03-07 12:00:18,357: INFO: config: PyTorch version 2.2.1 available.]
