In [1]:
import os 
os.chdir("../")

In [2]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class DataTransformationConfig:
    root_dir: Path
    dataset_dir: Path
    model_name: str
    tokenizer_name: str
    max_input_length: int = 1024
    max_target_length: int = 128

In [None]:
from src.textSummarizer.constants import *
from src.textSummarizer.utils.common import read_yaml, create_directories
from pathlib import Path
import logging
from pathlib import Path
import logging
from transformers import BartTokenizer, BartForConditionalGeneration
from datasets import load_dataset


logger = logging.getLogger(__name__)

In [20]:
class ConfigurationManager:
    def __init__(self,
                 config_filepath: str = CONFIG_FILE_PATH,
                 params_filepath: str = PARAMS_FILE_PATH):
        
        # Read configuration and parameters from YAML files
        self.config = read_yaml(Path(config_filepath))
        self.params = read_yaml(Path(params_filepath))

        # Create the root directory for artifacts
        create_directories([self.config.artifacts_root])

    def get_data_transformation_config(self) -> DataTransformationConfig:
        """Retrieve the data transformation configuration"""
        config = self.config.data_transformation

        # Create the directory for data transformation
        create_directories([config.root_dir])

        # Create and return the DataTransformationConfig object
        data_transformation_config = DataTransformationConfig(
            root_dir=Path(config.root_dir),  # Ensure this is correct
            dataset_dir=Path(config.dataset_dir),  # Ensure this is correct
            model_name=config.model_name,  # Use config directly
            tokenizer_name=config.tokenizer_name,  # Use config directly
            max_input_length=config.max_input_length,  # Use config directly
            max_target_length=config.max_target_length  # Use config directly
        )

        return data_transformation_config


In [21]:
class DataTransformation:
    def __init__(self, config: DataTransformationConfig):
        self.config = config
        self.tokenizer = BartTokenizer.from_pretrained(self.config.tokenizer_name)
        self.model = BartForConditionalGeneration.from_pretrained(
            self.config.model_name,
            gradient_checkpointing=True  # Saves memory during training
        )
        logger.info("Initialized DataTransformation with model: %s", 
                    self.config.model_name)

    def convert_examples_to_features(self, example_batch):
        """
        Convert raw text to tokenized features
        """
        try:
            # Tokenize inputs and targets
            input_encodings = self.tokenizer(
                example_batch['article'], 
                max_length=self.config.max_input_length,
                truncation=True,
                padding='max_length'
            )
            
            with self.tokenizer.as_target_tokenizer():
                target_encodings = self.tokenizer(
                    example_batch['highlights'],
                    max_length=self.config.max_target_length,
                    truncation=True,
                    padding='max_length'
                )
            
            return {
                'input_ids': input_encodings['input_ids'],
                'attention_mask': input_encodings['attention_mask'],
                'labels': target_encodings['input_ids']
            }
        except Exception as e:
            logger.error("Error in converting examples: %s", str(e))
            raise e
        

    def transform(self):
        """
        Complete data transformation pipeline for train, test, and validation datasets.
        """
        try:
            logger.info("Starting data transformation")
            
            # Load datasets from CSV files
            dataset_files = {
                "train": str(Path(self.config.dataset_dir) / "train.csv"),
                "test": str(Path(self.config.dataset_dir) / "test.csv"),
                "validation": str(Path(self.config.dataset_dir) / "validation.csv")
            }

            datasets = load_dataset('csv', data_files=dataset_files)

            logger.info("Loaded datasets: train (%d samples), test (%d samples), validation (%d samples)", 
                        len(datasets['train']), len(datasets['test']), len(datasets['validation']))

            # Apply transformation to each dataset
            transformed_train_dataset = datasets['train'].map(
                self.convert_examples_to_features,
                batched=True,
                remove_columns=['article', 'highlights', 'id']
            )
            transformed_test_dataset = datasets['test'].map(
                self.convert_examples_to_features,
                batched=True,
                remove_columns=['article', 'highlights', 'id']
            )
            transformed_validation_dataset = datasets['validation'].map(
                self.convert_examples_to_features,
                batched=True,
                remove_columns=['article', 'highlights', 'id']
            )

            logger.info("Dataset transformation completed")

            # Save transformed data
            transformed_train_path = Path(self.config.root_dir) / "transformed_train_data"
            transformed_test_path = Path(self.config.root_dir) / "transformed_test_data"
            transformed_validation_path = Path(self.config.root_dir) / "transformed_validation_data"

            transformed_train_dataset.save_to_disk(transformed_train_path)
            transformed_test_dataset.save_to_disk(transformed_test_path)
            transformed_validation_dataset.save_to_disk(transformed_validation_path)

            logger.info("Transformed datasets saved to: %s, %s, %s", 
                        transformed_train_path, transformed_test_path, transformed_validation_path)

            return {
                "train": transformed_train_dataset,
                "test": transformed_test_dataset,
                "validation": transformed_validation_dataset
            }
            
        except Exception as e:
            logger.exception("Data transformation failed: %s", str(e))
            raise e

In [22]:
# Get configuration
config_manager = ConfigurationManager()
transform_config = config_manager.get_data_transformation_config()

# Initialize and run transformation
data_transformation = DataTransformation(transform_config)
transformed_data = data_transformation.transform()

[2025-07-04 16:35:23,484: INFO: common: yaml file: config\config.yaml loaded successfully]
[2025-07-04 16:35:23,486: INFO: common: yaml file: config\params.yaml loaded successfully]
[2025-07-04 16:35:23,488: INFO: common: created directory at: artifacts]
[2025-07-04 16:35:23,489: INFO: common: created directory at: artifacts/data_transformation]
[2025-07-04 16:35:26,517: INFO: 314184361: Initialized DataTransformation with model: facebook/bart-base]
[2025-07-04 16:35:26,517: INFO: 314184361: Starting data transformation]


Generating train split: 287113 examples [00:14, 20373.58 examples/s]
Generating test split: 11490 examples [00:00, 19427.87 examples/s]
Generating validation split: 13368 examples [00:00, 19410.67 examples/s]


[2025-07-04 16:35:43,158: INFO: 314184361: Loaded datasets: train (287113 samples), test (11490 samples), validation (13368 samples)]


Map: 100%|██████████| 287113/287113 [23:46<00:00, 201.24 examples/s]
Map: 100%|██████████| 11490/11490 [01:02<00:00, 184.29 examples/s]
Map: 100%|██████████| 13368/13368 [01:09<00:00, 191.37 examples/s]

[2025-07-04 17:01:53,161: INFO: 314184361: Dataset transformation completed]



Saving the dataset (4/4 shards): 100%|██████████| 287113/287113 [00:04<00:00, 62239.97 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 11490/11490 [00:00<00:00, 72601.12 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 13368/13368 [00:00<00:00, 73027.45 examples/s]


[2025-07-04 17:01:58,266: INFO: 314184361: Transformed datasets saved to: artifacts\data_transformation\transformed_train_data, artifacts\data_transformation\transformed_test_data, artifacts\data_transformation\transformed_validation_data]
