In [None]:
import os
os.chdir("../")
%pwd

In [None]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class DataTransformationConfig:
    root_dir: Path 
    data_path: Path 
    checkpoint: str 
    max_length: int 
    min_length: int 
    output_dir: Path 
    prefix: str 
    sample_size: int 

In [1]:
from TextSummarizer.constants import *
from TextSummarizer.utils.file_utils import *
from TextSummarizer.utils.config_utils import *
from TextSummarizer.utils.lib_utils import *

In [None]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])


    
    def get_data_transformation_config(self) -> DataTransformationConfig:
        config = self.config.data_transformation

        create_directories([config.root_dir])

        data_transformation_config = DataTransformationConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            tokenizer_name = config.tokenizer_name,
            max_length = config.max_length,
            min_length = config.min_length,
            output_dir = config.output_dir,
            prefix = config.prefix,
            checkpoint = config.checkpoint
            sample_size = config.sample_size
        )

        return data_transformation_config
    
    

In [None]:
import pandas as pd
import numpy as np
from transformers import AutoTokenizer ,AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import pipeline
from datasets import Dataset
from sklearn.model_selection import train_test_split
import torch
from tqdm import tqdm
import evaluate
import nltk
from datasets import DatasetDict, load_dataset
from pathlib import Path
from transformers import DataCollatorForSeq2Seq
from TextSummarizer.logging import logger

In [None]:
class DataTransformation:
    def __init__(self, config: DataTransformationConfig):
        logger.info("Initializing DataTransformation with config")
        self.config = config
        self.checkpoint = self.config.checkpoint
        self.max_length = self.config.max_length
        self.min_length = self.config.min_length
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.dataset_path = self.config.data_path
        self.sample_size = self.config.sample_size
        logger.info(f"Initialized with checkpoint={self.checkpoint}, max_length={self.max_length}, min_length={self.min_length}, dataset_path={self.dataset_path}, sample_size={self.sample_size}")
 
    def load_data_into_DatasetDict(self, dataset_type: str = "csv") -> DatasetDict:
        logger.info("Loading data into DatasetDict")
        dataset = load_dataset(dataset_type, data_files=self.dataset_path)
        logger.info(f"Loaded dataset with {dataset_type} from {self.dataset_path}")

        # Ensure required columns exist
        required_features = ['transcript_standardized', 'description_standardized', 'title_standardized']
        feature_names = list(dataset["train"].features.keys())
        assert all(col in feature_names for col in required_features), f"Missing required columns: {required_features}"
        logger.info(f"Dataset contains required features: {required_features}")

        # Remove unnecessary columns
        logger.info("Removing unnecessary columns: ['description', 'tags', 'title', 'ratings', 'transcript']")
        dataset = dataset.remove_columns(['description', 'tags', 'title', 'ratings', 'transcript'])

        # Standardize column names
        logger.info("Renaming columns to standardized names")
        dataset = dataset.rename_column("transcript_standardized", "text")
        dataset = dataset.rename_column("description_standardized", "summary")
        dataset = dataset.rename_column("title_standardized", "title")

        # Split the dataset into train/validation/test sets
        logger.info("Splitting dataset into train/validation/test splits")
        train_test_split = dataset["train"].train_test_split(test_size=0.2)
        train_val_dataset = train_test_split["train"]
        test_dataset = train_test_split["test"]

        train_val_split = train_val_dataset.train_test_split(test_size=0.1)
        train_dataset = train_val_split["train"]
        validation_dataset = train_val_split["test"]

        dataset = DatasetDict({
            "train": train_dataset,
            "validation": validation_dataset,
            "test": test_dataset
        })
        logger.info("Dataset successfully split into train/validation/test")

        return dataset

    def preprocess_function(self, dataset):
        logger.info("Preprocessing dataset")
        if isinstance(dataset, DatasetDict):
            for split in dataset:
                logger.info(f"Preprocessing split: {split}")
                dataset[split] = dataset[split].map(self._preprocess_single_split)
            logger.info("All splits preprocessed")
            return dataset
        else:
            logger.info("Preprocessing a single dataset")
            return self._preprocess_single_split(dataset)

    def _preprocess_single_split(self, batch):
        logger.info("Preprocessing a single split batch")
        if "text" not in batch or "summary" not in batch:
            raise KeyError(f"Keys 'text' or 'summary' not found. Available keys: {list(batch.keys())}")
        
        inputs = [self.prefix + doc for doc in batch["text"]]
        model_inputs = self.tokenizer(inputs, padding=True, max_length=1024, truncation=True)
        labels = self.tokenizer(text_target=batch["summary"], padding=True, max_length=128, truncation=True)
        model_inputs["labels"] = labels["input_ids"]

        logger.info("Successfully tokenized inputs and labels for a batch")
        return model_inputs

    def tokenize_dataset(self, dataset):
        logger.info("Tokenizing dataset")
        tokenized_dataset = dataset.map(self.preprocess_function, batched=True)
        logger.info("Dataset successfully tokenized")
        return tokenized_dataset

    def data_sample_loader(self, dataset: Dataset):
        logger.info("Sampling data from dataset")
        sampled_dataset_random = dataset
        for split in dataset:
            logger.info(f"Sampling {self.sample_size} examples from split: {split}")
            sampled_dataset_random[split] = dataset[split].shuffle(seed=42).select(range(self.sample_size))
        logger.info("Successfully loaded sampled dataset")
        return sampled_dataset_random

    def save_dataset(self, dataset: DatasetDict):
        save_dir = self.config.output_dir
        if not isinstance(dataset, DatasetDict):
            raise ValueError("Provided dataset is not a DatasetDict object")

        logger.info(f"Saving DatasetDict to directory: {save_dir}")

        try:
            dataset.save_to_disk(save_dir)
            logger.info(f"Dataset successfully saved to {save_dir}")
        except Exception as e:
            logger.error(f"Failed to save dataset to {save_dir}: {e}")
            raise e

In [None]:
try:
    config = ConfigurationManager()
    data_transformation_config = config.get_data_transformation_config()
    data_transformation = DataTransformation(config=data_transformation_config)
    dataset = data_transformation.load_data_into_DatasetDict()
    model_inputs = data_transformation.preprocess_function(dataset)
    tokenized_dataset = data_transformation.tokenize_dataset(model_inputs)
    data_transformation.save_dataset(tokenized_dataset)
    sampled_dataset = data_transformation.data_sample_loader(tokenized_dataset)
    data_transformation.save_dataset(sampled_dataset)
except Exception as e:
    raise e