# Convolutional autoencoders (CAE) and LSTMs for PDE surrogates

*Original author*: Romit Maulik

*Additional edits*: Kyle Felker

Maulik, R., Lusch, B., & Balaprakash, P. (2021). Reduced-order modeling of advection-dominated systems with recurrent neural networks and convolutional autoencoders. Physics of Fluids, 33(3), 037106.
https://aip.scitation.org/doi/abs/10.1063/5.0039986

## Shallow water equations (SWE)

![Shallow water equations](media/Shallow_water_waves.gif "Shallow water equations")
<center><a href="https://en.wikipedia.org/wiki/Shallow_water_equations"><b>Wikipedia</b></a>: Output from a shallow-water equation model of water in a bathtub. The water experiences five splashes which generate surface gravity waves that propagate away from the splash locations and reflect off the bathtub walls. </center>

### Conservative form of SWE:

$$
\begin{align}
\frac{\partial (\rho \eta) }{\partial t} &+ \frac{\partial (\rho \eta u)}{\partial x} + \frac{\partial (\rho \eta v)}{\partial y} = 0,\\[3pt]
\frac{\partial (\rho \eta u)}{\partial t} &+ \frac{\partial}{\partial x}\left( \rho \eta u^2 + \frac{1}{2}\rho g \eta^2 \right) + \frac{\partial (\rho \eta u v)}{\partial y} = 0,\\[3pt]
\frac{\partial (\rho \eta v)}{\partial t} &+ \frac{\partial (\rho \eta uv)}{\partial x} + \frac{\partial}{\partial y}\left(\rho \eta v^2 + \frac{1}{2}\rho g \eta ^2\right) = 0.
\end{align}
$$

where $\eta(x,y,t)$ is the fluid column height, $(u,v)$ is the horizontal flow velocity. Inviscid, incompressible fluid with the horizontal length scale >> vertical length scale. Used in atmospheric and oceanic modeling, when you include Coriolis forces. 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# Set seeds
np.random.seed(10)
tf.random.set_seed(10)

from tensorflow.keras.layers import Input, Dense, LSTM, Lambda, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, UpSampling2D, MaxPooling2D

from tensorflow.keras import optimizers, models, regularizers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.models import load_model, Sequential, Model
from tensorflow.keras.regularizers import l1
from tensorflow.keras.utils import plot_model

from scipy.signal import savgol_filter

import seaborn as sns
import re
import matplotlib as mpl
import matplotlib.font_manager as fm
import matplotlib.ticker as mtick
mpl.rcParams['figure.dpi'] = 300

from IPython.display import Image

### Load SWE snapshot data

Generated using a finite volume simulation:
> Our full-order model uses a fourth-order accurate Runge–Kutta temporal integration scheme and a ﬁfth-order accurate weighted essentially nonoscillatory scheme (WENO) 57 for computing state reconstructions at cell faces. The Rusanov Reimann solver is utilized for ﬂux reconstruction after cell-face quantities are calculated.



In [None]:
# single initial condition (simulation) in each dataset
swe_data = np.transpose(np.load('datasets/train.npy'))
swe_data_v = np.transpose(np.load('datasets/validation.npy'))
print(swe_data.shape)

In [None]:
# timestep, y, x, variable (channel)
swe_train_data = np.zeros(shape=(400,64,64,3)) # Channels last
swe_valid_data = np.zeros(shape=(400,64,64,3))

for i in range(np.shape(swe_data)[0]):
    temp_1 = swe_data[i,0:64*64].reshape(64,64)
    temp_2 = swe_data[i,64*64:2*64*64].reshape(64,64)
    temp_3 = swe_data[i,2*64*64:3*64*64].reshape(64,64)
    swe_train_data[i,:,:,0] = np.transpose(temp_1[:,:])
    swe_train_data[i,:,:,1] = np.transpose(temp_2[:,:])
    swe_train_data[i,:,:,2] = np.transpose(temp_3[:,:])
    
for i in range(np.shape(swe_data_v)[0]):
    temp_1 = swe_data_v[i,0:64*64].reshape(64,64)
    temp_2 = swe_data_v[i,64*64:2*64*64].reshape(64,64)
    temp_3 = swe_data_v[i,2*64*64:3*64*64].reshape(64,64)
    swe_valid_data[i,:,:,0] = np.transpose(temp_1[:,:])
    swe_valid_data[i,:,:,1] = np.transpose(temp_2[:,:])
    swe_valid_data[i,:,:,2] = np.transpose(temp_3[:,:])
    

# Just to keep things simple, cut datasets in half by number of timesteps
# swe_train_data = swe_train_data[0:200,:,:,:]
# swe_valid_data = swe_valid_data[0:200,:,:,:]

# Normalize inputs ([0,1] for \rho\eta, arbitrary for other two vars)
for j in range(3):
    swe_train_data[:,:,:,j] = (swe_train_data[:,:,:,j] - np.min(swe_train_data[:,:,:,0]))/(np.max(swe_train_data[:,:,:,0])-np.min(swe_train_data[:,:,:,0]))
    swe_valid_data[:,:,:,j] = (swe_valid_data[:,:,:,j] - np.min(swe_valid_data[:,:,:,0]))/(np.max(swe_valid_data[:,:,:,0])-np.min(swe_valid_data[:,:,:,0]))

# Visualize one time instance
time = 0
fig, ax = plt.subplots(nrows=1,ncols=3, figsize=(15,15))

ax[0].imshow(swe_valid_data[time,:,:,0])
ax[1].imshow(swe_valid_data[time,:,:,1])
ax[2].imshow(swe_valid_data[time,:,:,2])

ax[0].set_title(r'$q_1 =\rho \eta$')
ax[1].set_title(r'$q_2 = \rho\eta u$')
ax[2].set_title(r'$q_3 = \rho\eta v$')
ax[0].set_ylabel(f"$t = {time}$",rotation=0, labelpad=30, fontsize=16) # fontweight='bold')
# fig.suptitle(f"$t = {time}$")
# plt.tight_layout()
# plt.subplots_adjust(top=0.85)
# plt.show()

In [None]:
# Visualize multiple time instances
for time in range(0,200,10):
    fig, ax = plt.subplots(nrows=1,ncols=3)
    ax[0].imshow(swe_train_data[time,:,:,0])
    ax[1].imshow(swe_train_data[time,:,:,1])
    ax[2].imshow(swe_train_data[time,:,:,2])
    ax[0].set_title(r'$q_1 = \rho\eta$')
    ax[1].set_title(r'$q_2 = \rho\eta u$')
    ax[2].set_title(r'$q_3 = \rho\eta v$')
    ax[0].set_ylabel(f"$t = {time}$",rotation=0, labelpad=30, fontsize=16) 
    plt.show()

### ML Presets and Custom Functions

In [None]:
mode = 'test'  # 'train'
lrate = 0.001

def mean_absolute_error(y_pred,y_true):
    return K.mean(K.abs(y_true-y_pred))

def max_absolute_error(y_pred,y_true):
    return K.max(K.abs(y_true-y_pred))

### Autoencoder

In [None]:
# Define recursive model architecture
weights_filepath = 'saved_models/SWE_CAE_Weights.h5'
## Encoder
encoder_inputs = Input(shape=(64,64,3),name='Field')
# Encode   
x = Conv2D(30,kernel_size=(3,3),activation='relu',padding='same')(encoder_inputs)
enc_l2 = MaxPooling2D(pool_size=(2, 2),padding='same')(x)

x = Conv2D(20,kernel_size=(3,3),activation='relu',padding='same')(enc_l2)
enc_l3 = MaxPooling2D(pool_size=(2, 2),padding='same')(x)

x = Conv2D(10,kernel_size=(3,3),activation='relu',padding='same')(enc_l3)
enc_l4 = MaxPooling2D(pool_size=(2, 2),padding='same')(x)

x = Conv2D(15,kernel_size=(3,3),activation='relu',padding='same')(enc_l4)
enc_l5 = MaxPooling2D(pool_size=(2, 2),padding='same')(x)

x = Conv2D(1,kernel_size=(3,3),activation=None,padding='same')(enc_l5)
encoded = MaxPooling2D(pool_size=(2, 2),padding='same')(x)

encoder = Model(inputs=encoder_inputs,outputs=encoded)
    
## Decoder
decoder_inputs = Input(shape=(2,2,1),name='decoded')

x = Conv2D(1,kernel_size=(3,3),activation='relu',padding='same')(decoder_inputs)
dec_l1 = UpSampling2D(size=(2, 2))(x)

x = Conv2D(5,kernel_size=(3,3),activation='relu',padding='same')(dec_l1)
dec_l2 = UpSampling2D(size=(2, 2))(x)

x = Conv2D(10,kernel_size=(3,3),activation='relu',padding='same')(dec_l2)
dec_l3 = UpSampling2D(size=(2, 2))(x)

x = Conv2D(20,kernel_size=(3,3),activation='relu',padding='same')(dec_l3)
dec_l4 = UpSampling2D(size=(2, 2))(x)

x = Conv2D(30,kernel_size=(3,3),activation='relu',padding='same')(dec_l4)
dec_l5 = UpSampling2D(size=(2, 2))(x)

decoded = Conv2D(3,kernel_size=(3,3),activation=None,padding='same')(dec_l5)
    
decoder = Model(inputs=decoder_inputs,outputs=decoded)

## Autoencoder
ae_outputs = decoder(encoder(encoder_inputs))
  
model = Model(inputs=encoder_inputs,outputs=ae_outputs,name='CAE')
   
# design network
my_adam = optimizers.Adam(learning_rate=lrate, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)

checkpoint = ModelCheckpoint(weights_filepath, monitor='loss', verbose=1, save_best_only=True, mode='min',save_weights_only=True)
earlystopping = EarlyStopping(monitor='loss', min_delta=0, patience=10, verbose=0, mode='auto', baseline=None, restore_best_weights=False)
callbacks_list = [checkpoint]

# fit network
model.compile(optimizer=my_adam,loss='mean_squared_error',metrics=[mean_absolute_error,max_absolute_error])    
model.summary()

num_epochs = 5000
batch_size = 4

if mode == 'train':
    train_history = model.fit(x=swe_train_data, y=swe_train_data, epochs=num_epochs, batch_size=batch_size, callbacks=callbacks_list)

In [None]:
plot_model(encoder, show_shapes=True, show_layer_names=True, to_file='encoder-model.png')
Image('encoder-model.png')

In [None]:
plot_model(decoder, show_shapes=True, show_layer_names=True, to_file='decoder-model.png')
Image('decoder-model.png')

### Check accuracy

In [None]:
time = 100

model.load_weights(weights_filepath)
from scipy.ndimage import gaussian_filter
recoded_1 = model.predict(swe_train_data[time:time+1,:,:,:])

fig, ax = plt.subplots(nrows=2,ncols=3,figsize=(14,12))

cs1 = ax[0,0].imshow(swe_train_data[time,:,:,0],label='input')
ax[1,0].imshow(gaussian_filter(recoded_1[0,:,:,0],sigma=2),label='decoded')

cs2 = ax[0,1].imshow(swe_train_data[time,:,:,1],label='input')
ax[1,1].imshow(gaussian_filter(recoded_1[0,:,:,1],sigma=2),label='decoded')

cs3 = ax[0,2].imshow(swe_train_data[time,:,:,2],label='input')
ax[1,2].imshow(gaussian_filter(recoded_1[0,:,:,2],sigma=2),label='decoded')

for i in range(2):
    for j in range(3):
        ax[i,j].set_xlabel('x')
        ax[i,j].set_ylabel('y')
        
fig.colorbar(cs1,ax=ax[0,0],fraction=0.046, pad=0.04)
fig.colorbar(cs1,ax=ax[1,0],fraction=0.046, pad=0.04)

fig.colorbar(cs2,ax=ax[0,1],fraction=0.046, pad=0.04)
fig.colorbar(cs2,ax=ax[1,1],fraction=0.046, pad=0.04)

fig.colorbar(cs3,ax=ax[0,2],fraction=0.046, pad=0.04)
fig.colorbar(cs3,ax=ax[1,2],fraction=0.046, pad=0.04)


ax[0,0].set_title(r'True $q_1$')
ax[0,1].set_title(r'True $q_2$')
ax[0,2].set_title(r'True $q_3$')

ax[1,0].set_title(r'Reconstructed $q_1$')
ax[1,1].set_title(r'Reconstructed $q_2$')
ax[1,2].set_title(r'Reconstructed $q_3$')

plt.subplots_adjust(wspace=0.5,hspace=-0.3)
plt.show()

### Generate encoded data for LSTM learning

In [None]:
encoded = K.eval(encoder(swe_train_data[:,:,:,:].astype('float32')))
# Visualize latent space
fig, ax = plt.subplots(nrows=1,ncols=2,figsize=(7,6))
time = 98
cs = ax[0].imshow(encoded[time,:,:,0])
fig.colorbar(cs,ax=ax[0],fraction=0.046, pad=0.04)

time = 198
ax[1].imshow(encoded[time,:,:,0])
fig.colorbar(cs,ax=ax[1],fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

In [None]:
encoded = encoded.reshape(-1,4)

In [None]:
plt.figure()
plt.plot(encoded[:,0],label='Mode 1')
plt.plot(encoded[:,1],label='Mode 2')
plt.plot(encoded[:,2],label='Mode 3')
plt.plot(encoded[:,3],label='Mode 4')
plt.legend()
plt.show()

### Train LSTM

In [None]:
time_window = 15 # The sliding window size of the LSTM, for truncated backpropagation through time (BPTT)
lstm_training_data = np.copy(encoded)
num_train_snapshots = 1
total_size = np.shape(lstm_training_data)[0]

# Shape the inputs and outputs
input_seq = np.zeros(shape=(total_size-time_window,time_window,4))
output_seq = np.zeros(shape=(total_size-time_window,4))

# Setting up inputs
sample = 0
for t in range(time_window,total_size):
    input_seq[sample,:,:] = lstm_training_data[t-time_window:t,:]
    output_seq[sample,:] = lstm_training_data[t,:]
    sample = sample + 1
print(f"Total samples = {sample}")

In [None]:
# Model architecture
lstm_model = models.Sequential()

# LSTM is always "stateful" when time_window > 1. Keras "stateful" option refers to inter-batch states
lstm_model.add(LSTM(20,input_shape=(time_window, 4),return_sequences=True, stateful=False)) 
lstm_model.add(LSTM(20,input_shape=(time_window, 4),return_sequences=False))  #
lstm_model.add(Dense(4, activation=None))

# lstm_model = models.Sequential(
# [
#     LSTM(20,input_shape=(time_window, 4),return_sequences=True),
#     LSTM(20,input_shape=(time_window, 4),return_sequences=False),
#     Dense(4, activation=None)
# ])


# training parameters
num_epochs = 3000
batch_size = 64


# design network
lstm_filepath = './saved_models/SWE_LSTM_Weights.h5'
lstm_adam = optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
checkpoint = ModelCheckpoint(lstm_filepath, monitor='loss', verbose=0, save_best_only=True, mode='min',save_weights_only=True)
earlystopping = EarlyStopping(monitor='loss', min_delta=0, patience=5, verbose=0, mode='auto', baseline=None, restore_best_weights=False)
lstm_callbacks_list = [checkpoint]

# fit network
lstm_model.compile(optimizer=lstm_adam,loss='mean_squared_error',metrics=[mean_absolute_error,max_absolute_error])

#mode = 'train'
if mode == 'train':
    lstm_train_history = lstm_model.fit(input_seq, output_seq, epochs=num_epochs, batch_size=batch_size, callbacks=lstm_callbacks_list)

![RNN sequence input and output](media/rnn_sequences.png)
Adapted from [Andrej Karpathy's blog](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)

See https://stackoverflow.com/questions/38714959/understanding-keras-lstms for a helpful summary of the Keras API for `LSTM()`

In [None]:
plot_model(lstm_model, show_shapes=True, show_layer_names=True, to_file='lstm-model.png')
Image('lstm-model.png')

### Test LSTM with parameter information

In [None]:
#mode='test'
if mode == 'test':
    lstm_model.load_weights(lstm_filepath)

encoded_valid = K.eval(encoder(swe_valid_data[:,:,:,:].astype('float32')))
encoded_valid = encoded_valid.reshape(-1,4)
lstm_testing_data = np.copy(encoded_valid)

# Shape the inputs and outputs
input_seq = np.zeros(shape=(1,time_window,4))
output_seq_pred = np.zeros(shape=(total_size,4))

# Setting up inputs
sample = 0
for t in range(time_window,total_size):
    input_seq[0,:,:] = lstm_testing_data[t-time_window:t,:]
    output_seq_pred[t,:] = lstm_model.predict(input_seq[0:1,:,:])[0,:]
    input_seq[0,0:time_window-1,:] = input_seq[0,1:,:] 
    input_seq[0,time_window-1,:] = output_seq_pred[t,:]
    sample = sample + 1

### Check quality in latent space for testing data

In [None]:
for i in range(4):
    plt.figure(figsize=(7,6))
    plt.plot(lstm_testing_data[time_window:,i],'r',label='True',linewidth=3)
    plt.plot(output_seq_pred[time_window:,i],'b--',label='Predicted',linewidth=3)
    
    if i == 0:
        plt.legend()
    plt.tight_layout()
    plt.xlim([0, 200])    
    plt.show()

### Evolution in physical space

Apply decoder to LSTM-evolved latent space

In [None]:
# Reshape for decoding
output_seq_pred = np.reshape(output_seq_pred,newshape=(-1,2,2,1))
# Feed it through decoder
decoded_valid = K.eval(decoder(output_seq_pred.astype('float32')))

# Check evolution through spot checks
time = 50

fig, ax = plt.subplots(nrows=2,ncols=3,figsize=(14,12))
cs1 = ax[0,0].imshow(swe_valid_data[time,:,:,0],label='Truth')
ax[1,0].imshow(gaussian_filter(decoded_valid[time,:,:,0],sigma=2),label='Prediction')

cs2 = ax[0,1].imshow(swe_valid_data[time,:,:,1],label='Truth')
ax[1,1].imshow(gaussian_filter(decoded_valid[time,:,:,1],sigma=2),label='Prediction')

cs3 = ax[0,2].imshow(swe_valid_data[time,:,:,2],label='Truth')
ax[1,2].imshow(gaussian_filter(decoded_valid[time,:,:,2],sigma=2),label='Prediction')

for i in range(2):
    for j in range(3):
        ax[i,j].set_xlabel('x')
        ax[i,j].set_ylabel('y')
        
fig.colorbar(cs1,ax=ax[0,0],fraction=0.046, pad=0.04)
fig.colorbar(cs1,ax=ax[1,0],fraction=0.046, pad=0.04)

fig.colorbar(cs2,ax=ax[0,1],fraction=0.046, pad=0.04)
fig.colorbar(cs2,ax=ax[1,1],fraction=0.046, pad=0.04)

fig.colorbar(cs3,ax=ax[0,2],fraction=0.046, pad=0.04)
fig.colorbar(cs3,ax=ax[1,2],fraction=0.046, pad=0.04)


ax[0,0].set_title(r'True $q_1$')
ax[0,1].set_title(r'True $q_2$')
ax[0,2].set_title(r'True $q_3$')

ax[1,0].set_title(r'Reconstructed $q_1$')
ax[1,1].set_title(r'Reconstructed $q_2$')
ax[1,2].set_title(r'Reconstructed $q_3$')

plt.subplots_adjust(wspace=0.5,hspace=-0.3)
# plt.tight_layout()
plt.show()

### A posteriori analysis

In [None]:
x = np.linspace(-1/2, 1/2, 64)  # Array with x-points
y = np.linspace(-1/2, 1/2, 64)  # Array with x-points

# Meshgrid for plotting
X, Y = np.meshgrid(x, y)
time = 199

from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize = (11, 7))
ax = Axes3D(fig)
surf = ax.plot_surface(X, Y, swe_valid_data[time,:,:,0], rstride = 1, cstride = 1,
    cmap = plt.cm.jet, linewidth = 0, antialiased = True)

# ax.set_title('Visualization', fontname = "serif", fontsize = 17)
ax.set_xlabel('x [m]', fontsize = 16)
ax.set_ylabel('y [m]', fontsize = 16)

ax.xaxis.labelpad=30
ax.yaxis.labelpad=30

ax.xaxis.labelpad=30
ax.yaxis.labelpad=30

ax.tick_params(axis='both', which='major', pad=15)

ax.set_zticks([0.1, 0.15, 0.20, 0.25, 0.3])
ax.set_zlim((0.1,0.3))


plt.show()

fig = plt.figure(figsize = (11, 7))
ax = Axes3D(fig)
surf = ax.plot_surface(X, Y, gaussian_filter(decoded_valid[time,:,:,0],sigma=2), rstride = 1, cstride = 1,
    cmap = plt.cm.jet, linewidth = 0, antialiased = True)

# ax.set_title('Visualization', fontname = "serif", fontsize = 17)
ax.set_xlabel('x [m]', fontsize = 16)
ax.set_ylabel('y [m]', fontsize = 16)

ax.xaxis.labelpad=30
ax.yaxis.labelpad=30

ax.set_zticks([0.1, 0.15, 0.20, 0.25, 0.3])
ax.set_zlim((0.1,0.3))

ax.tick_params(axis='both', which='major', pad=15)
plt.show()

### ONNX portability

In [None]:
import tf2onnx

In [None]:
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(lstm_model, 
                                                                  input_signature=[tf.TensorSpec((None, time_window, 4) )],
                                                                  opset=10, output_path='./lstm.onnx')