In [None]:
## This script is for model training using APL framework and ResNet-50 features

In [None]:
# import packages
import os, cv2
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import keras
from keras.applications import resnet50
from keras.models import Model, Sequential
from keras.applications.resnet50 import preprocess_input
from keras.utils import np_utils
from tensorflow.keras.layers import Input, UpSampling2D, Flatten, BatchNormalization, Dense, Dropout, GlobalAveragePooling2D
from tensorflow.python.client import device_lib
## check GPU
print(device_lib.list_local_devices())

In [None]:
# load data (first part)
## load shadow-free images
path_wd = '../' ## set working directory
img = cv2.imread(path_wd + 'output/images/shadow_free.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) ## transform color channel

In [None]:
# choose parameters
THRESH = 20 ## diameter threshold
n_c = 25 ## number of clusters

In [None]:
# load data (second part)
## load clustering results
labels = pd.read_csv(path_wd + 'output/KMeans/labels_' + str(n_c) + '.csv', index_col=0).to_numpy().flatten()
## load relevance
rele_PREMON = pd.read_csv(path_wd + 'output/relevance/relevance_PREMON_' + str(THRESH) + '_' + str(n_c) + '.csv', index_col=0).to_numpy().flatten()
rele_CECSCH = pd.read_csv(path_wd + 'output/relevance/relevance_CECSCH_' + str(THRESH) + '_' + str(n_c) + '.csv', index_col=0).to_numpy().flatten()
rele_MANBID = pd.read_csv(path_wd + 'output/relevance/relevance_MANBID_' + str(THRESH) + '_' + str(n_c) + '.csv', index_col=0).to_numpy().flatten()

In [None]:
# create training set
time_start = datetime.now()
print('Start:', time_start)
## create training patches
n_img = len(labels)
## create image arrays
x_dat = np.zeros((n_img, 100, 100, 3), dtype=np.uint8)
y_dat = np.zeros((n_img, 3))
## loop over the patches
for i in range(n_img):
    rr = i // 96
    cc = i % 96
    lab_cluster = labels[i]
    x_dat[i] = img[(rr*100):(rr*100+100), (cc*100):(cc*100+100)]
    y_dat[i] = np.array([rele_PREMON[lab_cluster], rele_CECSCH[lab_cluster], rele_MANBID[lab_cluster]])
## preprocess the input images
x_dat = preprocess_input(x_dat)
## label rescale (This step is optional)
## rescale of feature labels, such as the cluster with largest palm relevance has label as 1
for i in range(3):
    y_dat[:, i] = y_dat[:, i] * (1 / np.max(y_dat[:, i]))
print('Finished:', datetime.now()-time_start)

In [None]:
# build the model
## load the ResNet model
resnet_model = resnet50.ResNet50(weights='imagenet', include_top=False, input_shape=(100, 100, 3))
## change the training settings
#for layer in resnet_model.layers:
#    if isinstance(layer, BatchNormalization):
#        layer.trainable = True
#    else:
#        layer.trainable = False
## build the model
model = Sequential()
model.add(resnet_model)
model.add(GlobalAveragePooling2D())
model.add(Dense(3, activation='sigmoid'))
## compile the model
opt = keras.optimizers.Adam(learning_rate=1e-4)
model.compile(loss='mean_squared_error',optimizer=opt,metrics=['mean_absolute_error'])

In [None]:
# parameter selection
## train-validation split
np.random.seed(2020)
n_img = x_dat.shape[0]
loc_train = np.random.choice(n_img, int(n_img * 0.8), replace=False)
loc_val = np.setdiff1d(np.arange(n_img), loc_train)
## start model training
time_start = datetime.now()
print('Start training:', time_start)
#model.fit(x_dat[loc_train], y_dat[loc_train], batch_size=64, epochs=200, validation_data=(x_dat[loc_val], y_dat[loc_val]))
print('Time for model training:', datetime.now()-time_start)

In [None]:
# retrain the model with selected epochs
## build the model
model = Sequential()
model.add(resnet_model)
model.add(GlobalAveragePooling2D())
model.add(Dense(3, activation='sigmoid'))
## compile the model
opt = keras.optimizers.Adam(learning_rate=1e-3)
model.compile(loss='mean_squared_error',optimizer=opt,metrics=['mean_absolute_error'])
## start model training
time_start = datetime.now()
print('Start training:', time_start)
model.fit(x_dat, y_dat, batch_size=64, epochs=50, validation_data=(x_dat[loc_val], y_dat[loc_val]))
## save the model
model.save(path_wd + 'output/models/APL_' + str(THRESH) + '_' + str(n_c))
print('Time for model training:', datetime.now()-time_start)