In [1]:
import os
%pwd

'/media/kirti/Dev/DeepLearning/Project/E2E/ChestCancerDetection/research'

In [2]:
os.chdir('../')
%pwd

'/media/kirti/Dev/DeepLearning/Project/E2E/ChestCancerDetection'

In [3]:
from dataclasses import dataclass
from pathlib import Path    

@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_save_path: Path
    update_base_model_path: Path
    params_image_size: list
    params_batch_size: int
    params_epochs: int
    params_classes: int
    params_weights: str
    params_learning_rate: float

In [None]:
from cnnClassifier.constants import *
from cnnClassifier.utils.common import read_yaml, create_directories


In [5]:
class ConfigManager:
    def __init__(self, config_path: Path = CONFIG_FILE_PATH, 
                 params_path: Path = PARAMS_FILE_PATH):
        """        Initializes the ConfigManager with paths to the configuration and parameters files.
        Args:
            config_path (Path): Path to the configuration file.
            params_path (Path): Path to the parameters file.
        """
        self.config = read_yaml(config_path)
        self.params = read_yaml(params_path)

        create_directories([self.config.artifacts_root])


    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:

        config = self.config.prepare_base_model

        create_directories([config.root_dir])
        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_save_path=Path(config.base_model_save_path),
            update_base_model_path=Path(config.update_base_model_path),
            params_image_size=self.params.IMAGE_SIZE,
            params_batch_size=self.params.BATCH_SIZE,
            params_epochs=self.params.EPOCHS,
            params_classes=self.params.CLASSES,
            params_weights=self.params.WEIGHTS,
            params_learning_rate=self.params.LEARNING_RATE
        )

        return prepare_base_model_config
        

In [None]:
import os
import torch
import torch.nn as nn
import torchvision
from torchsummary import summary
from torchvision.models import VGG16_Weights

In [7]:
class PrepareBaseModel:
    def __init__(self, config: PrepareBaseModelConfig):
        """Initializes the PrepareBaseModel with the given configuration.
        
        Args:
            config (PrepareBaseModelConfig): Configuration for preparing the base model.
        """
        self.config = config

    def get_base_model(self):
        """ Return vgg16 model with imagenet weights without top layer."""
        base_model = torchvision.models.vgg16(
            weights=self.config.params_weights,
            progress=True,  
        )

        # Remove the top layer (classifier)
        base_model.classifier = nn.Sequential()

        #save the base model
        base_model_save_path = self.config.base_model_save_path
        os.makedirs(base_model_save_path.parent,exist_ok=True)
        self.save_model(base_model, base_model_save_path)
        print(f"Base model saved at: {base_model_save_path}")


    def update_base_model(self, freeze_all=True, freeze_till=None):
        """Updates the base model by modifying the classifier and setting up the optimizer.
        
        Args:
            freeze_all (bool): Whether to freeze all layers.
            freeze_till (int): Layer index till which to freeze.
        """
        model = self._prepare_full_model(
            model=self.config.base_model_save_path,
            classes=self.config.params_classes,
            freeze_all=freeze_all,
            freeze_till=freeze_till,
            learning_rate=self.config.params_learning_rate
        )

        # Print model summary
        print("Model Summary:")
        summary(model, input_size=(3,224,224), batch_size=self.config.params_batch_size)
        # Save the updated model
        update_base_model_path = self.config.update_base_model_path
        os.makedirs(update_base_model_path.parent, exist_ok=True)
        self.save_model(model, update_base_model_path)
        print(f"Updated base model saved at: {update_base_model_path}")

    @staticmethod
    def _prepare_full_model(model,classes,freeze_all,freeze_till,learning_rate):
        """Prepares the full model by modifying the classifier and setting up the optimizer.
        
        Args:
            model (torch.nn.Module): The base model to modify.
            classes (int): Number of output classes.
            freeze_all (bool): Whether to freeze all layers.
            freeze_till (int): Layer index till which to freeze.
            learning_rate (float): Learning rate for the optimizer.
        
        Returns:
            torch.nn.Module: The modified model with a new classifier.
        """

        model = PrepareBaseModel.load_model(model_path=model)
        if freeze_all:
            for param in model.parameters():
                param.requires_grad = False

        elif freeze_till is not None:
            for param in list(model.parameters())[:freeze_till]:
                param.requires_grad = False

        # Modify the classifier
        model.classifier = nn.Sequential(
            nn.Linear(25088, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, classes)
        )

        return model


    @staticmethod
    def save_model(model: torch.nn.Module, path: Path):
        """Saves the model to the specified path.
        
        Args:
            model (torch.nn.Module): The model to save.
            path (Path): The path where the model will be saved.
        """
        torch.save(model.state_dict(), path)

    @staticmethod
    def load_model(model_path: Path):
        """Loads the model from the specified path.
        
        Args:
            model_path (Path): The path from which to load the model.
        Returns:
            torch.nn.Module: The loaded model.
        """
        model = torchvision.models.vgg16(weights=None)
        model.classifier = torch.nn.Sequential()
        state_dict = torch.load(model_path)
        model.load_state_dict(state_dict)
        return model


In [8]:
try:
    config = ConfigManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model(freeze_all=True)

except Exception as e:  
    raise e

[2025-07-09 12:36:46,408|(INFO)| File: common | Message: Created directory: artifacts]
[2025-07-09 12:36:46,428|(INFO)| File: common | Message: Created directory: artifacts/prepare_base_model]


Base model saved at: artifacts/prepare_base_model/base_model.pth


  state_dict = torch.load(model_path)


Model Summary:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [16, 64, 224, 224]           1,792
              ReLU-2         [16, 64, 224, 224]               0
            Conv2d-3         [16, 64, 224, 224]          36,928
              ReLU-4         [16, 64, 224, 224]               0
         MaxPool2d-5         [16, 64, 112, 112]               0
            Conv2d-6        [16, 128, 112, 112]          73,856
              ReLU-7        [16, 128, 112, 112]               0
            Conv2d-8        [16, 128, 112, 112]         147,584
              ReLU-9        [16, 128, 112, 112]               0
        MaxPool2d-10          [16, 128, 56, 56]               0
           Conv2d-11          [16, 256, 56, 56]         295,168
             ReLU-12          [16, 256, 56, 56]               0
           Conv2d-13          [16, 256, 56, 56]         590,080
             ReLU-14    