In [None]:
import tqdm #For progress bar
import tensorflow as tf
from tf import keras
import numpy as np

In [None]:
def rnn_jac(Wxh, Whh, ht, xt, phiprime):
    """
    Compute the Jacobian of the RNN with respect to the hidden state ht
    :param Wxh: input-to-hidden weight matrix (U)
    :param Whh: hidden-to-hidden weight matrix (V)
    :param ht: current hidden state
    :param xt: current input
    :param phiprime: function handle for the derivative of the activation function
    :return: Jacobian matrix
    """
    
    # Compute the Jacobian of the RNN with respect to ht

    alpha=Wxh@xt + Whh@ht
    J=np.diag(phiprime(alpha))@Whh
    return J

def calc_LEs(x_batches, h0, RNNlayer, activation_function_prim=lambda x:np.heaviside(x,1), k_LE=1000):
    """
    Calculate the Lyapunov exponents of a batch of sequences using the QR method.
    :param x_batches: input sequences (batch_size, T, input_size)
    :param h0: initial hidden state (batch_size, hidden_size)   
    :param RNNlayer: RNN layer object (e.g., tf.keras.layers.SimpleRNN)
    :param activation_function_prim: function handle to derivative of activation function used in the RNN layer
    :param k_LE: number of Lyapunov exponents to compute
    :return: Lyapunov exponents for each batch (batch_size, k_LE)
    """
    #get dimensions
    hidden_size = h0.shape
    batch_size, T, input_size = x_batches.shape
    L = hidden_size

    #get recurrent cell
    RNNcell=RNNlayer.cell
        
    # Choose how many exponents to track
    k_LE = max(min(L, k_LE), 1)

    #save average Lyapunov exponent over the sequence for each batch
    lyaps_batches = np.zeros((batch_size, k_LE))
    #Loop over input sequence
    #tqdm creates a progress bar
    for batch in tqdm(range(batch_size)):
        x=x_batches[batch]
        ht=h0
        #Initialize Q
        Q = tf.eye(L)
        #keep track of average lyapunov exponents
        cum_lyaps = tf.zeros((k_LE,))

        for t in tqdm(range(T)):
            #Get next state ht+1 by taking a reccurent step
            xt=x[t]
            _, ht = RNNcell(xt, ht)

            #Get jacobian J
            J = rnn_jac(RNNlayer.get_weights(), ht, xt, activation_function_prim)
            
            #Get the Lyapunov exponents from qr decomposition
            Q=Q@J
            Q,R=tf.linalg.qr(Q, full_matrices=False)
            cum_lyaps += tf.math.log(tf.linalg.diag_part(R[0:k_LE, 0:k_LE]))
        lyaps_batches[batch] = cum_lyaps / T
    return lyaps_batches
