In [1]:
%pwd

'd:\\ITI\\DS Track\\Deep Learning\\Projects\\Kidney Classification\\Kidney-Disease-Classification\\research'

In [2]:
import os

os.chdir("../")

%pwd

'd:\\ITI\\DS Track\\Deep Learning\\Projects\\Kidney Classification\\Kidney-Disease-Classification'

# Update the config.yaml is done


# update the params.yaml

# Update the Entity


In [49]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
# frozen=true : dont add anything by the user 
# so those vars will become constant and we cant add any var

class PrepareBaseModelConfig:
    root_dir:Path
    base_model_path:Path
    updated_base_model_path:Path
    params_image_size:list
    params_learning_rate:float
    params_including_top:int
    params_weight:str
    params_classes:int
    params_model_name:str


In [5]:
from src.CNNClassifierKidneyDiseases.constants import *
from src.CNNClassifierKidneyDiseases.utils.common import read_yaml , create_directories


In [50]:
class ConfigurationManager:
    def __init__(self, config_filepath = CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifact_root]) 


    def get_prepare_base_model(self)->PrepareBaseModelConfig:
        config = self.config.prepare_base_model

        create_directories([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir = Path(config.root_dir),
            base_model_path = Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_image_size= self.params.IMAGE_SIZE,
            params_learning_rate= self.params.LEARNING_RATE,
            params_classes= self.params.CLASSES,
            params_weight= self.params.WEIGHTS,
            params_including_top= self.params.INCLUDING_TOP,
            params_model_name=self.params.MODEL_NAME
                    )
        return prepare_base_model_config
    


# Updating Componenets


In [18]:
import os
from zipfile import ZipFile
import torch 
import urllib.request as requests
import timm
import tqdm 
import torch.nn as nn
from torch.nn import Module
import torch.optim as optim
from typing import Optional


In [None]:
class PrepareBaseModel:
    def __init__(self,config:PrepareBaseModelConfig):
        self.config = config

    def get_base_model(self):
        self.model = timm.create_model(self.config.params_model_name, pretrained=True)

        self.save_model(path=self.config.base_model_path , model=self.model)
    

    @staticmethod
    def _prepare_full_model(model: Module, classes: int, freeze_all: bool, freeze_till: Optional[int], learning_rate: float):
        """
        Equivalent to the Keras _prepare_full_model function.
        - Freezes layers
        - Adds a classifier head
        - Prepares optimizer and loss
        """
        # Freeze all or part of the model
        layers = list(model.children())

        for param in model.parameters():
            param.requires_grad = False



        # Get output features from the model
        # classif : is the last layer of this model
        in_features = model.classif.in_features


        # Replace the classifier head
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, classes)
        )

        # Attach new classifier
        model.classif = classifier


        # ✅ Unfreeze classifier parameters
        for param in classifier.parameters():
            param.requires_grad = True

        # Set optimizer and loss
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        criterion = nn.CrossEntropyLoss()

        # Print model summary (simplified)
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Trainable parameters: {trainable_params}/{total_params}")

        return model, criterion, optimizer

    def update_base_model(self):
        self.full_model, self.criterion, self.optimizer = self._prepare_full_model(
            model=self.model,
            classes=self.config.params_classes,
            freeze_all=True,
            freeze_till=None,
            learning_rate=self.config.params_learning_rate
        )

        self.save_model(path=self.config.updated_base_model_path, model=self.full_model)



    @staticmethod
    def save_model(path:Path , model:torch.nn.Module):
        torch.save(model.state_dict(),path)

    

In [56]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model()
    prepare_base_model =PrepareBaseModel(config = prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()

except Exception as e:
    raise e

[2025-07-06 01:04:25,508: INFO: common: yaml file config\config.yaml loaded successfully]
[2025-07-06 01:04:25,513: INFO: common: yaml file params.yaml loaded successfully]
[2025-07-06 01:04:25,515: INFO: common: Created Directory at : artifacts]
[2025-07-06 01:04:25,516: INFO: common: Created Directory at : artifacts/prepare_base_model]
[2025-07-06 01:04:26,016: INFO: _builder: Loading pretrained weights from Hugging Face hub (timm/inception_resnet_v2.tf_in1k)]
[2025-07-06 01:04:26,199: INFO: _hub: [timm/inception_resnet_v2.tf_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.]
Trainable parameters: 393986/54700450


In [52]:
model=timm.create_model("inception_resnet_v2", pretrained=True)
model


[2025-07-06 00:52:25,966: INFO: _builder: Loading pretrained weights from Hugging Face hub (timm/inception_resnet_v2.tf_in1k)]
[2025-07-06 00:52:26,216: INFO: _hub: [timm/inception_resnet_v2.tf_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.]


InceptionResnetV2(
  (conv2d_1a): ConvNormAct(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNormAct2d(
      32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): ReLU(inplace=True)
    )
  )
  (conv2d_2a): ConvNormAct(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNormAct2d(
      32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): ReLU(inplace=True)
    )
  )
  (conv2d_2b): ConvNormAct(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNormAct2d(
      64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): ReLU(inplace=True)
    )
  )
  (maxpool_3a): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2d_3b): ConvNormAct(
    (conv): Conv2d(64, 80, kernel_size=(1, 