In [1]:
import os

In [2]:
%pwd

'd:\\Artificial Intelligence and Machine Learning\\End2End Projects\\Paddy Doctor - Paddy Disease Classification\\paddy-doctor\\research'

In [3]:
os.chdir("../")
%pwd

'd:\\Artificial Intelligence and Machine Learning\\End2End Projects\\Paddy Doctor - Paddy Disease Classification\\paddy-doctor'

In [4]:
## Define entity
from dataclasses import dataclass
from pathlib import Path

@dataclass
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    params_weights: str
    params_input_shape: list
    freeze_all: bool
    freeze_till: int
    learning_rate: float
    classes: int
    dropout_rate:float
    l2_weight_decay: float
    l1_weight_decay: float




In [5]:
from paddydoctor.constants import *
from paddydoctor.utils.common import *

In [6]:
## Define Configuration Manager

class ConfigurationManager:
    def __init__(self, 
                 config_filepath = CONFIG_FILEPATH, 
                 params_filepath = PARAMS_FILEPATH):
        
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories_files([self.config.artifacts_root])

    def get_prepare_base_model_config(self)->PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        create_directories_files([config.root_dir])
        prepare_base_model_config = PrepareBaseModelConfig(root_dir = config.root_dir, 
                                                           base_model_path = config.base_model_path, 
                                                           updated_base_model_path = config.updated_base_model_path, 
                                                           params_weights = self.params.weights,
                                                           params_input_shape=self.params.input_shape,
                                                           freeze_all = self.params.freeze_all, 
                                                           freeze_till = self.params.freeze_till, 
                                                           learning_rate = self.params.learning_rate, 
                                                           classes = self.params.classes,
                                                           dropout_rate = self.params.dropout_rate,
                                                           l2_weight_decay = self.params.l2_weight_decay,
                                                           l1_weight_decay = self.params.l1_weight_decay)
        
        return prepare_base_model_config






In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchinfo import summary

In [18]:
## Update Components
class PrepareBaseModel:
    def __init__(self, config: PrepareBaseModelConfig):
        self.config = config
    def get_base_model(self):
        weights_map = {"IMAGENET1K_V1": ViT_B_16_Weights.IMAGENET1K_V1, 
                       "IMAGENET1K_SWAG_E2E_V1": ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1}
        
        self.model = vit_b_16(weights=weights_map[self.config.params_weights])
        self.save_model(path = self.config.base_model_path, model = self.model)

    ## Static Method is defined when we want to use a function that doesn't rely on the class
    ## We cannot use self when a function is defined as static method
    ## Static Function can be directly run without instantating the class. Use it directly as
    ## PrepareBaseModel._prepare_full_model()
    
    @staticmethod
    def _prepare_full_model(model, 
                            classes:int, 
                            freeze_all:bool, 
                            freeze_till:int, 
                            learning_rate:float, 
                            dropout_rate: float, 
                            l2_weight_decay: float):
        # Freeze all layers if freeze_all is true
        if freeze_all:
            for param in model.parameters():
                param.requires_grad = False
        elif freeze_till is not None and freeze_till>0:
            for param in list(model.parameters())[:-freeze_till]:
                param.requires_grad = False

        # Replace the classification head
        
        model.heads = nn.Sequential(nn.Linear(768,1028),
                                    nn.ReLU(), 
                                    nn.Dropout(dropout_rate),
                                    nn.Linear(1028, 128),
                                    nn.ReLU(),
                                    nn.Dropout(dropout_rate),
                                    nn.Linear(128, classes),
                                    nn.Softmax(dim = 1))
        # Define optimizer
        optimizer = optim.Adam(params = model.parameters(), 
                               lr = learning_rate, 
                               weight_decay = l2_weight_decay)
        # Define loss function
        criterion = nn.CrossEntropyLoss()

        print(summary(model = model, 
                input_size = (32,3,224,224), 
                col_names=["input_size", "output_size", "num_params", "trainable"],
                col_width = 20, row_settings = ["var_names"]))

        
        return model, optimizer, criterion
    
    def update_base_model(self):
        self.full_model,self.optimizer, self.criterion = self._prepare_full_model(model = self.model, 
                                                                                  classes = self.config.classes,
                                                                                  freeze_all = self.config.freeze_all, 
                                                                                  freeze_till = self.config.freeze_till, 
                                                                                  learning_rate = self.config.learning_rate, 
                                                                                  dropout_rate = self.config.dropout_rate, 
                                                                                  l2_weight_decay = self.config.l2_weight_decay
                                                                                  )
        self.save_model(self.config.updated_base_model_path, self.full_model)
    
    @staticmethod
    def save_model(path, model):
        torch.save(model.state_dict(), path)
        
        

In [19]:
## Create pipeline
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
except Exception as e:
    raise e

[2024-09-17 13:55:27,006: INFO: common: config\config.yaml loaded successfully]
[2024-09-17 13:55:27,009: INFO: common: params\params.yaml loaded successfully]
[2024-09-17 13:55:27,012: INFO: common: Directories and Files successfully created]
[2024-09-17 13:55:27,013: INFO: common: Directories and Files successfully created]
Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 10]             768                  Partial
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    (590,592)            False
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              Partial
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential