In [1]:
import os
os.chdir('../')

In [2]:
%pwd

'f:\\NamHoang\\MyProject\\DL\\dog-classification'

In [3]:
from pathlib import Path
from dataclasses import dataclass

In [4]:
@dataclass(frozen=True)
class PrepareModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_model_path: Path
    params_weights: str
    params_model: str
    params_seed: int

In [5]:
from cnnClassifier.constants import *
from cnnClassifier.utils.common import read_yaml, create_dir

In [6]:
class ConfigurationManager:
    def __init__(self) -> None:
        config_path = CONFIG_PATH
        params_path = PARAMS_PATH
        
        self.config = read_yaml(config_path)
        self.params = read_yaml(params_path)
        
        create_dir([self.config['artifacts_root']], verbose=False)
    
    def get_prepare_model_config(self):
        config = self.config['prepare_model']
        create_dir([config['root_dir']])
        
        prepare_model_config = PrepareModelConfig(
            root_dir=config['root_dir'],
            base_model_path=config['base_model_path'],
            updated_model_path=config['updated_model_path'],
            params_weights=self.params['WEIGHTS'],
            params_model=self.params['MODEL'],
            params_seed=self.params['SEED']
        )
        return prepare_model_config

In [7]:
from cnnClassifier.utils.common import save_model
import torchvision
import torch
from torchsummary import summary
from torch import nn

In [8]:
class PrepareModel:
    def __init__(self, config: PrepareModelConfig, num_classes: int=2) -> None:
        self.config = config
        self.model = None
        self.num_classes = num_classes
        
    def get_base_model(self):
        weights = torchvision.models.get_weight(f"{self.config.params_weights}.DEFAULT")
        self.model = torchvision.models.get_model(self.config.params_model, weights=weights)
        
        save_model(self.model, self.config.base_model_path)
    
    def _prepare_full_model(self, num_classes, seed = 42):
        torch.manual_seed(seed=seed)
        if self.model == None:
            self.get_base_model()
        self.model.classifier = torch.nn.Sequential(
            nn.Dropout(p=0.3, inplace=True),
            nn.LazyLinear(out_features=num_classes)
        )
        summary(self.model, input_size=(3, 224, 224))
        return self.model

    def updated_model(self):
        full_model = self._prepare_full_model(self.num_classes, self.num_classes)
        
        save_model(full_model, self.config.updated_model_path)

In [9]:
from cnnClassifier.pipeline.state_02_dataloader import DataLoaderPipeline

In [10]:
try:
    dataloader = DataLoaderPipeline()
    _, _, classes = dataloader.main()
    config = ConfigurationManager()
    prepare_model_config = config.get_prepare_model_config()
    prepare_model = PrepareModel(config=prepare_model_config, num_classes=len(classes))
    prepare_model.get_base_model()
    prepare_model.updated_model()
except Exception:
    raise

[2024-01-13 08:47:55,973: INFO: common: Yaml file: config\config.yaml loaded successfully]
[2024-01-13 08:47:55,976: INFO: common: Yaml file: params.yaml loaded successfully]
[2024-01-13 08:47:55,978: INFO: common: Created directory at: artifacts]
[2024-01-13 08:47:55,980: INFO: common: Created directory at: artifacts/dataloader]
[2024-01-13 08:47:56,089: INFO: common: Yaml file: config\config.yaml loaded successfully]
[2024-01-13 08:47:56,092: INFO: common: Yaml file: params.yaml loaded successfully]
[2024-01-13 08:47:56,093: INFO: common: Created directory at: artifacts/prepare_model]




----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 24, 112, 112]             648
       BatchNorm2d-2         [-1, 24, 112, 112]              48
              ReLU-3         [-1, 24, 112, 112]               0
         MaxPool2d-4           [-1, 24, 56, 56]               0
            Conv2d-5           [-1, 24, 28, 28]             216
       BatchNorm2d-6           [-1, 24, 28, 28]              48
            Conv2d-7           [-1, 24, 28, 28]             576
       BatchNorm2d-8           [-1, 24, 28, 28]              48
              ReLU-9           [-1, 24, 28, 28]               0
           Conv2d-10           [-1, 24, 56, 56]             576
      BatchNorm2d-11           [-1, 24, 56, 56]              48
             ReLU-12           [-1, 24, 56, 56]               0
           Conv2d-13           [-1, 24, 28, 28]             216
      BatchNorm2d-14           [-1, 24,