In [None]:
# Standard import
import numpy as np
import matplotlib.pyplot as plt
%matplotlib widget

In [None]:
# Machine learning libraries
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError

from tqdm.notebook import tqdm
from tqdm.keras import TqdmCallback

In [None]:
# Custom plots file and tensorflow models
from Tools import Plot
from Tools import Custom_models

In [None]:
# Set the seed for reproducibility
seed = 6
np.random.seed(seed)
tf.random.set_seed(seed*2)

In [None]:
# Extract data
data = np.loadtxt('Data/tanh.txt')
features = data[:,0]
labels = data[:,1]

In [None]:
# Separate data into training and testing sets (keeping 10% in the test set)
N = len(features)
stop = round(0.9*N)
perm = np.random.permutation(N)
features_training = features[perm[0:stop]]
features_testing = features[perm[stop:N]]
labels_training = labels[perm[0:stop]]
labels_testing = labels[perm[stop:N]]

In [None]:
# Create the model
rates = np.linspace(1e-4, 1e-2, 5)
N_epochs = 200
N_node = 100
N_out = 1
name = 'model_1D'
hist = []

for rate in tqdm(rates):
    opt = Adam(learning_rate=rate)
    model = Custom_models.Model_1D_1_layer(N_nodes=N_node, N_output=N_out, name=name)
    model.compile(loss=MeanSquaredError(), optimizer=opt)
    history = model.fit(features_training, labels_training, epochs=N_epochs, validation_split=0.1, verbose=0,
                    shuffle=True, initial_epoch=0, callbacks=[TqdmCallback()])
    hist.append(history)

In [None]:
plt.figure()
for i in range(5):
    plt.semilogy(hist[i].epoch, np.array(hist[i].history['loss']), label=f'rate = {rates[i]}')
plt.xlabel('Epoch')
plt.ylabel('Mean squared error')
plt.xlim([0,200])
plt.legend()
plt.title('Training loss')
plt.grid()
plt.show()

In [None]:
plt.figure()
for i in range(5):
    plt.semilogy(hist[i].epoch, np.array(hist[i].history['val_loss']), label=f'rate = {rates[i]}')
plt.xlabel('Epoch')
plt.ylabel('Mean squared error')
plt.xlim([0,200])
plt.legend()
plt.title('Validation loss')
plt.grid()
plt.show()