In [None]:
import numpy as np
np.random.seed(10)
import tensorflow as tf
tf.random.set_seed(10)
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ReduceLROnPlateau,EarlyStopping
import matplotlib.pyplot as plt

In [None]:
# In physical space use 'raw_train.reshape(-1,121,281)' to visualize original
num_snapshots = 12564
raw_train_z500 = np.load('../Raw_Dataset/train_z500_snap.npy').reshape(-1,121,281)[:,::2,::4].reshape(num_snapshots,-1) # Subsample to fit on GPU
raw_train_u250 = np.load('../Raw_Dataset/train_u250_snap.npy').reshape(-1,121,281)[:,::2,::4].reshape(num_snapshots,-1)
raw_train_v250 = np.load('../Raw_Dataset/train_v250_snap.npy').reshape(-1,121,281)[:,::2,::4].reshape(num_snapshots,-1)
raw_train_u850 = np.load('../Raw_Dataset/train_u850_snap.npy').reshape(-1,121,281)[:,::2,::4].reshape(num_snapshots,-1)
raw_train_v850 = np.load('../Raw_Dataset/train_v850_snap.npy').reshape(-1,121,281)[:,::2,::4].reshape(num_snapshots,-1)
raw_train_t250 = np.load('../Raw_Dataset/train_t250_snap.npy').reshape(-1,121,281)[:,::2,::4].reshape(num_snapshots,-1)
raw_train_t850 = np.load('../Raw_Dataset/train_t850_snap.npy').reshape(-1,121,281)[:,::2,::4].reshape(num_snapshots,-1)
raw_train_blh = np.load('../Raw_Dataset/train_blh_snap.npy').reshape(-1,121,281)[:,::2,::4].reshape(num_snapshots,-1)
raw_train_tcwv = np.load('../Raw_Dataset/train_tcwv_snap.npy').reshape(-1,121,281)[:,::2,::4].reshape(num_snapshots,-1)

var_list = [raw_train_z500,
            raw_train_u250,raw_train_v250,raw_train_t250,
            raw_train_u850,raw_train_v850,raw_train_t850,
            raw_train_blh,raw_train_tcwv]

scaler_list = []
for i in range(len(var_list)):
    scaler = MinMaxScaler()
    var_list[i] = scaler.fit_transform(var_list[i])
    scaler_list.append(scaler)
    
# Free memory
del raw_train_z500,raw_train_u250,raw_train_v250,raw_train_t250,raw_train_u850,raw_train_v850,raw_train_t850,raw_train_blh,raw_train_tcwv

In [None]:
raw_train = np.concatenate(var_list,axis=-1)

In [None]:
input_window = 14
output_window = 7

train_inputs = []
train_outputs = []

i = 0
while i < raw_train.shape[0] - input_window - output_window:
    train_inputs.append(raw_train[i:i+input_window])
    train_outputs.append(raw_train[i+input_window:i+input_window+output_window])
    i = i + 1
    
del raw_train
    
train_inputs = np.asarray(train_inputs)
train_outputs = np.asarray(train_outputs)

In [None]:
embed_dim = train_inputs[0].shape[-1]
output_dim = train_outputs[0].shape[-1]
encode_dim = 180

In [None]:
num_ae_encoder_layers = 3
num_ae_decoder_layers = 3
num_lstm_cells_encoder = 2
num_lstm_cells_decoder = 2

In [None]:
ff_dim = 100  # Hidden layer size in feed forward network inside transformer
dropout_rate = 0.0

inputs = layers.Input(shape=(input_window,embed_dim))
ae_encoding_layers = []
for _ in range(num_ae_encoder_layers):
    ae_encoding_layers.append(layers.TimeDistributed(layers.Dense(encode_dim,activation='elu')))

lstm_encoder_cells = []
for _ in range(num_lstm_cells_encoder):
    lstm_encoder_cells.append(layers.Bidirectional(layers.LSTM(ff_dim,activation='elu',return_sequences=True,)))

lstm_encoder_final = layers.Bidirectional(layers.LSTM(ff_dim,activation='elu'))
lstm_repeater_layer = layers.RepeatVector(output_window)
    
lstm_decoder_cells = []
for _ in range(num_lstm_cells_decoder):
    lstm_decoder_cells.append(layers.LSTM(encode_dim,activation='elu',return_sequences=True,))

    
ae_decoding_layers = []
for _ in range(num_ae_decoder_layers):
    ae_decoding_layers.append(layers.TimeDistributed(layers.Dense(embed_dim,activation='elu')))


# Encode from physical space
print('Input shape:',inputs.get_shape().as_list())

x = inputs
for i in range(num_ae_encoder_layers):
    x = ae_encoding_layers[i](x)
encoded = x

print('AE Encoded shape:',encoded.get_shape().as_list())

x = lstm_encoder_cells[0](encoded)    
for i in range(1,num_lstm_cells_encoder):
    x = lstm_encoder_cells[i](x)
    
x = lstm_encoder_final(x)

print('LSTM Encoded shape:',x.get_shape().as_list())

x = lstm_repeater_layer(x)

for i in range(num_lstm_cells_decoder):
    x = lstm_decoder_cells[i](x)

print('LSTM Decoded shape:',x.get_shape().as_list())
    
for i in range(num_ae_decoder_layers):
    x = ae_decoding_layers[i](x)
    
outputs = x

print('AE+LSTM Output shape:',outputs.get_shape().as_list())
    
decoded = encoded
for i in range(num_ae_decoder_layers):
    decoded = ae_decoding_layers[i](decoded)
    
print('AE Output shape:',decoded.get_shape().as_list())

model = tf.keras.Model(inputs=inputs, outputs=[outputs,decoded])

In [None]:
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=5, min_lr=0.0001)
early_stop = EarlyStopping(monitor='val_loss',patience=20)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss='mean_squared_error',loss_weights=[1,1])

In [None]:
model.summary()

In [None]:
history = model.fit(train_inputs,[train_outputs,train_inputs],epochs=250,batch_size=2,callbacks=[reduce_lr,early_stop],validation_split=0.2)