In [1]:
import os
os.chdir("../")
%pwd

'd:\\python-projects\\chest-cancer-classification'

In [2]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    params_image_size: list
    params_learning_rate: float
    params_include_top: bool
    params_weights: str
    params_classes: int

In [3]:
from src.cnnClassifier.constants import CONFIG_FILE_PATH, PARAMS_FILE_PATH
from src.cnnClassifier.utils.common import read_yaml, create_directories

In [4]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([Path(self.config.artifacts_root)])


    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        
        create_directories([Path(config.root_dir)])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_image_size=self.params.IMAGE_SIZE,
            params_learning_rate=self.params.LEARNING_RATE,
            params_include_top=self.params.INCLUDE_TOP,
            params_weights=self.params.WEIGHTS,
            params_classes=self.params.CLASSES
        )

        return prepare_base_model_config

In [5]:
import torch
import torch.nn as nn
import torchvision.models as models
import pytorch_lightning as pl
from torchinfo import summary

In [6]:

class PrepareBaseModel:
    def __init__(self, config: PrepareBaseModelConfig):
        super().__init__()
        self.config = config


    def get_base_model(self):
        self.base_model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        if not self.config.params_include_top:
            self.base_model.classifier = nn.Identity()

        self.save_model(path=self.config.base_model_path, model=self.base_model)



    @staticmethod
    def _prepare_full_model(model, classes, freeze_all, freeze_till):
        if freeze_all:
            for param in model.parameters():
                param.requires_grad = False
        elif (freeze_till is not None) and (freeze_till > 0):
            for idx, child in enumerate(model.features):
                if idx < freeze_till:
                    for param in child.parameters():
                        param.requires_grad = False

        model.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(25088, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, classes),
            nn.Softmax(dim=1)
        )
        
        print(summary(model,input_size=(1, 3, 224, 224)))
        return model



    def update_base_model(self):
        self.full_model = self._prepare_full_model(
            model=self.base_model,
            classes=self.config.params_classes,
            freeze_all=True,
            freeze_till=None,
        )
        self.save_model(self.config.updated_base_model_path, self.full_model)

    @staticmethod
    def save_model(path, model):
        torch.save(model, path)

In [7]:
from torch.nn import nn
class MyImageClassifier(pl.LightningModule):
    def __init__(self, model: nn.Module, config: PrepareBaseModelConfig):
        super().__init__()
        self.model = model
        self.config = config
        self.loss_fn = nn.CrossEntropyLoss()
        self.save_hyperparameters(ignore=['model'])

    
    def forward(self, x):
        return self.model(x)
    

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss


    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        preds = torch.argmax(logits, dim=1)
        accuracy = (preds == y).float().mean()
        self.log('val_accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.params_learning_rate)
        return optimizer

ImportError: cannot import name 'nn' from 'torch.nn' (d:\python-projects\chest-cancer-classification\venv\Lib\site-packages\torch\nn\__init__.py)

In [26]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
except Exception as e:
    raise e

[2025-06-05 20:45:41,254] [13] [common] - INFO - YAML file loaded successfully: config\config.yaml
[2025-06-05 20:45:41,259] [13] [common] - INFO - YAML file loaded successfully: params.yaml
[2025-06-05 20:45:41,262] [26] [common] - INFO - Created directory at: artifacts
[2025-06-05 20:45:41,265] [26] [common] - INFO - Created directory at: artifacts\prepare_base_model
Layer (type:depth-idx)                   Output Shape              Param #
VGG                                      [1, 2]                    --
├─Sequential: 1-1                        [1, 512, 7, 7]            --
│    └─Conv2d: 2-1                       [1, 64, 224, 224]         (1,792)
│    └─ReLU: 2-2                         [1, 64, 224, 224]         --
│    └─Conv2d: 2-3                       [1, 64, 224, 224]         (36,928)
│    └─ReLU: 2-4                         [1, 64, 224, 224]         --
│    └─MaxPool2d: 2-5                    [1, 64, 112, 112]         --
│    └─Conv2d: 2-6                       [1, 128, 11