In [31]:
import optuna
import joblib
import matplotlib.pyplot as plt
import json
import ast

#### optuna search

In [2]:
OPTUNA_STUDY_PATH = '../experiments/log6/optuna_study.pkl'
study = joblib.load(OPTUNA_STUDY_PATH)

In [None]:
study.best_trials[0].params

#### best config train

In [511]:
LOG_NAME = '7'

TRAIN_LOG_PATH = f'../experiments/log{LOG_NAME}/run_logs.txt'
STEP = 20 # 1 | 20

info = {
    'train_l': [],
    'eval_l': [],
    'psnr': [],
    'ssim': [],
    'scc': [],
    'sam': [],
    'epoch': []
}

with open(TRAIN_LOG_PATH, 'r', encoding='utf-8') as fd:
    data = fd.readlines()
    tmp = []
    for i, item in enumerate(data):
        try:
            tmp.append(ast.literal_eval(item.strip()))
        except ValueError:
            print(i)
            break
    

    for i, epoch_info in enumerate(tmp):
        if i % STEP == 0:
            info['epoch'].append(epoch_info['epoch'])
            info['train_l'].append(epoch_info['train_loss'])
            info['eval_l'].append(epoch_info['eval_loss'])
            info['psnr'].append(epoch_info['scores']['psnr'])
            info['ssim'].append(epoch_info['scores']['ssim'])
            info['scc'].append(epoch_info['scores']['scc'])
            info['sam'].append(epoch_info['scores']['sam'])

In [512]:
l_off = {
    '1': [+10, 0.005],
    '2': [+12, 0.005],
    '3': [+12, 0.002],
    '4': [+12, 0.001],
    '5': [+12, 0.001],
    '6': [+12, 0.001],
    '7': [-20, -0.003]
}

l_lim = {
    '1': [-0.005, 0.1],
    '2': [-0.005, 0.1],
    '3': [0, 0.03],
    '4': [0, 0.02],
    '5': [0, 0.02],
    '6': [0, 0.02],
    '7': [0, 0.02]
}

s_ticks = {
    '1': list(range(0,101,10)),
    '2': list(range(0,101,10)),
    '3': list(range(0,101,10)),
    '4': list(range(0,101,10)),
    '5': list(range(0,101,10)),
    '6': list(range(0,101,10)),
    '7': list(range(0,1521,200))
}

In [None]:
# Initialise the subplot function using number of rows and columns 

plt.figure(figsize=(6,4))

plt.plot(info['epoch'], info['eval_l'], label='eval_loss', c='blue')


best_epoch = info['eval_l'].index(min(info['eval_l']))
plt.scatter(info['epoch'][best_epoch], info['eval_l'][best_epoch], c='red')
plt.text(info['epoch'][best_epoch] - l_off[LOG_NAME][0], info['eval_l'][best_epoch] - l_off[LOG_NAME][1], 
         f"{round(info['eval_l'][best_epoch], 4)}", c='red')
plt.axvline(x = info['epoch'][best_epoch], color = 'r', linestyle='dashed')

plt.plot(info['epoch'], info['train_l'], label='train_loss', c='orange')

plt.ylim(l_lim[LOG_NAME][0],l_lim[LOG_NAME][1])
plt.xticks(s_ticks[LOG_NAME])

plt.xlabel('epoch')
plt.ylabel('loss')

plt.legend()
plt.grid()
plt.savefig("loss_plot.jpeg")
plt.show()

In [514]:
s_off = {
    '1': {'psnr': [-13, -0.4], 'ssim': [-10, -0.02], 'scc': [-10, -0.005], 'sam': [-10, 0.005]},
    '2': {'psnr': [-10, -0.05], 'ssim': [-10, -0.05], 'scc': [-10,  -0.01], 'sam': [2, 0]},
    '3': {'psnr': [-12, -2.3], 'ssim': [-10, -0.05], 'scc': [-8, -0.05], 'sam': [2, 0.015]},
    '4': {'psnr': [2, -3], 'ssim': [2, -0.08], 'scc': [-10, -0.1], 'sam': [2, 0.02]},
    '5': {'psnr': [2, -3], 'ssim': [2, -0.08], 'scc': [2, -0.1], 'sam': [-10, 0.02]},
    '6': {'psnr': [-12, -3.2], 'ssim': [-10, -0.1], 'scc': [-10, -0.1], 'sam': [-10, 0.03]},
    '7':{'psnr': [-200, -0.5], 'ssim': [+30, -0.1], 'scc': [+30, -0.1], 'sam': [+30, 0.02]}
}

In [None]:
figure, axis = plt.subplots(2, 2, ) 
figure.set_figwidth(10)
figure.tight_layout(pad=2.8)

#
psnt_best_s = info['psnr'][best_epoch]
axis[0, 0].plot(info['epoch'], info['psnr'], c='blue') 
axis[0, 0].set_ylabel("PSNR-score")
axis[0, 0].set_xlabel("epoch")
axis[0, 0].grid()
axis[0, 0].set_xticks(s_ticks[LOG_NAME])
axis[0, 0].axvline(x = info['epoch'][best_epoch], color = 'r', linestyle='dashed')
axis[0, 0].scatter(info['epoch'][best_epoch], psnt_best_s, c='red')
axis[0, 0].text(info['epoch'][best_epoch] + s_off[LOG_NAME]['psnr'][0], 
                psnt_best_s + s_off[LOG_NAME]['psnr'][1], f"{round(psnt_best_s,2)}", c='red')

#
ssim_best_s = info['ssim'][best_epoch]
axis[0, 1].plot(info['epoch'], info['ssim'], c='blue') 
axis[0, 1].set_ylabel("SSIM-score") 
axis[0, 1].set_xlabel("epoch")
axis[0, 1].grid()
axis[0, 1].set_xticks(s_ticks[LOG_NAME])
axis[0, 1].axvline(x = info['epoch'][best_epoch], color = 'r', linestyle='dashed')
axis[0, 1].scatter(info['epoch'][best_epoch], ssim_best_s, c='red')
axis[0, 1].text(info['epoch'][best_epoch] + s_off[LOG_NAME]['ssim'][0], 
                ssim_best_s + s_off[LOG_NAME]['ssim'][1], f"{round(ssim_best_s,2)}", c='red')


#
scc_best_s = info['scc'][best_epoch]
axis[1, 0].plot(info['epoch'], info['scc'], c='blue') 
axis[1, 0].set_ylabel("SCC-score")
axis[1, 0].set_xlabel("epoch")
axis[1, 0].grid()
axis[1, 0].set_xticks(s_ticks[LOG_NAME])
axis[1, 0].axvline(x = info['epoch'][best_epoch], color = 'r', linestyle='dashed')
axis[1, 0].scatter(info['epoch'][best_epoch], scc_best_s, c='red')
axis[1, 0].text(info['epoch'][best_epoch] + s_off[LOG_NAME]['scc'][0], 
                scc_best_s + s_off[LOG_NAME]['scc'][1], f"{round(scc_best_s,2)}", c='red')

#
sam_best_s = info['sam'][best_epoch]
axis[1, 1].plot(info['epoch'], info['sam'], c='blue') 
axis[1, 1].set_ylabel("SAM-score")
axis[1, 1].set_xlabel("epoch")
axis[1, 1].grid()
axis[1, 1].set_xticks(s_ticks[LOG_NAME])
axis[1, 1].axvline(x = info['epoch'][best_epoch], color = 'r', linestyle='dashed')
axis[1, 1].scatter(info['epoch'][best_epoch], sam_best_s, c='red')
axis[1, 1].text(info['epoch'][best_epoch] + s_off[LOG_NAME]['sam'][0], 
                sam_best_s + s_off[LOG_NAME]['sam'][1], f"{round(sam_best_s,2)}", c='red')

plt.savefig("scores_plot.jpeg")
plt.show()