In [1]:
import os 
os.chdir("../")

In [34]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path 
    base_model_path: Path
    updated_base_model_path: Path 
    params_classes: int

In [18]:
from src.CNNClassifier.constants import * 
from src.CNNClassifier.utils.common import read_yaml, create_directories

In [35]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])
    
    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model

        create_directories([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_classes=self.params.CLASSES,
        )

        return prepare_base_model_config

In [36]:
ConfigurationManager().get_prepare_base_model_config()

[2024-10-24 15:40:48,180: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-10-24 15:40:48,184: INFO: common: yaml file: params.yaml loaded successfully]
[2024-10-24 15:40:48,184: INFO: common: created directory at: artifacts]
[2024-10-24 15:40:48,189: INFO: common: created directory at: artifacts/prepare_base_model]


PrepareBaseModelConfig(root_dir=WindowsPath('artifacts/prepare_base_model'), base_model_path=WindowsPath('artifacts/prepare_base_model/base_vgg16.pth'), updated_base_model_path=WindowsPath('artifacts/prepare_base_model/updated_base_vgg16.pth'), params_classes=2)

In [29]:
import torch 
import torch.nn as nn 
import torch.optim as optim 
from torchsummary import summary
from torchvision import models 

In [37]:
class PrepareBaseModel():
    def __init__(self, config: PrepareBaseModelConfig):
        self.config = config 
    
    def get_base_model(self):
        self.model = models.vgg16(pretrained=True)
        self.save_model(self.model, self.config.base_model_path)

    def prepare_full_model(self):
        for param in self.model.features.parameters():
            param.requires_grad = False

        self.model.classifier[6] = nn.Linear(self.model.classifier[6].in_features, self.config.params_classes)
        # print(summary(self.model, self.config.image_size))

        self.save_model(self.model, self.config.updated_base_model_path)
        return self.model

    @staticmethod
    def save_model(model, path):
        torch.save(model, path)

In [14]:
model = models.vgg16(pretrained=True)



In [16]:
for param in model.features.parameters():
    param.requires_grad = False

num_class = 2
model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_class)

In [38]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)
    prepare_base_model.get_base_model()
    final_model = prepare_base_model.prepare_full_model()
except Exception as e:
    raise e

[2024-10-24 15:40:54,790: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-10-24 15:40:54,795: INFO: common: yaml file: params.yaml loaded successfully]
[2024-10-24 15:40:54,799: INFO: common: created directory at: artifacts]
[2024-10-24 15:40:54,799: INFO: common: created directory at: artifacts/prepare_base_model]
