In [1]:
import os

In [2]:
%pwd

'c:\\Users\\Administrator\\Desktop\\Kidney-Disease-Classification\\research'

In [3]:
os.chdir("../")

In [4]:
%pwd


'c:\\Users\\Administrator\\Desktop\\Kidney-Disease-Classification'

In [5]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model_path: Path
    training_data: Path
    params_epochs: int
    params_batch_size: int
    params_is_augmentation: bool
    params_image_size: list

In [6]:
from CNN_Classifier.constants import *
from CNN_Classifier.utils.common import read_yaml, create_directories
import tensorflow as tf

In [7]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])


    
    def get_training_config(self) -> TrainingConfig:
        training = self.config.training
        prepare_base_model = self.config.prepare_base_model
        params = self.params
        training_data = os.path.join(self.config.data_ingestion.unzip_dir,"kidney-ct-scan-image", "CT-KIDNEY-DATASET-Normal-Cyst-Tumor-Stone")
        create_directories([
            Path(training.root_dir)
        ])

        training_config = TrainingConfig(
            root_dir=Path(training.root_dir),
            trained_model_path=Path(training.trained_model_path),
            updated_base_model_path=Path(prepare_base_model.updated_base_model_path),
            training_data=Path(training_data),
            params_epochs=params.EPOCHS,
            params_batch_size=params.BATCH_SIZE,
            params_is_augmentation=params.AUGMENTATION,
            params_image_size=params.IMAGE_SIZE
        )

        return training_config

In [8]:
import os
import urllib.request as request
from zipfile import ZipFile
import tensorflow as tf
import time

In [None]:


import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    ReduceLROnPlateau,
    TensorBoard
)

from sklearn.metrics import confusion_matrix, classification_report, f1_score
from sklearn.utils.class_weight import compute_class_weight


# ============================
# CONFIG
# ============================
@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model_path: Path
    training_data: Path
    params_epochs: int
    params_batch_size: int
    params_is_augmentation: bool
    params_image_size: list


# ============================
# TRAINING CLASS
# ============================
class Training:
    def __init__(self, config: TrainingConfig):
        self.config = config

    # ----------------------------
    # Load Model
    # ----------------------------
    def get_base_model(self):
        self.model = tf.keras.models.load_model(
            self.config.updated_base_model_path
        )

    # ----------------------------
    # Data Generators
    # ----------------------------
    def train_valid_generator(self):
        datagenerator_kwargs = dict(
            preprocessing_function=preprocess_input,
            validation_split=0.20
        )

        dataflow_kwargs = dict(
            target_size=self.config.params_image_size[:-1],
            batch_size=self.config.params_batch_size,
            interpolation="bilinear"
        )

        valid_datagen = ImageDataGenerator(**datagenerator_kwargs)
        train_datagen = ImageDataGenerator(
            rotation_range=15,
            width_shift_range=0.1,
            height_shift_range=0.1,
            zoom_range=0.15,
            shear_range=0.1,
            brightness_range=[0.8, 1.2],
            horizontal_flip=True,
            **datagenerator_kwargs
        )

        self.valid_generator = valid_datagen.flow_from_directory(
            directory=self.config.training_data,
            subset="validation",
            shuffle=False,
            **dataflow_kwargs
        )

        self.train_generator = train_datagen.flow_from_directory(
            directory=self.config.training_data,
            subset="training",
            shuffle=True,
            **dataflow_kwargs
        )

    # ----------------------------
    # Compile Model Helper
    # ----------------------------
    def _compile_model(self, lr):
        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
            loss="categorical_crossentropy",
            metrics=[
                "accuracy",
                tf.keras.metrics.Precision(name="precision"),
                tf.keras.metrics.Recall(name="recall"),
                tf.keras.metrics.AUC(name="auc")
            ]
        )

    # ----------------------------
    # Training
    # ----------------------------
    def train(self):
        # ---- Class Weights ----
        classes = self.train_generator.classes
        class_weights = compute_class_weight(
            class_weight="balanced",
            classes=np.unique(classes),
            y=classes
        )
        class_weight_dict = dict(enumerate(class_weights))

        callbacks = [
            ModelCheckpoint(
                filepath=self.config.trained_model_path,
                monitor="val_accuracy",
                save_best_only=True,
                mode="max",
                verbose=1
            ),
            EarlyStopping(
                monitor="val_loss",
                patience=5,
                restore_best_weights=True,
                verbose=1
            ),
            ReduceLROnPlateau(
                monitor="val_loss",
                factor=0.5,
                patience=2,
                min_lr=1e-6,
                verbose=1
            ),
            TensorBoard(log_dir="logs/fit")
        ]

        # ============================
        # PHASE 1: Train Classifier Head
        # ============================
        print("\nðŸ”µ Phase 1: Training classifier head")

        for layer in self.model.layers:
            layer.trainable = False

        self._compile_model(lr=1e-4)

        self.model.fit(
            self.train_generator,
            epochs=5,
            validation_data=self.valid_generator,
            callbacks=callbacks,
            class_weight=class_weight_dict
        )

        # ============================
        # PHASE 2: Fine-tuning Backbone
        # ============================
        print("\nðŸŸ¢ Phase 2: Fine-tuning backbone")

        for layer in self.model.layers[-50:]:
            if not isinstance(layer, tf.keras.layers.BatchNormalization):
                layer.trainable = True

        self._compile_model(lr=1e-5)

        self.model.fit(
            self.train_generator,
            epochs=self.config.params_epochs,
            validation_data=self.valid_generator,
            callbacks=callbacks,
            class_weight=class_weight_dict
        )

    # ----------------------------
    # Evaluation
    # ----------------------------
    def evaluate_model(self):
        y_true = self.valid_generator.classes
        y_pred_prob = self.model.predict(self.valid_generator, verbose=1)
        y_pred = np.argmax(y_pred_prob, axis=1)

        cm = confusion_matrix(y_true, y_pred)

        print("\nCONFUSION MATRIX")
        print(cm)

        print("\nCLASSIFICATION REPORT")
        print(
            classification_report(
                y_true,
                y_pred,
                target_names=list(self.valid_generator.class_indices.keys())
            )
        )

        f1 = f1_score(y_true, y_pred, average="weighted")
        print("Weighted F1 Score:", f1)

        # Plot Confusion Matrix
        plt.figure(figsize=(8, 6))
        plt.imshow(cm)
        plt.title("Confusion Matrix")
        plt.colorbar()
        plt.xticks(
            np.arange(len(self.valid_generator.class_indices)),
            self.valid_generator.class_indices.keys(),
            rotation=45
        )
        plt.yticks(
            np.arange(len(self.valid_generator.class_indices)),
            self.valid_generator.class_indices.keys()
        )
        plt.xlabel("Predicted")
        plt.ylabel("Actual")
        plt.tight_layout()
        plt.show()
