In [55]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D

data = np.load("gridsearch_score_data/CNN_score_accuracy.npy")
# parameters (in order as below): learning rate, regularization parameter, batch size, accuracy
eta = np.log10(data[:,0])
lmb = np.log10(data[:,1])
batch_size = data[:,2]
#batch_size = np.log2(data[:,2])-2 if equal distance preferred
accuracy = data[:,3]

# getting examined parameter values for axis ticks
eta_val = set(eta)
lmb_val = set(lmb)
batch_val = set(batch_size)
eta_val = np.array(list(eta_val))
lmb_val = np.array(list(lmb_val))
batch_val = np.array(list(batch_val))


sns.set_style("whitegrid", {'axes.grid' : False})
#sns.set_theme(style="white")

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(eta, lmb, batch_size, c=accuracy, s=48, cmap='RdBu', zorder=2, alpha=1)
ax.set_title("CNN accuracy")
ax.set_xlabel(r"learning rate, $\log(\eta)$")
ax.set_ylabel(r"regularization parameter, $\log(\lambda)$")
ax.set_zlabel("batch size")
ax.xaxis.set_ticks(eta_val)
ax.yaxis.set_ticks(lmb_val)
ax.zaxis.set_ticks(batch_val)
#lines
for i,j,k,h in zip(eta, lmb, batch_size, accuracy):
    ax.plot([i,i],[j,j],[k,h], color="#bcbcbc", zorder=1)

#colorbar
[vmin, vmax] = [accuracy.min(), accuracy.max()]
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=vmin, vmax=vmax), cmap='RdBu'),
            ax=ax, orientation='vertical', label='accuracy', pad = 0.15, shrink=0.75)
cbar_ticks = np.arange(np.ceil(100*vmin)/100, np.floor(100*vmax)/100-0.05, 0.05)
cbar_ticks = np.append(cbar_ticks, np.floor(100*vmax)/100)
cbar.ax.set_yticks(cbar_ticks)
cbar.ax.tick_params(labelsize=12)

#annotate
best_param = accuracy.argmax()
best_param = [eta[best_param], lmb[best_param], batch_size[best_param]]
ax.plot(best_param[0], best_param[1], best_param[2], marker='x', c='k', markersize=10, zorder=100)
ax.text(best_param[0]+0.08, best_param[1]+0.03, best_param[2]+0.02, f'{vmax:.2F}', c='k', fontsize=10, zorder=101)

# save figure
plt.savefig("gridsearch_accuracy")
plt.close()


In [59]:
#duplicate for loss

data = np.load("gridsearch_score_data/CNN_score_loss.npy")
# parameters (in order as below): learning rate, regularization parameter, batch size, loss
eta = np.log10(data[:,0])
lmb = np.log10(data[:,1])
batch_size = data[:,2]
#batch_size = np.log2(data[:,2])-2 if equal distance preferred
loss = data[:,3]

# getting examined parameter values for axis ticks
eta_val = set(eta)
lmb_val = set(lmb)
batch_val = set(batch_size)
eta_val = np.array(list(eta_val))
lmb_val = np.array(list(lmb_val))
batch_val = np.array(list(batch_val))


sns.set_style("whitegrid", {'axes.grid' : False})
#sns.set_theme(style="white")

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(eta, lmb, batch_size, c=loss, s=48, cmap='RdBu_r', zorder=2, alpha=1)
ax.set_title("CNN loss")
ax.set_xlabel(r"learning rate, $\log(\eta)$")
ax.set_ylabel(r"regularization parameter, $\log(\lambda)$")
ax.set_zlabel("batch size")
ax.xaxis.set_ticks(eta_val)
ax.yaxis.set_ticks(lmb_val)
ax.zaxis.set_ticks(batch_val)
#lines
for i,j,k,h in zip(eta, lmb, batch_size, loss):
    ax.plot([i,i],[j,j],[k,h], color="#bcbcbc", zorder=1)

#colorbar
[vmin, vmax] = [loss.min(), loss.max()]
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=vmin, vmax=vmax), cmap='RdBu_r'),
            ax=ax, orientation='vertical', label='loss', pad = 0.15, shrink=0.75)
cbar_ticks = np.arange(np.ceil(100*vmin)/100, np.floor(100*vmax)/100, 0.05)
cbar_ticks = np.append(cbar_ticks, np.floor(100*vmax)/100)
cbar.ax.set_yticks(cbar_ticks)
cbar.ax.tick_params(labelsize=12)

#annotate
best_param = loss.argmin()
best_param = [eta[best_param], lmb[best_param], batch_size[best_param]]
ax.plot(best_param[0], best_param[1], best_param[2], marker='x', c='k', markersize=10, zorder=100)
ax.text(best_param[0]+0.08, best_param[1]+0.03, best_param[2]+0.02, f'{vmin:.2F}', c='k', fontsize=10, zorder=101)

# save figure
plt.savefig("gridsearch_loss")
plt.close()