In [1]:
!pip install tensorflow



In [7]:
import os
import tensorflow as tf
from pathlib import Path

In [17]:
class PrepareBaseModel:
    def __init__(self):
        self.params_classes=22
        self.params_freeze_all=True
        self.params_freeze_till=None
        self.params_learning_rate=0.01
        self.input_shape=(224,224,3)
        self.params_weights='imagenet'
        self.include_top=False
        self.params_activation="softmax"
        self.updated_model_path="../artifacts/prepare_base_model/base_model_updated.h5"
        self.base_model_path="../artifacts/prepare_base_model/base_model.h5"

    
    def load_base_model(self):
        self.base_model = tf.keras.applications.vgg16.VGG16(
                        include_top=self.include_top,
                        weights=self.params_weights,
                        input_shape=self.input_shape
                    )
        self.save_model(path=self.base_model_path, model=self.base_model)

    def prepare_full_model(self, model, classes, freeze_all, freeze_till, learning_rate):
        if freeze_all:
            for layer in model.layers:
                model.trainable = False
        elif (freeze_till is not None) and (freeze_till > 0):
            for layer in model.layers[:-freeze_till]:
                model.trainable = False
    
        flatten_in = tf.keras.layers.Flatten()(model.output)
        prediction = tf.keras.layers.Dense(
            units = classes,
            activation = self.params_activation
        )(flatten_in)
    
        self.full_model = tf.keras.models.Model(
            inputs = model.input,
            outputs = prediction
        )
    
        self.full_model.compile(
            optimizer = tf.keras.optimizers.SGD(learning_rate = learning_rate),
            loss = tf.keras.losses.CategoricalCrossentropy(),
            metrics = ["accuracy"]
        )
    
        self.full_model.summary()
    
    def save_model(self, path:Path, model: tf.keras.Model):
        model.save(path)


    def update_base_model(self):
        self.load_base_model()
        self.prepare_full_model(
            model = self.base_model,
            classes = self.params_classes,
            freeze_all = self.params_freeze_all,
            freeze_till = self.params_freeze_till,
            learning_rate = self.params_learning_rate
        )
        self.save_model(path=self.updated_model_path, model=self.full_model)

In [19]:
prepare_base_model = PrepareBaseModel()
prepare_base_model.update_base_model()



