In [1]:
import os

In [2]:
%pwd

'C:\\Users\\DIKSHANT PATEL\\Kidney-Disease-Classification\\research'

In [3]:
os.chdir('../.')

In [4]:
%pwd

'C:\\Users\\DIKSHANT PATEL\\Kidney-Disease-Classification'

In [5]:
from dataclasses import dataclass
from typing import Any

@dataclass
class CallbacksConfig:
    checkpoint_path: str
    early_stopping_params: dict
    reduce_lr_params: dict
    checkpoint_params: dict

In [6]:
from cnnClassifier.constants import *
from cnnClassifier.utils.common import read_yaml, create_directories

In [7]:
class ConfigurationManager:
    def __init__(self,
                 config_filepath=CONFIG_FILE_PATH,
                 params_filepath=PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_roots])
    
    def get_callbacks_config(self) -> CallbacksConfig:
    
        callbacks_config = self.config.callbacks
        training_params = self.params.training

        create_directories([callbacks_config.checkpoint_dir])

        return CallbacksConfig(
            checkpoint_path=training_params.checkpoint.filepath,
            early_stopping_params=training_params.early_stopping,
            reduce_lr_params=training_params.reduce_lr,
            checkpoint_params=training_params.checkpoint
        )

In [8]:
import numpy as np
from tensorflow.keras.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    ReduceLROnPlateau,
    Callback
)
from sklearn.utils.class_weight import compute_class_weight
from cnnClassifier import logger

In [13]:
class CustomObjectiveLogger(Callback):
    def on_epoch_end(self, epoch, logs=None):
        acc_val = logs.get('val_accuracy', 0)
        acc_train = logs.get('accuracy', 0)
        loss_val = logs.get('val_loss', 0)
        loss_train = logs.get('loss', 0)

        loss_diff = abs(loss_val - loss_train)
        objective_value = acc_val - loss_diff
        logs['val_objective'] = objective_value
        
        logger.info(
            f"[Epoch {epoch + 1:03d}] Custom Objective = {objective_value:.6f},\n "
            f"Train Acc = {acc_train:.4f}, Train Loss = {loss_train:.4f}, "
            f"Val Acc = {acc_val:.4f}, Val Loss = {loss_val:.4f}"
        )
        
class CallbackHandler:
    def __init__(self, config: CallbacksConfig, ori_training_set ):
        self.config = config
        self.ori_training_set = ori_training_set

    def get_class_weights(self):
        logger.info("Computing class weights...")
        class_weights = compute_class_weight(
            class_weight='balanced',
            classes=np.unique(self.ori_training_set.classes),
            y=self.ori_training_set.classes
        )
        class_weights_dict = dict(zip(np.unique(self.ori_training_set.classes), class_weights))
        logger.info(f"Class Weights: {class_weights_dict}")
        return class_weights_dict

    def get_callbacks(self):
        logger.info("Preparing callbacks...")

        checkpoint = ModelCheckpoint(
            **self.config.checkpoint_params
        )

        early_stopping = EarlyStopping(
            **self.config.early_stopping_params
        )

        reduce_lr = ReduceLROnPlateau(
            **self.config.reduce_lr_params
        )

        custom_logger = CustomObjectiveLogger()

        return [custom_logger, early_stopping, reduce_lr, checkpoint]

In [14]:
from cnnClassifier.config.configuration import ConfigurationManager
from cnnClassifier.components.data_loader import DataLoader
from cnnClassifier.components.callback import CallbacksConfig


In [15]:
try:
    config = ConfigurationManager()
    
    data_loader_config = config.get_data_loader_config()
    data_loader = DataLoader(config=data_loader_config)
    train_generator, val_generator, test_generator,train_df, ori_train = data_loader.get_generators()

    callbacks_config = config.get_callbacks_config()
    handler = CallbackHandler(config=callbacks_config, ori_training_set=ori_train)
    class_weights = handler.get_class_weights()
    callbacks = handler.get_callbacks()
    
except Exception as e:
    raise e

[2025-04-22 09:12:16,918: INFO: common: yaml file: config\config.yaml loaded successfully]
[2025-04-22 09:12:16,928: INFO: common: yaml file: params.yaml loaded successfully]
[2025-04-22 09:12:16,929: INFO: common: created directory at: artifacts]
[2025-04-22 09:12:16,930: INFO: data_loader: Loading dataframe from file: artifacts/data_split/train.csv]
[2025-04-22 09:12:16,955: INFO: data_loader: Dataframe loaded successfully with 8712 records.]
[2025-04-22 09:12:16,955: INFO: data_loader: Loading dataframe from file: artifacts/data_split/val.csv]
[2025-04-22 09:12:16,960: INFO: data_loader: Dataframe loaded successfully with 1121 records.]
[2025-04-22 09:12:16,961: INFO: data_loader: Loading dataframe from file: artifacts/data_split/test.csv]
[2025-04-22 09:12:16,970: INFO: data_loader: Dataframe loaded successfully with 2613 records.]
[2025-04-22 09:12:17,015: INFO: data_loader: Generators created successfully.]
Found 8712 validated image filenames belonging to 4 classes.
Found 8712 v