In [None]:
import pandas as pd
import numpy as np

from tensorflow.keras.applications.resnet import ResNet50
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from sklearn.metrics import mean_absolute_error


def load_train(path):
    data = pd.read_csv(path + 'labels.csv')

    train_datagen = ImageDataGenerator(
                                       validation_split=0.25,
                                       rescale=1./255,
                                       horizontal_flip=True,
                                       width_shift_range=.2,
                                       height_shift_range =.2) 

    train_datagen_flow = train_datagen.flow_from_dataframe(
                                                           dataframe = data,
                                                           directory= path + 'final_files/',
                                                           x_col='file_name',
                                                           y_col='real_age',
                                                           target_size=(224, 224),
                                                           batch_size=32,
                                                           class_mode='raw',
                                                           subset='training',
                                                           seed=12345)
    return train_datagen_flow

def load_test(path):
    data = pd.read_csv(path + 'labels.csv')

    test_datagen = ImageDataGenerator(
                                      validation_split=0.25,
                                      rescale=1./255) 

    test_datagen_flow = test_datagen.flow_from_dataframe(
                                                         dataframe = data,
                                                         directory= path + 'final_files/',
                                                         x_col='file_name',
                                                         y_col='real_age',
                                                         target_size=(224, 224),
                                                         batch_size=32,
                                                         class_mode='raw',
                                                         subset='validation',
                                                         seed=12345)
    return test_datagen_flow

def create_model(input_shape):


    backbone = ResNet50(input_shape=input_shape,
                        weights='/datasets/keras_models/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
                        include_top=False)
    model = Sequential()
    optimizer = Adam(lr=0.0001)
    model.add(backbone)
    model.add(GlobalAveragePooling2D())
    model.add(Dense(1, activation='relu'))

    model.compile(optimizer=optimizer, loss='mean_squared_error',
                  metrics=['mae'])
    return model

def train_model(model, train_data, test_data, batch_size=16, epochs=15,
                steps_per_epoch=None, validation_steps=None):

    model.fit(train_data,
          validation_data=test_data,
             epochs=epochs,
              steps_per_epoch=steps_per_epoch,
              validation_steps=validation_steps,
              verbose=2, shuffle=True)

    return model