In [1]:
import os


os.chdir("../")

In [2]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class ModelTrainerConfig:
    root_dir: Path
    trained_model_file_path: Path
    updated_base_model_path: Path
    data_path: Path
    image_size: list
    epochs: int
    batch_size: int
    augmentation: bool

In [3]:
from cnn_classifier.constants import *
from cnn_classifier.utils.common import read_yaml, create_directories


class ConfigurationManager:
    def __init__(
        self,
        config_file_path: Path = CONFIG_FILE_PATH,
        params_file_path: Path = PARAMS_FILE_PATH,
    ):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_model_trainer_config(self) -> ModelTrainerConfig:
        cfg = self.config.model_trainer
        params = self.params.params

        updated_base_model_path = self.config.prepare_base_model.updated_base_model_path
        data_path = [
            f.path
            for f in os.scandir(self.config.data_ingestion.unzip_dir)
            if f.is_dir()
        ][0]

        create_directories([cfg.root_dir])

        model_trainer_config = ModelTrainerConfig(
            root_dir=cfg.root_dir,
            trained_model_file_path=cfg.trained_model_file_path,
            updated_base_model_path=updated_base_model_path,
            data_path=data_path,
            image_size=params.IMAGE_SIZE,
            epochs=params.EPOCHS,
            batch_size=params.BATCH_SIZE,
            augmentation=params.AUGMENTATION,
        )

        return model_trainer_config

In [4]:
import tensorflow as tf


class ModelTrainer:
    def __init__(self, config: ModelTrainerConfig):
        self.config = config

    def get_base_model(self):
        self.model = tf.keras.models.load_model(self.config.updated_base_model_path)

    def train_val_generator(self):
        data_generator_kwargs = dict(rescale=1 / 255, validation_split=0.20)
        data_flow_kwargs = dict(
            target_size=self.config.image_size[:-1],
            batch_size=self.config.batch_size,
            interpolation="bilinear",
        )

        val_data_generator = tf.keras.preprocessing.image.ImageDataGenerator(
            **data_generator_kwargs
        )

        self.val_generator = val_data_generator.flow_from_directory(
            directory=self.config.data_path,
            subset="validation",
            shuffle=False,
            **data_flow_kwargs,
        )

        if self.config.augmentation:
            train_data_generator = tf.keras.preprocessing.image.ImageDataGenerator(
                rotation_range=40,
                horizontal_flip=True,
                width_shift_range=0.2,
                height_shift_range=0.2,
                shear_range=0.2,
                zoom_range=0.2,
                **data_generator_kwargs,
            )
        else:
            train_data_generator = val_data_generator

        self.train_generator = train_data_generator.flow_from_directory(
            directory=self.config.data_path,
            subset="training",
            shuffle=True,
            **data_flow_kwargs,
        )

    @staticmethod
    def save_model(path: Path, model: tf.keras.Model):
        model.save(path)

    def train(self):
        self.steps_per_epoch = (
            self.train_generator.samples // self.train_generator.batch_size
        )
        self.validation_steps = (
            self.val_generator.samples // self.val_generator.batch_size
        )

        self.model.fit(
            self.train_generator,
            epochs=self.config.epochs,
            steps_per_epoch=self.steps_per_epoch,
            validation_steps=self.validation_steps,
            validation_data=self.val_generator,
        )

        self.save_model(path=self.config.trained_model_file_path, model=self.model)

In [5]:
from cnn_classifier import logger


class ModelTrainerPipeline:

    def run_pipeline(self):
        try:
            logger.info("Model training started")
            configuration_manager = ConfigurationManager()
            model_trainer_config = configuration_manager.get_model_trainer_config()
            model_trainer = ModelTrainer(config=model_trainer_config)
            model_trainer.get_base_model()
            model_trainer.train_val_generator()
            model_trainer.train()
            logger.info("Model training ended")

        except Exception as e:
            logger.error(f"Model training failed: {e}")
            raise e

In [6]:
model_trainer_pipeline = ModelTrainerPipeline()
model_trainer_pipeline.run_pipeline()

[ 2024-02-28 01:26:24,194 ] 8 3068857550 cnn_classifier -  INFO - 
Model training started
[ 2024-02-28 01:26:24,196 ] 34 common cnn_classifier -  INFO - Loaded YAML file successfully from: config/config.yaml
[ 2024-02-28 01:26:24,198 ] 34 common cnn_classifier -  INFO - Loaded YAML file successfully from: params.yaml
[ 2024-02-28 01:26:24,198 ] 55 common cnn_classifier -  INFO - Created directory at: artifacts
[ 2024-02-28 01:26:24,199 ] 55 common cnn_classifier -  INFO - Created directory at: artifacts/trained_model
Found 68 images belonging to 2 classes.
Found 275 images belonging to 2 classes.
[ 2024-02-28 01:27:01,391 ] 15 3068857550 cnn_classifier -  INFO - Model training ended

