In [1]:
import os
os.chdir('../')

In [2]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class DataTransformationConfig:
    root_dir : Path
    data_path : Path
    tokenizer_path : Path
    max_length : int

In [3]:
from machinetranslation.constants import *
from machinetranslation.utils.common import read_yaml,create_directories

class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_data_transformation_config(self) -> DataTransformationConfig:
        config = self.config.data_transformation
        
        create_directories([config.root_dir])
        
        data_transformation_config = DataTransformationConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            tokenizer_path=config.tokenizer_path,
            max_length=config.max_lenght
        )
        
        return data_transformation_config

In [4]:

from datasets import Dataset
from transformers import M2M100Tokenizer
import pandas as pd
from machinetranslation.logging import logger
from typing import Tuple

class DataTransformation:
    def __init__(self, config: DataTransformationConfig):
        self.config = config
        self.tokenizer = M2M100Tokenizer.from_pretrained(self.config.tokenizer_path)

    def read_data(self) -> pd.DataFrame:
        df = pd.read_csv(self.config.data_path)
        logger.info(f"Data shape: {df.shape}")
        return df

    def preprocess_function(self, examples):
        inputs = self.tokenizer(examples['en'], max_length=self.config.max_length, truncation=True, padding="max_length")
        targets = self.tokenizer(examples['bn'], max_length=self.config.max_length, truncation=True, padding="max_length")
        inputs["labels"] = targets["input_ids"]
        return inputs

    def transform_data(self) -> Tuple[Dataset, Dataset]:
        df = self.read_data()
        dataset = Dataset.from_pandas(df)
        
        tokenized_dataset = dataset.map(self.preprocess_function, batched=True, remove_columns=dataset.column_names)
        
        train_dataset = tokenized_dataset.shuffle(seed=42).select(range(int(len(tokenized_dataset) * 0.8)))
        eval_dataset = tokenized_dataset.shuffle(seed=42).select(range(int(len(tokenized_dataset) * 0.8), len(tokenized_dataset)))

        return train_dataset, eval_dataset

    def save_datasets(self, train_dataset: Dataset, eval_dataset: Dataset):
        train_dataset.save_to_disk(os.path.join(self.config.root_dir, "train"))
        eval_dataset.save_to_disk(os.path.join(self.config.root_dir, "eval"))
        logger.info(f"Datasets saved to {self.config.root_dir}")


[2024-08-09 17:53:19,139: INFO: config: PyTorch version 2.2.2+cu121 available.]
[2024-08-09 17:53:19,141: INFO: config: TensorFlow version 2.12.0 available.]


In [6]:
try:
    config = ConfigurationManager()
    data_transformation_config = config.get_data_transformation_config()
    data_transformation = DataTransformation(config=data_transformation_config)
    train_dataset, eval_dataset = data_transformation.transform_data()
    data_transformation.save_datasets(train_dataset, eval_dataset)
except Exception as e:
    raise e

[2024-08-09 17:53:59,525: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-08-09 17:53:59,527: INFO: common: yaml file: params.yaml loaded successfully]
[2024-08-09 17:53:59,529: INFO: common: created directory at: artifacts]
[2024-08-09 17:53:59,530: INFO: common: created directory at: artifacts/data_transformation]
[2024-08-09 17:53:59,889: INFO: 693474816: Data shape: (2000, 2)]


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1600 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/400 [00:00<?, ? examples/s]

[2024-08-09 17:54:01,761: INFO: 693474816: Datasets saved to artifacts/data_transformation]
