In [1]:
import os
%pwd

'f:\\ml projects\\brain_tumor_classification\\notebooks'

In [2]:
os.chdir("../")
%pwd

'f:\\ml projects\\brain_tumor_classification'

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelTrainingConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model:Path
    training_data:Path
    validation_data:Path
    params_is_augmentation:bool
    params_image_size: list
    params_batch_size: int
    params_epochs: int

@dataclass(frozen=True)
class PrepareCallbackConfig:
    root_dir:Path
    tensorboard_log_dir:Path
    checkpoint_model_path: Path


In [4]:
from brain_tumor_classification.utils.common import read_yaml,create_directories
from brain_tumor_classification.constants import *

class ConfigurationManager:
    def __init__(self,config_filepath = CONFIG_FILEPATH,params_filepath = PARAMS_FILEPATH) -> None:
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])
    
    def get_prepare_callback_config(self) -> PrepareCallbackConfig:
        temp_config = self.config.prepare_callbacks

        create_directories([temp_config.root_dir])

        prepare_callback_config = PrepareCallbackConfig(
            root_dir = temp_config.root_dir,
            tensorboard_log_dir = temp_config.tensorboard_log_dir,
            checkpoint_model_path = temp_config.checkpoint_model_path
        )

        return prepare_callback_config

    def get_model_training_config(self) -> ModelTrainingConfig:
        temp_config = self.config.model_training
        prepare_base_model = self.config.prepare_base_model
        params = self.params
        training_data_path = os.path.join(self.config.data_ingestion.unzip_dir,"brain_tumor_mris","Training")
        validation_data_path = os.path.join(self.config.data_ingestion.unzip_dir,"brain_tumor_mris","Validation")

        create_directories([temp_config.root_dir])
        
        model_training_config = ModelTrainingConfig(
            root_dir= Path(temp_config.root_dir),
            trained_model_path= Path(temp_config.trained_model_path),
            updated_base_model= Path(prepare_base_model.updated_base_model_path),
            training_data = Path(training_data_path),
            validation_data= Path(validation_data_path),
            params_is_augmentation= params.AUGMENTATION,
            params_image_size= params.IMAGE_SIZE,
            params_batch_size= params.BATCH_SIZE,
            params_epochs= params.EPOCHS        
        )
        return model_training_config

In [5]:
import tensorflow as tf
from datetime import datetime
from brain_tumor_classification.utils.exception import CustomException
import sys




In [6]:
class PrepareCallback:
    def __init__(self,config:PrepareCallbackConfig) -> None:
        self.config = config

    @property
    def _create_tensorboard(self):
        log_dir_name = f"TB_LOG_at{datetime.now().strftime('%d_%m_%y_%H_%M_%S')}"
        tensorboard_log_dir = os.path.join(self.config.tensorboard_log_dir,log_dir_name)

        return tf.keras.callbacks.TensorBoard(log_dir = tensorboard_log_dir)

    @property
    def _create_model_callback(self):
        return tf.keras.callbacks.ModelCheckpoint(self.config.checkpoint_model_path,save_best_only = True)
    
    def create_TB_CP(self):
        return [self._create_tensorboard,self._create_model_callback]
    

In [7]:
class ModelTraining:
    def __init__(self,config:ModelTrainingConfig) -> None:
        self.config = config

    def get_base_model(self):
        self.model = tf.keras.models.load_model(self.config.updated_base_model)
    
    def train_valid_generator(self):
        try:
            if self.config.params_is_augmentation: 
                train_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator(
                    rescale = 1./255,
                    shear_range = 0.2,
                    zoom_range = 0.2,
                    horizontal_flip = True
                )
                valid_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator(
                    rescale = 1./255
                )
            else:
                train_datagenerator = valid_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator(
                    rescale = 1./255
                )

            data_flow_kwargs = dict(
                target_size = self.config.params_image_size[:-1],
                batch_size = self.config.params_batch_size,
                class_mode = 'categorical',
            )
            self.training_data = train_datagenerator.flow_from_directory(
                directory = self.config.training_data,
                **data_flow_kwargs
            )

            self.validation_data = valid_datagenerator.flow_from_directory(
                directory = self.config.validation_data,
                **data_flow_kwargs
            )
            
        except Exception as e:
            raise CustomException(e,sys)

    def train_model(self,callbacks_list: list):
        self.steps_per_epochs = self.training_data.samples // self.training_data.batch_size
        self.validation_steps = self.validation_data.samples //self.validation_data.batch_size 

        self.model.fit(
            self.training_data,
            validation_data = self.validation_data,
            epochs = self.config.params_epochs,
            steps_per_epoch = self.steps_per_epochs,
            validation_steps = self.validation_steps,
            callbacks = callbacks_list
        )

        self.model.save(self.config.trained_model_path)

In [None]:
try:
    config = ConfigurationManager()
    get_callbacks_config = config.get_prepare_callback_config()
    prepare_callbacks = PrepareCallback(config = get_callbacks_config)
    callbacks_list =  prepare_callbacks.create_TB_CP()

    training_config = config.get_model_training_config()
    model_training = ModelTraining(config = training_config)
    model_training.get_base_model()
    model_training.train_valid_generator()
    model_training.train_model(callbacks_list = callbacks_list)
except Exception as e:
    raise CustomException(e,sys)