In [1]:
import os

In [2]:
os.chdir('../')

In [3]:
from dataclasses import dataclass
from pathlib import Path
@dataclass(frozen=True)
class Model_Trainer_Config:
    root_dir:  Path  
    train_path:Path 
    test_path: Path 

In [4]:
from src.constant import *
def create_directories(file_path):
    pass
def read_yaml(file_path):
    pass
class Configuration_Manager:
    def __init__(self,config_path=CONFIG_PATH,params_path=PARAMS_PATH,schema_path=SCHEMA_PATH):
        self.config = read_yaml(config_path)
        self.params = read_yaml(params_path)
        self.schema = read_yaml(schema_path)
        create_directories([self.config.artifacts_root])
        
def get_model_trainer_config(self)->Model_Trainer_Config:
    config = self.config.Model_Trainer
    params = self.params
    create_directories([config.root_dir])
    model_trainer_config = Model_Trainer_Config(root_dir=config.root_dir,data_path=config.data_path,train_path=config.train_path,test_path=config.test_path)
    return model_trainer_config,params    

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator,load_img,image
from tensorflow.keras.applications import VGG16,ResNet50
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from tensorflow.keras.optimizers import Adagrad,SGD,Adam
from tensorflow.keras import regularizers
from tensorflow.keras.utils import img_to_array,load_img
from tensorflow.keras.layers import Sequential,Input,Flatten,Dense,Dropout,BatchNormalization
import numpy as np
import pandas as pd
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
from src.utils.common import logger
TENSOR_SIZE = (224,224,3)

class Model_Trainer:
    def __init__(self,model_trainer_config: Model_Trainer_Config):
        self.config,self.params = model_trainer_config
        input_shape = tuple(self.params['VGG16']['input_tensor']['input_shape'])
        self.vgg_base_model = VGG16(input_tensor=Input(input_shape=input_shape), 
                      weights=self.params['VGG16']['weights'],
                      include_top=self.params['VGG16']['include_top'])
        self.resnet_base_model = ResNet50(input_tensor=Input(input_shape=input_shape), 
                      weights=self.params['ResNet50']['weights'],
                      include_top=self.params['ResNet50']['include_top'])
        
    def get_models(self):
        for layer in self.model_trainer_config.vgg_base_model.layers:
            layer.trainable = False
        
        for layer in self.model_trainer_config.resnet_base_model.layers:
            layer.trainable = False
            
        vgg_based_model = Sequential()    
        vgg_based_model.add(self.model_trainer_config.vgg_base_model)    
        vgg_based_model.add(Flatten())    
        vgg_based_model.add(Dense(256,activation='relu'))    
        vgg_based_model.add(Dense(4,activation='softmax'))
        
        resnet_based_model = Sequential()
        resnet_based_model.add(self.model_trainer_config.resnet_base_model)
        resnet_based_model.add(Flatten())
        resnet_based_model.add(Dense(256,activation='relu'))
        resnet_based_model.add(Dense(4,activation='softmax'))    

        vgg_based_model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
        resnet_based_model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
        
        return (vgg_based_model,resnet_based_model)
             
    def initiate_model_trainer(self,training_set,testing_set):
        vgg_model,resnet_model = self.get_models()
        
        vgg_model.fit(training_set,epochs=self.params['VGG16']['epochs'],validation_set=testing_set,steps_per_epoch=len(training_set),validation_steps=len(testing_set))
        logging.info('Finished Training VGG16 Model')
        
        resnet_model.fit(training_set,epochs=self.params['ResNet50']['epochs'],validation_set=testing_set,steps_per_epoch=len(training_set),validation_steps=len(testing_set))
        logging.info('Finished Training ResNet50 Model')
        
        vgg_model.save(os.path.join(self.config.model_path,'vgg_model.h5'))
        resnet_model.save(os.path.join(self.config.model_path,'resnet_model.h5'))
        logging.info(f'Saved both models at: {self.config.model_path}')

In [None]:
try:
    config = Configuration_Manager()
    model_trainer_config = config.get_model_trainer_config()
    model_trainer = Model_Trainer(model_trainer_config)
    model_trainer.initiate_model_trainer()
except Exception as e:
    raise e