In [None]:
import os
import sys
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from keras.optimizers import Adam
from keras.models import model_from_json
from SANet import SANet
import shutil
from utils import load_img, gen_x_y, eval_loss, gen_paths, ssim_eucli_loss, random_cropping


# Settings
net = 'SANet'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
%matplotlib inline
dataset = "B"
LOSS = 'ssim_eucli_loss'
LOSS_train = eval(LOSS)
with_validation = False
lr = 1e-4

img_paths_test, img_paths_train = gen_paths(
    path_file_root='data/paths_train_val_test',
    dataset=dataset,
    with_validation=with_validation
)
# img_paths_test, img_paths_train = img_paths_test[:10], img_paths_train[:10]
if with_validation:
    img_paths_train = list(set(img_paths_train) - set(img_paths_val))
    x_val, y_val, img_paths_val = gen_x_y(img_paths_val, 'val')
    print(len(x_val), len(y_val), len(img_paths_val))
    x_test, y_test, img_paths_test = gen_x_y(img_paths_test, 'test')
    print('Test data size:', len(x_test), len(y_test), len(img_paths_test))
else:
    x_val, y_val, img_paths_val = gen_x_y(img_paths_test[:16], 'test')
    print('Validation data size:', len(x_val), len(y_val), len(img_paths_val))
x_train, y_train, img_paths_train = gen_x_y(img_paths_train[:], 'train', augmentation_methods=['ori', 'flip'])
print('Train data size:', len(x_train), len(y_train), len(img_paths_train))
weights_dir = 'weights_' + dataset
if os.path.exists(weights_dir):
    shutil.rmtree(weights_dir)
os.makedirs(weights_dir)

In [None]:
# Model
model = SANet(IN=False)
model.summary()
optimizer = Adam(lr=lr)
model.compile(optimizer=optimizer, loss=LOSS_train)

In [None]:
# Settings
# model.load_weights('SANet_best.hdf5')
if not os.path.exists(weights_dir):
    os.makedirs(weights_dir)
# Two ways to set the different parts of model trainable
# Set different branches trainable seperatly(Each module has four branches)
branches_untrainable = [
    [2, 3, 4, 6, 7, 8, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 27, 28, 29, 30, 31, 32, 34, 35, 36, 38, 39, 40, 43, 44, 45, 46, 47, 48, 50, 51, 52, 54, 55, 56],
    [2, 3, 4, 6, 7, 8, 12, 13, 15, 16, 17, 19, 20, 22, 23, 24, 28, 29, 31, 32, 33, 35, 36, 37, 39, 40, 44, 45, 47, 48, 49, 51, 52, 53, 55, 56],
    [1, 3, 4, 5, 7, 8, 11, 13, 14, 16, 17, 18, 20, 21, 22, 24, 27, 29, 30, 32, 33, 34, 36, 37, 38, 40, 43, 45, 46, 48, 49, 50, 52, 53, 54, 56],
    [1, 2, 4, 5, 6, 8, 11, 12, 14, 15, 17, 18, 19, 21, 22, 23, 27, 28, 30, 31, 33, 34, 35, 37, 38, 39, 43, 44, 46, 47, 49, 50, 51, 53, 54, 55]
]
# Set differernt modules trainable 3 Inception-like module and the Decoder module trainable seperatly.
branches_trainable = [
    list(range(1, 4+1)),
    list(range(5, 11+1)),
    list(range(12, 18+1)),
    list(range(19, 25+1)),
    list(range(26, 33+1))
]
lossMAE = 1e5
lossMDMD, lossMAPE, lossMSE = -1, -1, -1
counter_train = 0
mae = 1e5
val_rate = 0.2
lossesMDMD = []
lossesMAE = []
lossesMAPE = []
lossesMSE = []
path_val_display = img_paths_val[0]
x_val_display = load_img(path_val_display)
y_val_display = np.squeeze(y_val[0])
if not os.path.exists('models'):
    os.makedirs('models')
with open('./models/{}.json'.format(net), 'w') as fout:
    fout.write(model.to_json())
if dataset == 'A':
    save_frequencies = [(90, 0.1), (80, 0.05), (95, 0.25)]
else:
    save_frequencies = [(30, 0.1), (16, 0.05), (32, 0.25)]

if_train_seperatly = True
if if_train_seperatly:
    module_step = 1
    epoch_train_seperatly = module_step * len(branches_untrainable) * 1
    epoch_train = epoch_train_seperatly + 100
else:
    epoch_train = 200
    

# Training
time_st = time.time()
for epoch in range(epoch_train):
    if if_train_seperatly:
        trainable_choice = epoch % (module_step*len(branches_untrainable))
        # Train modules seperatly to avoid output all zeros
        if epoch < epoch_train_seperatly and trainable_choice in [0, module_step*1, module_step*2, module_step*3, module_step*4, module_step*5]:
            branch_trainable = branches_trainable[trainable_choice]
            for i in range(1, len(model.layers)):
                if 'conv' in model.layers[i].name or 'activa' in model.layers[i].name or 'norm' in model.layers[i].name or 1:
                    model.layers[i].trainable = False
            for i in range(1, len(model.layers)):
                idx_operator = int(model.layers[i].name.split('_')[-1])
                if idx_operator in branch_trainable:
                    model.layers[i].trainable = True
            model.compile(optimizer=Adam(lr=lr), loss=LOSS_train)
        elif epoch == epoch_train_seperatly:
            for i in range(1, len(model.layers)):
                model.layers[i].trainable = True
            model.compile(optimizer=Adam(lr=lr/10), loss=LOSS_train)
        else:
            pass
    
    for i in range(len(x_train)):
        if lossMAE < save_frequencies[0][0]:
            val_rate = save_frequencies[0][1]
            if lossMAE < save_frequencies[1][0]:
                val_rate = save_frequencies[1][1]
        if len(lossesMAE) > 50 and val_rate <= save_frequencies[-1][-1] and np.median(lossesMAE[-20]) > save_frequencies[-1][0]:
            val_rate = save_frequencies[-1][-1]
        x_, y_ = x_train[i], y_train[i]
        model.fit(x_, y_, batch_size=1, verbose=0)
#         x_crop, y_crop = random_cropping(x_, y_)
#         model.fit(x_crop, y_crop, batch_size=1, verbose=0)
#         x_flip, y_flip = x_[:, :, ::-1, :], y_[:, :, ::-1, :]
#         x_flip_crop, y_flip_crop = random_cropping(x_flip, y_flip)
#         model.fit(x_flip_crop, y_flip_crop, batch_size=1, verbose=0)
        counter_train += 1
        if counter_train % int(len(x_train)*val_rate) == 0:
            # Calc loss
            lossMDMD, lossMAE, lossMAPE, lossMSE = eval_loss(model, x_val, y_val)
            lossesMDMD.append(lossMDMD)
            lossesMAE.append(lossMAE)
            lossesMAPE.append(lossMAPE)
            lossesMSE.append(lossMSE)
            lossMAE, lossMAPE, lossMDMD, lossMSE = round(lossMAE, 3), round(lossMAPE, 3), round(lossMDMD, 3), round(lossMSE, 3)
            if (lossMAE < mae and epoch_train > 0) or lossMAE < save_frequencies[1][0] * 0.9:
                mae = lossMAE
                model.save_weights(
                    os.path.join(weights_dir, '{}_MAE{}_MSE{}_MAPE{}_MDMD{}_epoch{}-{}.hdf5'.format(
                        net, str(lossMAE), str(lossMSE), str(lossMAPE), str(lossMDMD), epoch, counter_train%len(x_train)
                    ))
                )
                model.save_weights(os.path.join(weights_dir, '{}_best.hdf5'.format(net)))
            if counter_train % (len(x_train)*2) == 0:
                # show prediction
                pred = np.squeeze(model.predict(np.expand_dims(x_val_display, axis=0)))
                fg, (ax_x_ori, ax_y, ax_pred) = plt.subplots(1, 3, figsize=(20, 4))
                ax_x_ori.imshow(cv2.cvtColor(cv2.imread(path_val_display), cv2.COLOR_BGR2RGB))
                ax_x_ori.set_title('Original Image')
                ax_y.imshow(y_val_display, cmap=plt.cm.jet)
                ax_y.set_title('Ground_truth: ' + str(np.sum(y_val_display)))
                ax_pred.imshow(pred, cmap=plt.cm.jet)
                ax_pred.set_title('Prediction: ' + str(np.sum(pred)))
                plt.suptitle('Loss = ' + str(lossMAE))
                plt.show()
            if counter_train % (len(x_train)*1) == 0 and (epoch+1) % (4) == 0:
                # plot val_loss
                plt.plot(lossesMDMD, 'r')
                plt.plot(lossesMSE, 'c')
                plt.plot(lossesMAE, 'b')
                plt.plot(lossesMAPE, 'y')
                plt.legend(['Loss_Density_Map_Distance', 'LossMSE', 'LossMAE', 'LossMAPE'])
                plt.title('Loss')
                plt.show()
        time_now = time.time()
        time_consuming = time_now - time_st
        sys.stdout.write('In epoch {}_{}, with MAE={}, MSE={}, MAPE={}, MDMD={}, time consuming={}m-{}s\r'.format(
            epoch, counter_train%len(x_train), str(lossMAE), str(lossMSE), str(lossMAPE), str(lossMDMD),
            int(time_consuming/60), int(time_consuming-int(time_consuming/60)*60)
        ))
        sys.stdout.flush()
end_time_of_train = '_'.join('-'.join(time.ctime().split()).split(':'))
MAE_min = str(round(np.min(lossesMAE), 3))
shutil.move('weights_{}'.format(dataset), 'weights_{}_{}_bestMAE{}_{}'.format(dataset, LOSS, MAE_min, end_time_of_train))

In [None]:
loss_dir = 'losses_' + dataset
if not os.path.exists(loss_dir):
    os.makedirs(loss_dir)
np.savetxt(os.path.join(loss_dir, 'loss_DMD.txt'), lossesMDMD)
np.savetxt(os.path.join(loss_dir, 'loss_MAE.txt'), lossesMAE)
np.savetxt(os.path.join(loss_dir, 'loss_MAPE.txt'), lossesMAPE)
np.savetxt(os.path.join(loss_dir, 'loss_MSE.txt'), lossesMSE)
plt.plot(lossesMDMD, 'r')
plt.plot(lossesMAE, 'b')
plt.plot(lossesMAPE, 'y')
plt.plot(lossesMSE, 'c')
plt.legend(['Loss_Density_Map_Distance', 'Loss_MAE', 'LossMAPE', 'LossMSE'])
plt.title('Loss -- {} epochs'.format(epoch_train))
plt.savefig('./loss_{}/loss_{}_{}.jpg'.format(dataset, dataset, end_time_of_train))
plt.show()
shutil.move('losses_{}'.format(dataset), 'losses_{}_{}_bestMAE{}_{}'.format(dataset, LOSS, MAE_min, end_time_of_train))

In [None]:
model.load_weights('weights_B/SANet_best.hdf5')
pred = np.squeeze(model.predict(np.expand_dims(x_val_display, axis=0)))
fg, (ax_x_ori, ax_y, ax_pred) = plt.subplots(1, 3, figsize=(20, 4))
ax_x_ori.imshow(cv2.cvtColor(cv2.imread(path_val_display), cv2.COLOR_BGR2RGB))
ax_x_ori.set_title('Original Image')
ax_y.imshow(y_val_display, cmap=plt.cm.jet)
ax_y.set_title('Ground_truth: ' + str(np.sum(y_val_display)))
ax_pred.imshow(pred, cmap=plt.cm.jet)
ax_pred.set_title('Prediction: ' + str(np.sum(pred)))
plt.suptitle('Loss = ' + str(lossMAE))
plt.show()