Burgers equation in time - Implementing the FNO model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import prevision as prv
import tensorflow as tf

In [None]:
lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(
  0.01,    
  decay_steps=100,
  decay_rate=2,
  staircase=True)

def custom_loss(y_true, y_pred):
    return tf.keras.losses.cosine_similarity(y_true,y_pred) + 30*tf.keras.losses.huber(y_true,y_pred)

def get_optimizer():
    return tf.keras.optimizers.Adam(lr_schedule)

def get_callbacks(name):
    return [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=500, min_delta=0.0001),
        ]

def compile_and_fit(model, name, x_train, y_train, x_test, y_test, optimizer=None, max_epochs=1000, batch_size=32):

    if optimizer is None:
        optimizer = get_optimizer()

    model.compile(optimizer=optimizer,
                loss=custom_loss,
                metrics=[
                  tf.keras.losses.MeanSquaredError("auto", "mean_squared_error")])

    model.summary()

    history = model.fit(
            x_train, y_train,
            epochs=max_epochs,
            validation_data=(x_test,y_test),
            callbacks=get_callbacks(name),
            batch_size = batch_size,
            verbose=2)

    return history

# Loading training dataset
folder = 'Burgers_time'
input_train=np.load('../data/' + folder + '/input_train.npy')
input_test=np.load('../data/' + folder + '/input_test.npy')
output_train=np.load('../data/' + folder + '/output_train.npy')
output_test=np.load('../data/' + folder + '/output_test.npy')

INPUTDIM = (input_train.shape[1],)
OUTPUTDIM = (output_train.shape[1],)

# Creting model
model = prv.FNO(INPUTDIM, OUTPUTDIM, p_dim=512, n=3, k_max=17, verbose=True, model_name='Burgers_time_FNO', dropout=0.01, kernel_reg=0.001)
history = compile_and_fit(model, model.name, x_train=input_train, y_train=output_train, x_test=input_test, y_test=output_test, batch_size=60 , max_epochs=5000)

bc     = history.history['loss']
val_bc = history.history['val_loss']

plt.semilogx(range(np.shape(bc)[0]),bc, linewidth=2,)
plt.semilogx(range(np.shape(val_bc)[0]),val_bc,'--', linewidth=2,)
plt.title('FNO learning')
plt.grid(True)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.ylim([-1,0])
plt.legend(('Loss training set','Loss validation set'))
plt.show()

model.save('../data/'+folder+'/Burgers_time_FNO.h5')