In [1]:
import os
os.chdir("../")

In [56]:
from dataclasses import dataclass
from pathlib import Path

In [99]:
from src.bone_classifier import log
from src.bone_classifier.constants import *
from src.bone_classifier.utils.common import read_yaml, create_directories

In [100]:
@dataclass(frozen=True)
class PrepareBaseModelConfig: 
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path

    params_image_size: list
    params_include_top: bool
    params_weight : str
    params_classes: int
    params_learning_rate: int
    


In [101]:
class ConfigurationManager:
    def __init__(self, config_filepath=CONFIG_FILE_PATH, 
                 params_filepath=PARAMS_FILE_PATH):
        self.config=read_yaml(config_filepath)
        self.params=read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_prepare_base_model(self)-> PrepareBaseModelConfig:
        config=self.config.prepare_base_model
        create_directories([config.root_dir])
        
        prepare_base_model=PrepareBaseModelConfig(
            root_dir=Path(config.root_dir), 
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_model_path),

            params_image_size=self.params.IMAGE_SIZE,
            params_include_top=self.params.INCLUDE_TOP,
            params_weight=self.params.WEIGHTS,
            params_classes=self.params.CLASSES,
            params_learning_rate=self.params.LEARNING_RATE
            )
        return prepare_base_model

In [102]:
from keras.applications.inception_v3 import InceptionV3, preprocess_input
from keras.preprocessing.image import ImageDataGenerator
import keras
from keras.layers import BatchNormalization

In [103]:
class PreparBaseModel: 
    def __init__(self, config: PrepareBaseModelConfig) -> None:
        self.config=config

    @staticmethod
    def save_model(path: Path, model: keras.Model):
        model.save(path)
        log.info(f"Base model save sucessfully at path: {path}")
        

    def get_base_model(self):
        self.model=InceptionV3(
            include_top=self.config.params_include_top,
            weights=self.config.params_weight,
            input_shape=self.config.params_image_size,
            classes=self.config.params_classes)
        
        self.save_model(path=self.config.base_model_path, model=self.model)

    
    def prepare_full_model(self,freez_all, model, freez_till: int ,classes):

        if freez_all:
            for layer in model.layers: 
                model.trainable=False
            log.info(f"Model freez all")

        elif (freez_till is None) and (freez_till>0):
            for layer in model.layers[:-freez_till]:
                model.trainable=False
            log.info(f"Model freetill {freez_till}")


        flatten_in=keras.layers.Flatten()(model.output)
        x=keras.layers.Dense(1000, activation='relu')(flatten_in)
        x=keras.layers.Dropout(0.2)(x)
        x=keras.layers.BatchNormalization()(x)
        x=keras.layers.Dense(500, activation='relu')(x)
        x=keras.layers.Dense(300, activation='relu')(x)
        x=keras.layers.BatchNormalization()(x)
        prediction=keras.layers.Dense(units=5, activation='softmax')(x)

        full_model=keras.models.Model(
            inputs=model.input,
            outputs=prediction
        )

        # full_model.compile(optimizer=keras.optimizers.SGD(learning_rate=0.01), loss=keras.losses.CategoricalCrossentropy, metrics=['acc'])
    

        full_model.summary()

        

        return full_model
    
    def update_base_model(self):
        self.full_model=self.prepare_full_model(freez_all=True, 
                                                model=self.model, 
                                                freez_till=None,
                                                classes=self.config.params_classes)
        
        self.save_model(path=self.config.updated_base_model_path, model=self.full_model)





        

In [104]:
try: 
    config=ConfigurationManager()
    preparemodel_config=config.get_prepare_base_model()
    prepare_model=PreparBaseModel(preparemodel_config)
    prepare_model.get_base_model()
    prepare_model.update_base_model()
except Exception as e: 
    raise e

[2024-03-15 04:14:37,297 : INFO : common : yaml file : config/config.yaml loaded sucessfully]
[2024-03-15 04:14:37,300 : INFO : common : yaml file : params.yaml loaded sucessfully]
[2024-03-15 04:14:37,300 : INFO : common : Created directories at : artifacts]
[2024-03-15 04:14:37,301 : INFO : common : Created directories at : artifacts/prepare_base_model]
[2024-03-15 04:14:38,031 : INFO : 4172541680 : Base model save sucessfully at path: artifacts/prepare_base_model/model.h5]
[2024-03-15 04:14:38,838 : INFO : 4172541680 : Model freez all]
Model: "model_6"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_26 (InputLayer)       [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 conv2d_2350 (Conv2D)        (None, 111, 1