In [None]:
import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.layers import Dropout

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense,GlobalAveragePooling2D
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras.applications.resnet50 import ResNet50

from tensorflow.keras.applications.mobilenet import preprocess_input
import numpy as np
import matplotlib.pyplot as plt

In [None]:
base_model=MobileNetV2(input_shape=(224,224,3), weights='imagenet',include_top=False)

x=base_model.output
x=GlobalAveragePooling2D()(x)
x=Dropout(0.5,seed=42)(x)
x=Dense(1024,activation='relu')(x) 
x=Dense(1024,activation='relu')(x) 
x=Dense(512,activation='relu')(x) 

output=Dense(2,activation='softmax')(x) 
model=Model(inputs=base_model.input,outputs=output)

In [None]:
n=154 #mobileNet,  
# n=174 #ResNet50
for layer in model.layers[:n]:
    layer.trainable=False
for layer in model.layers[n:]:
    layer.trainable=True    
    
for layer in model.layers: print(layer, layer.trainable)    

In [None]:
opt = tf.keras.optimizers.Adam()
model.compile(optimizer=opt,loss='categorical_crossentropy',metrics=['accuracy'])

In [None]:
datagen =ImageDataGenerator(rescale=1.0/255.0,
                                     rotation_range=40,
                                     width_shift_range=0.2,
                                     height_shift_range=0.2,
                                     shear_range=0.2,
                                     zoom_range=[0.5,1.0],
                                     brightness_range=[0.2,1.0],              
                                     horizontal_flip=True,
                                     fill_mode='nearest',
                                     data_format='channels_last',
                                     validation_split=0.2
                           )

In [None]:
datadirectory='/home/jan/mycontainer/input/'
 
Tsize=[224,224]
 
train_batch_generator =datagen.flow_from_directory(directory=datadirectory,
                                     classes=['Faust', 'Offen'],
                                     target_size=Tsize,
                                     class_mode='categorical',              
                                     batch_size=20,
                                     color_mode='rgb',
                                     shuffle=True,
                                     seed=42,
                                     subset='training')

valid_batch_generator =datagen.flow_from_directory(directory=datadirectory,
                                     classes=['Faust', 'Offen'],
                                     target_size=Tsize,
                                     class_mode='categorical',
                                     #batch_size=10,
                                     color_mode='rgb',
                                     shuffle=True,
                                     seed=42,
                                     subset='validation')

In [None]:
filepath='/home/jan/models/RMobileNetV2h5.h5'

Checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=filepath,               
    monitor='val_accuracy',
    verbose=0, 
    save_best_only=True, 
    save_weights_only=False, 
    mode='auto', 
    period=1)

In [None]:
%%time
step_size_train=train_batch_generator.n//train_batch_generator.batch_size
step_size_valid=valid_batch_generator.n//valid_batch_generator.batch_size
# or
step_size_train=60 # If we use Data augmentation with fit_generator, you can use more samples than you have, as they are generated differently on the fly

history=model.fit_generator(generator=train_batch_generator,
                    steps_per_epoch=step_size_train,
                    validation_data=valid_batch_generator,
                    epochs=10,
                    callbacks=[Checkpoint_callback])