# Training a conditional recurrent neural network for pulse propagation
This notebook shows how to train a conditional recurrent LSTM network 

#### Import relevant modules

In [6]:
import tensorflow as tf
from cond_rnn import ConditionalRecurrent
import matplotlib.pyplot as plt

#### Define the model name and load it
It's important to define the custom layer when loading the saved model that include a ConditionalRecurrent layer

In [9]:
clean_model_name = 'cond_LSTM.h5'
model = tf.keras.models.load_model(clean_model_name, custom_objects={'ConditionalRecurrent': ConditionalRecurrent})

#### Define the training and testing datasets
This part depends on how your data is saved, and what represent you condition, so it's left blank here.
However, the idea behind the data prep is the following.

Each prediction of the network takes as input the sequence and the condition, and output one single prediction. The sequence is a number of individual sequential observations, with shape `train_x = [None, sequence_length, n_features]`. `sequence_length` is the number of the sequential observations that make the input sequence, while `n_features` is the number of points in a single observation. For each of the input sequences, we need to define a condition at that step, with shape `train_c = [None, cond_features]`, where `cond_features` is the number of points within a single condition. Finally, the output for these two inputs is a single observation at the next step, with shape `train_y = [None, n_features]`.
This can be applied similarly for more conditions, with the same logic.

In [None]:
# Test
train_x = []
train_c = []
train_y = []

# Train
test_x = []
test_c = []
test_y = []

#### Define the training parameters

In [None]:
# Batch size
b_size = 150
# Number of epochs
n_epochs = 200

#### Train the model

In [None]:
# Train
history = model.fit(
        verbose=1,
        x=[train_x, train_c], y=train_y,
    batch_size=b_size,
        validation_data=([test_x, test_c], test_y),
        epochs= n_epochs
    )

# Plot training
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='test')
plt.legend()
plt.show()