In [None]:
import tensorflow as tf
import poisson_CNN_old as poisson_CNN
import numpy as np
import matplotlib.pyplot as plt
import warnings
import matplotlib as mpl

warnings.filterwarnings("ignore")
tf.keras.backend.set_floatx('float32')

img_path = '/home/ago14/storage/ali/manuscript/results/training_set_examples/'
examplename = 'tset_dbcnn_2'
case = 'dbcnn_'

In [None]:
#Load model
br = None#tf.keras.regularizers.L1L2(l2=1e-5)#
kr = None#tf.keras.regularizers.L1L2(l2=1e-5)#
mod = poisson_CNN.models.Dirichlet_BC_NN(data_format = 'channels_first', initial_kernel_size=13, final_kernel_size=3, conv1d_final_channels=256, kernel_regularizer=kr, bias_regularizer=br, mae_component_weight = 5e+1, mse_component_weight = 5e+2, n_quadpts=34)
mod((tf.random.uniform((10,1,74), dtype = tf.keras.backend.floatx()), tf.random.uniform((10,1), dtype = tf.keras.backend.floatx()), tf.constant(np.random.randint(48,96))))
from IPython.display import clear_output
clear_output()
mod.load_weights('DBCNN_direct_192_228_5e-3_5e-2_2.h5')

In [None]:
#Set up data generator
output_shape_range = [[192,228],[192,228]]#[[120,260],[120,260]]#
rdxr = [0.005,0.05]#[0.1,0.3]#
brsr = {'left':[5,8], 'top':[5,8], 'right':[5,8], 'bottom': [5,8]}
dg = poisson_CNN.dataset.generators.numerical_dataset_generator(batch_size = 20, batches_per_epoch = 50, rhses = 'zero', return_rhs = False, return_boundaries = True, return_keras_style=True, nonzero_boundaries = ['left'], exclude_zero_boundaries=True, return_dx = True, return_shape = True, random_output_shape_range = output_shape_range, random_dx_range = rdxr, randomize_boundary_smoothness=True, boundary_random_smoothness_range=brsr)

In [None]:
#Training - do not execute just to try out the model!
loss = lambda yt,yp: tf.keras.losses.mse(yt,yp) + tf.keras.losses.mean_squared_logarithmic_error(yt,yp)
savefile = 'DBCNN_direct_192_228_5e-3_5e-2_logerror.h5'
cb = [tf.keras.callbacks.ModelCheckpoint(savefile, monitor='mse', verbose=1, save_best_only=True, save_weights_only=True), tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', min_lr = 1e-8, verbose = True, patience = 5)]
mod.compile(loss = 'mse', optimizer = tf.keras.optimizers.Adam(learning_rate = 5e-6), metrics = ['mse', 'mae'])

mod.run_eagerly = True
mod.fit_generator(generator=dg, epochs=5000, callbacks=cb)

In [None]:
#Get new batch of data, use model to predict
inp,soln = dg.__getitem__(10)
pred = mod(inp)

In [None]:
#Evaluate batch error statistics
q = np.abs(pred - tf.cast(soln, tf.keras.backend.floatx()))/np.abs(tf.cast(soln, tf.keras.backend.floatx()))
rms = tf.sqrt(tf.reduce_mean((pred - tf.cast(soln, tf.keras.backend.floatx()))**2))
mae = tf.reduce_mean(tf.abs(pred - tf.cast(soln, tf.keras.backend.floatx())))
print('Mean abs % error: ' + str(100*float(tf.reduce_mean(q[q < 1]))))
print('% of gridpts with less than 10% error: ' + str(100*float(np.sum(q < 0.1)/np.prod(q.shape))))
print('RMS error: ' + str(float(rms)))
print('MAE: ' + str(float(mae)))

In [None]:
#Ground truth plotting
import matplotlib as mpl
mpl.rcParams['figure.dpi']=300

saving = False

x, y = np.meshgrid(np.linspace(0, soln.shape[-2]*inp[1][0,0], soln.shape[-2]), np.linspace(0, soln.shape[-1]*inp[1][0,0], soln.shape[-1]), indexing = 'ij')
print('Mean abs % error of sample: ' + str(100*float(tf.reduce_mean(q[p_r,0,...][q[p_r,0,...] < 1]))))
z = soln[p_r,0,...]
dx = float(inp[1][p_r,0])
print('dx: ' + str(dx))
print(float(tf.reduce_mean(tf.keras.losses.mae(pred[p_r,...],tf.cast(soln[p_r,...],tf.keras.backend.floatx())))/tf.cast(tf.reduce_max(tf.abs(soln[p_r,...])),tf.keras.backend.floatx())))
print(z.shape)
z_min, z_max = -np.abs(z).max(), np.abs(z).max()
fig, ax = plt.subplots()
c = ax.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
ax.set_title('Ground truth')
ax.axis([x.min(), x.max(), y.min(), y.max()])
fig.colorbar(c, ax=ax)

x1 = np.max(x)/8
x2 = 7*np.max(x)/8
y1 = 1*np.max(y)/4
y2 = 2*np.max(y)/4
plt.axvline(x=x1, color = 'm', linestyle = '--')
plt.axvline(x=x2, color = 'm', linestyle = '--')
plt.axhline(y=y1, color = 'm', linestyle = '--')
plt.axhline(y=y2, color = 'm', linestyle = '--')

if saving:
    identifier = '_' + str(int(soln.shape[-2])) + 'x' + str(int(soln.shape[-1])) + '_dx' + ('{:.2e}'.format(dx)) + '_'
    thisimage = 'groundtruth'
    plt.savefig(img_path + examplename + '/' + examplename + identifier + case + thisimage + '.png', bbox_inches = 'tight')


plt.show()

In [None]:
#Prediction plotting
x, y = np.meshgrid(np.linspace(0, soln.shape[-2]*inp[1][0,0], soln.shape[-2]), np.linspace(0, soln.shape[-1]*inp[1][0,0], soln.shape[-1]), indexing = 'ij')
pred = mod(inp)
z = pred[p_r,0,...]
print(pred.shape)
fig, ax = plt.subplots()
c = ax.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
ax.set_title('DBCNN Prediction')
ax.axis([x.min(), x.max(), y.min(), y.max()])
fig.colorbar(c, ax=ax)

if saving:
    thisimage = 'pred'
    plt.savefig(img_path + examplename + '/' + examplename + identifier + case + thisimage + '.png', bbox_inches = 'tight')


plt.show()

In [None]:
#BC plotting
xpos = 0
plt.plot(y[0,:], soln[p_r,0,xpos,:], label = 'Ground truth')
plt.legend()
plt.title('Imposed boundary condition')

if saving:
    thisimage = 'bc'
    plt.savefig(img_path + examplename + '/' + examplename + identifier + case + thisimage + '.png', bbox_inches = 'tight')

plt.show()

In [None]:
#Plots at constant y (change ypos variable to adjust the y values)
ypos = (soln.shape[-1]//4)
print(np.squeeze(pred[p_r,0,:,ypos]).shape)
plt.plot(x[:,0], pred[p_r,0,:,ypos], label = 'Pred. y=' + ('{:.4f}'.format((ypos * inp[1][p_r,0]).numpy())))
plt.plot(x[:,0], soln[p_r,0,:,ypos], label = 'Gnd. tru. y=' + ('{:.4f}'.format((ypos * inp[1][p_r,0]).numpy())))
ypos = (2*soln.shape[-1]//4)
plt.plot(x[:,0], pred[p_r,0,:,ypos], label = 'Pred. y=' + ('{:.4f}'.format((ypos * inp[1][p_r,0]).numpy())))
plt.plot(x[:,0], soln[p_r,0,:,ypos], label = 'Gnd. tru. y=' + ('{:.4f}'.format((ypos * inp[1][p_r,0]).numpy())))
plt.legend()
plt.title('Values of the solution at constant y')

if saving:
    thisimage = 'consty'
    plt.savefig(img_path + examplename + '/' + examplename + identifier + case + thisimage + '.png', bbox_inches = 'tight')

plt.show()

In [None]:
#Plots at constant x
xpos = (soln.shape[-2]//8)
print(np.squeeze(pred[p_r,0,xpos,:]).shape)
plt.plot(y[0,:], pred[p_r,0,xpos,:], label = 'Pred. x=' + ('{:.4f}'.format((xpos * inp[1][p_r,0]).numpy())))
plt.plot(y[0,:], soln[p_r,0,xpos,:], label = 'Gnd. tru. x=' + ('{:.4f}'.format((xpos * inp[1][p_r,0]).numpy())))
xpos = (7*soln.shape[-2]//8)
plt.plot(y[0,:], pred[p_r,0,xpos,:], label = 'Pred. x=' + ('{:.4f}'.format((xpos * inp[1][p_r,0]).numpy())))
plt.plot(y[0,:], soln[p_r,0,xpos,:], label = 'Gnd. tru. x=' + ('{:.4f}'.format((xpos * inp[1][p_r,0]).numpy())))
plt.legend()
plt.title('Values of the solution at constant x')

if saving:
    thisimage = 'constx'
    plt.savefig(img_path + examplename + '/' + examplename + identifier + case + thisimage + '.png', bbox_inches = 'tight')

plt.show()

In [None]:
#Percentage error map plot
x, y = np.meshgrid(np.linspace(0, soln.shape[-2]*inp[1][p_r,0], soln.shape[-2]), np.linspace(0, soln.shape[-1]*inp[1][p_r,0], soln.shape[-1]), indexing = 'ij')
z = 100*tf.abs((pred[p_r,0,...] - tf.cast(soln[p_r,0,...], tf.keras.backend.floatx()))/tf.cast(soln[p_r,0,...], tf.keras.backend.floatx()))
z_min, z_max = 0,100
print(np.mean(np.array(z)[np.array(z) < 100000000]))
fig, ax = plt.subplots()
c = ax.pcolormesh(x, y, z, cmap='Blues', vmin=z_min, vmax=z_max)
ax.set_title('Absolute error percentage')
ax.axis([x.min(), x.max(), y.min(), y.max()])
fig.colorbar(c, ax=ax)

if saving:
    thisimage = 'percentageerrormap'
    plt.savefig(img_path + examplename + '/' + examplename + identifier + case + thisimage + '.png', bbox_inches = 'tight')

plt.show()

In [None]:
#Error map plot
x, y = np.meshgrid(np.linspace(0, soln.shape[-2]*inp[1][p_r,0], soln.shape[-2]), np.linspace(0, soln.shape[-1]*inp[1][p_r,0], soln.shape[-1]), indexing = 'ij')
z = pred[p_r,0,...] - tf.cast(soln[p_r,0,...], tf.keras.backend.floatx())
z_min, z_max = -0.3*np.abs(z).max(), 0.3*np.abs(z).max()
print(np.mean(np.abs(z)))
#z_min, z_max = -1,1
#z_min, z_max = -0.8,0.8
fig, ax = plt.subplots()
c = ax.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
#c = ax.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
ax.set_title('Absolute error')
ax.axis([x.min(), x.max(), y.min(), y.max()])
fig.colorbar(c, ax=ax)

if saving:
    thisimage = 'errormap'
    plt.savefig(img_path + examplename + '/' + examplename + identifier + case + thisimage + '.png', bbox_inches = 'tight')

plt.show()

In [None]:
#Evaluate error statistics for 600 samples (can take long!)
from IPython.display import clear_output
rmses = []
maes = []
mapes = []
pgpb10s = []
nepochs = 60
roundtonearest = lambda x,base: base * round(x/base)
for k in range(nepochs):
    if k%5 == 0:
        progress = int(roundtonearest(k/nepochs,0.05)//0.05)
        clear_output()
        progressbar = ''.join(['['] + ['=' for s in range(progress-1)] + ['>'] + [' ' for s in range(20-progress)] + [']'])
        print(progressbar)
        
    inp600, soln600 = dg.__getitem__(10)
    pred600 = mod(inp600)
    rms = tf.sqrt(tf.reduce_mean((pred600 - tf.cast(soln600, tf.keras.backend.floatx()))**2))
    mae = tf.reduce_mean(tf.abs(pred600 - tf.cast(soln600, tf.keras.backend.floatx())))
    q = np.abs(pred600 - tf.cast(soln600, tf.keras.backend.floatx()))/np.abs(tf.cast(soln600, tf.keras.backend.floatx()))
    rmses.append(rms)
    maes.append(mae)
    mapes.append(100*float(tf.reduce_mean(q[q < 1])))
    pgpb10s.append(100*float(np.sum(q < 0.1)/np.prod(q.shape)))
print('Mean abs % error: ' + str(np.mean(mapes)))
print('% of gridpts with less than 10% error: ' + str(np.mean(pgpb10s)))
print('RMS error: ' + str(np.mean(rmses)))
print('MAE: ' + str(np.mean(maes)))