In [None]:
import os

In [None]:
pwd

In [None]:
os.chdir('..')

In [None]:
pwd

## Entity

In [None]:
from dataclasses import  dataclass
from pathlib import  Path

In [None]:
@dataclass(frozen=True)
class Training:
    root_dir: Path
    model_path: Path
    model_metrics_path: Path
    data: Path
    updated_model_path: Path
    INPUT_SHAPE: list
    BATCH_SIZE: int
    SHUFFLE: bool
    VALIDATION_SPLIT: float
    LABEL_MODEL: str
    EPOCHS: int
    AUGMENTED: bool

    

## Config Manager

In [None]:
from plant_disease_clf.utils.common import  create_directories, read_yaml
from plant_disease_clf.constants import  *

In [None]:
class ConfigManager:
    def __init__(self,
        config_file_path = CONFIG_FILE_PATH,
        params_file_path = PARAMS_FILE_PATH):

        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_model_training_config(self):

        self.config = self.config.model_training
        self.params = self.params.resnet50

        create_directories([self.config.root_dir])

        return Training (
            root_dir = Path(self.config.root_dir),
            model_path = Path(self.config.model_path),
            model_metrics_path = Path(self.config.model_metrics_path),
            data = Path(self.config.data),
            updated_model_path = Path(self.config.updated_model_path),
            INPUT_SHAPE = self.params.INPUT_SHAPE,
            BATCH_SIZE = self.params.BATCH_SIZE,
            SHUFFLE = self.params.SHUFFLE,
            VALIDATION_SPLIT = self.params.VALIDATION_SPLIT,
            LABEL_MODEL = self.params.LABEL_MODEL,
            EPOCHS = self.params.EPOCHS,
            AUGMENTED = self.params.AUGMENTED
        )



## Components

In [None]:
import os,sys
from pathlib import Path
from plant_disease_clf.logger import  logging
from plant_disease_clf.exception import  CustomException

from plant_disease_clf.utils.common import  save_json

import tensorflow as tf
from plant_disease_clf.pipeline.stage_03_prepare_callbacks import  CallbacksPipeline

In [None]:
class ModelTraining:
    def __init__(self, config: Training):
        self.config = config
    
    def get_updated_model(self):
        return tf.keras.models.load_model(self.config.updated_model_path)


    def train(self):
        try:
            logging.info(f"Training the model with the following config: {self.config}")

            callbacks = CallbacksPipeline()
            callbacks_list = callbacks.main()

# -------------------------------------------------------------------------------------------------------------------------------------
            # train_data, val_data = tf.keras.utils.image_dataset_from_directory(
            #     directory = self.config.data,
            #     labels = "inferred",
            #     label_mode = self.config.LABEL_MODEL,
            #     batch_size = self.config.BATCH_SIZE,
            #     image_size = self.config.INPUT_SHAPE,
            #     shuffle = self.config.SHUFFLE,
            #     seed = 42,  #--------------------- add this for splitting the data
            #     validation_split = self.config.VALIDATION_SPLIT,
            #     subset = self.config.SUBSET,
            #     interpolation = "bilinear",
            # )
            # train_data_samples = tf.data.experimental.cardinality(train_data).numpy()
            # val_data_samples = tf.data.experimental.cardinality(val_data).numpy()
# -------------------------------------------------------------------------------------------------------------------------------------
            data_generator_kwargs = dict(
                rescale = 1./255,
                validation_split = 0.20
            )

            data_flow_kwargs = dict(
                target_size = self.config.INPUT_SHAPE[:-1],
                batch_size = self.config.BATCH_SIZE,
                interpolation = 'bilinear'
            )


            valid_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
                **data_generator_kwargs
            )

            val_data = valid_data_gen.flow_from_directory(
                directory = self.config.data,
                class_mode = self.config.LABEL_MODEL,
                shuffle = self.config.SHUFFLE,
                seed = 42,  #--------------------- add this for splitting the data
                subset = 'validation',
                # interpolation = "bilinear",
                **data_flow_kwargs
            )

            if self.config.AUGMENTED:
                train_data_gen= tf.keras.preprocessing.image.ImageDataGenerator(
                    rotation_range = 40,
                    horizontal_flip = True,
                    width_shift_range = 0.2,
                    height_shift_range = 0.2,
                    shear_range = 0.2,
                    zoom_range = 0.2,
                    **data_generator_kwargs
                )

            else:
                train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
                    **data_generator_kwargs
                )

            train_data = train_data_gen.flow_from_directory(
                directory = self.config.data,
                class_mode = self.config.LABEL_MODEL,
                shuffle = True, ### 
                # seed = 42,  #--------------------- add this for splitting the data
                subset = 'training',
                # interpolation = "bilinear",
                **data_flow_kwargs
            )

            model = self.get_updated_model()

            print(callbacks_list)
            print(model.summary())

            # history = model.fit(
            #     train_data,
            #     epochs = self.config.EPOCHS,
            #     validation_data = val_data,
            #     callbacks = callbacks_list
            # )

            
            self.steps_per_epoch = train_data.samples // self.config.BATCH_SIZE

            self.validation_steps = val_data.samples // self.config.BATCH_SIZE 


            history = model.fit(
                train_data,
                steps_per_epoch = self.steps_per_epoch,
                epochs = self.config.EPOCHS,
                validation_data = val_data,
                validation_steps = self.validation_steps,
                callbacks = callbacks_list
            )

            # tf.keras.models.save_model(model, self.config.model_path)
            # save_json(self.config.model_metrics_path, history.history)
            logging.info("Model training completed successfully")


        except Exception as e:
            logging.error(f"Model training failed: {e}")
            raise CustomException(e,sys)




## Pipeline

In [None]:


try:
    config_manager = ConfigManager()
    config = config_manager.get_model_training_config()
    model_training = ModelTraining(config)
    model_training.train()

except Exception as e:
    raise CustomException(e,sys)