In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
%config IPCompleter.use_jedi = False

In [None]:
fname = 'dataset.npz'
with np.load(fname) as data:
    designs = data['D']
    responses = data['R']
    
n_grating_layers = designs.shape[-1]
n_freqs = responses.shape[-1]
Dtrain, Dtest, Rtrain, Rtest = train_test_split(designs, responses,
                                                test_size=0.1,
                                                random_state=42)
print("Train set contains {} samples".format(Dtrain.shape[0]))
print("Validation set contains {} samples".format(Dtest.shape[0]))

In [None]:
activation = keras.activations.relu
# Architecture 3
model = keras.Sequential([layers.Input((n_grating_layers,)),
                          layers.Dense(500, activation=activation),
                          layers.Dense(200, activation=activation),
                          layers.Dense(200, activation=activation),
                          layers.Dense(n_freqs, activation='sigmoid')])
model.summary()

In [None]:
loss = keras.losses.MeanSquaredError()
model.compile(optimizer='adam', loss=loss)

In [None]:
initial_epoch = 0

In [None]:
info = model.fit(Dtrain, Rtrain,
                 batch_size=128, epochs=100,
                 validation_data=(Dtest, Rtest),
                 validation_freq=5,
                 initial_epoch=initial_epoch)
initial_epoch = info.epoch[-1]

In [None]:
val_loss = info.history['val_loss']
loss = info.history['loss']
fig, ax = plt.subplots()
ax.plot(info.epoch, n_freqs * np.array(loss), label='Loss')
ax.plot(info.epoch[::5], n_freqs * np.array(val_loss), label='Validation Loss')
ax.legend()
ax.set_xlabel('Epoch')

In [None]:
idx = np.random.randint(0, Dtest.shape[0], 1)
dnn_responses = model(Dtest[idx]).numpy()
responses = Rtest[idx]
for o, r in zip(dnn_responses, responses):
    line, = plt.plot(o, '--')
    plt.plot(r, '-', color=line.get_color())