In [None]:
import os

os.chdir("../")

In [None]:
%pwd

In [None]:
import os
from ecan.constants import *
from ecan.entity.config_entity import DataIngestionConfig, TrainingConfig
from ecan.utils.common import read_yaml, create_directories


class ConfigurationManager:
    def __init__(
            self,
            config_filepath=CONFIG_FILE_PATH,
            params_filepath=PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_data_ingestion_config(self) -> DataIngestionConfig:
        config = self.config.data_ingestion

        create_directories([config.root_dir])

        data_ingestion_config = DataIngestionConfig(
            root_dir=Path(config.root_dir),
            source_URL=Path(config.source_URL),
            local_data_file=Path(config.local_data_file),
            unzip_dir=Path(config.unzip_dir)
        )

        return data_ingestion_config

    def get_training_config(self) -> TrainingConfig:
        config = self.config.training
        training_data = os.path.join(self.config.prepare_base_model.root_dir, 'dataset')
        create_directories([config.root_dir])

        data_ingestion_config = TrainingConfig(
            root_dir=Path(config.root_dir),
            trained_model_path=Path(config.trained_model_path),
            updated_base_model_path=Path(self.config.prepare_base_model.updated_base_model_path),
            training_data=Path(training_data),
            params_epochs=self.params.N_EPOCHS,
            params_batch_size=self.params.BATCH_SIZE,
            params_scale_factor=self.params.SCALE_FACTOR,
            params_nFrame=self.params.N_FRAMES
        )

        return data_ingestion_config


In [None]:
from ecan.config.configuration import ConfigurationManager
from ecan.components.model_training import DataTraining
from ecan import logger

STAGE_NAME = "Data Training Stage"


class DataTrainingPipeline:
    def __init__(self):
        pass

    def main(self):
        config = ConfigurationManager()
        data_training_config = config.get_training_config()
        data_training = DataTraining(config=data_training_config)
        data_training.main()


if __name__ == '__main__':
    try:
        logger.info(f"*******************")
        logger.info(f">>>>>> stage {STAGE_NAME} started <<<<<<")
        obj = DataTrainingPipeline()
        obj.main()
        logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
    except Exception as e:
        logger.exception(e)
        raise e
