In [89]:
import math
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence

In [90]:
seq1 = torch.randn(50, 1)
seq2 = torch.randn(60, 1)
seq3 = torch.randn(2, 1)
seq4 = torch.randn(10, 1)

In [91]:
help(pad_sequence)

Help on function pad_sequence in module torch.nn.utils.rnn:

pad_sequence(sequences, batch_first=False, padding_value=0.0)
    Pad a list of variable length Tensors with ``padding_value``
    
    ``pad_sequence`` stacks a list of Tensors along a new dimension,
    and pads them to equal length. For example, if the input is list of
    sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
    otherwise.
    
    `B` is batch size. It is equal to the number of elements in ``sequences``.
    `T` is length of the longest sequence.
    `L` is length of the sequence.
    `*` is any number of trailing dimensions, including none.
    
    Example:
        >>> from torch.nn.utils.rnn import pad_sequence
        >>> a = torch.ones(25, 300)
        >>> b = torch.ones(22, 300)
        >>> c = torch.ones(15, 300)
        >>> pad_sequence([a, b, c]).size()
        torch.Size([25, 3, 300])
    
    Note:
        This function returns a Tensor of size ``T x B x *`` or ``B x T x

In [92]:
padded_sequence = pad_sequence(sequences=[seq1, 
                                          seq2,
                                          seq3,
                                          seq4], batch_first=True, padding_value=-1000)
lens = [seq.size(0) for seq in [seq1, seq2, seq3, seq4]]

In [93]:
lens

[50, 60, 2, 10]

In [94]:
padded_sequence.size()

torch.Size([4, 60, 1])

In [95]:
padded_sequence[0]

tensor([[-6.4193e-01],
        [ 1.2174e+00],
        [-3.0194e-01],
        [-1.2784e+00],
        [-6.8767e-01],
        [ 8.0987e-01],
        [-8.5511e-02],
        [-3.3943e-01],
        [ 5.8664e-01],
        [-2.1247e-01],
        [ 2.2592e-01],
        [ 2.4989e+00],
        [-9.6000e-01],
        [-4.2262e-01],
        [ 6.4699e-01],
        [-1.8023e+00],
        [ 5.4731e-01],
        [ 3.2625e-01],
        [-9.6760e-01],
        [ 1.3899e+00],
        [ 1.0241e+00],
        [ 3.7540e+00],
        [-7.3970e-01],
        [ 6.8482e-01],
        [ 1.0996e+00],
        [ 1.4202e+00],
        [ 3.0433e-01],
        [ 7.7881e-01],
        [-1.2369e-01],
        [-1.0651e+00],
        [ 6.8292e-01],
        [-7.2316e-01],
        [ 7.0480e-01],
        [-6.4630e-01],
        [ 9.4085e-01],
        [ 4.8894e-01],
        [ 3.6962e-01],
        [-2.6323e-01],
        [ 7.8854e-01],
        [-1.1166e+00],
        [ 1.1667e+00],
        [-3.1645e-01],
        [-2.6030e-01],
        [-3

In [96]:
packed_padded_sequence = pack_padded_sequence(input=padded_sequence, lengths=lens, batch_first=True, enforce_sorted=False)

In [97]:
packed_padded_sequence.data.size()

torch.Size([122, 1])

In [98]:
rnn = nn.GRU(input_size=1, hidden_size=2, num_layers=1, 
             bias=True, batch_first=True, dropout=0.0, 
             bidirectional=False)

In [99]:
out, hn = rnn.forward(packed_padded_sequence)

In [100]:
out.data.size()

torch.Size([122, 2])

In [101]:
fc1 = nn.Linear(2, 1)

In [102]:
hidden_states_unpacked, lens_unpacked = pad_packed_sequence(sequence=out, batch_first=True, padding_value=-1000)

In [103]:
hidden_states_unpacked.shape

torch.Size([4, 60, 2])

In [104]:
hidden_states_unpacked[0], lens_unpacked[0], hidden_states_unpacked[0,:lens_unpacked[0],:]

(tensor([[-2.1862e-01, -1.2859e-02],
         [-2.6092e-01,  3.4482e-01],
         [-3.2502e-01,  1.9340e-01],
         [-3.8103e-01, -8.8216e-02],
         [-4.2686e-01, -7.9546e-02],
         [-3.5748e-01,  2.1930e-01],
         [-3.7610e-01,  1.8971e-01],
         [-3.9780e-01,  1.2041e-01],
         [-3.5671e-01,  2.8033e-01],
         [-3.7753e-01,  1.8695e-01],
         [-3.7080e-01,  2.4100e-01],
         [-1.2542e-01,  5.9188e-01],
         [-2.2503e-01,  1.1186e-01],
         [-3.2497e-01,  7.4123e-02],
         [-3.2873e-01,  2.7648e-01],
         [-3.7665e-01, -1.6827e-01],
         [-3.6958e-01,  1.3300e-01],
         [-3.6408e-01,  2.3779e-01],
         [-4.0322e-01, -1.6404e-03],
         [-2.8486e-01,  3.5215e-01],
         [-2.6827e-01,  4.6728e-01],
         [ 6.4164e-02,  7.6619e-01],
         [-9.7058e-02,  2.0611e-01],
         [-2.4161e-01,  3.5978e-01],
         [-2.4850e-01,  4.8555e-01],
         [-2.1482e-01,  5.8316e-01],
         [-2.6862e-01,  4.1707e-01],
 

In [105]:
means = fc1(hidden_states_unpacked)

In [106]:
means.size()

torch.Size([4, 60, 1])

In [107]:
a = torch.randn(3,)
b = a / (1.0 + a**2)
print(a, b)

tensor([ 1.4895, -0.4597,  1.0238]) tensor([ 0.4628, -0.3795,  0.4999])


In [108]:
means[3]

tensor([[ 4.3440e-01],
        [ 5.2169e-01],
        [ 3.0555e-01],
        [ 4.8630e-01],
        [ 5.5042e-01],
        [ 6.7441e-01],
        [ 6.3157e-01],
        [ 4.1873e-01],
        [ 4.6184e-01],
        [ 4.0588e-01],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6.2850e+02],
        [-6

## Experimenting with refactored dynamics functions

In [115]:
def linear_dynamics_fn(x_k):
    gamma = 0.8
    n_states = x_k.shape[-1]
    F_mat = torch.eye(n_states) + torch.concatenate((torch.zeros(n_states, 1), 
                                                torch.concatenate((torch.ones(1, n_states-1), 
                                                                torch.zeros(n_states-1, n_states-1)), 
                                                               dim=0)
                                                ), 
                                                dim=1)
        
    F_mat = F_mat.type(torch.FloatTensor) * gamma
    assert (torch.eig(F_mat)[0] <= 1.0).all() == True, "System is not stable!"
    x_k_plus_1 = F_mat @ x_k
    return x_k_plus_1

def nonlinear1d_dynamics_fn(x_k, k, a=0.5, b=25.0, c=8.0):
    x_k_plus_1 = a * x_k + b * (x_k / (1.0 + x_k**2)) + c * torch.cos(1.2 * (k+1)) 
    return x_k_plus_1

def lorenz63_dynamics_fn(x_k, J=5, delta=0.02):
    
    n_states = x_k.shape[-1]
    A_mat = torch.Tensor([
        [-10.0, 10.0, 0.0],
        [28.0, - 1.0, -x_k[0]],
        [0.0, x_k[0], -8.0/3] 
    ]).type(torch.FloatTensor) 
    F_k = torch.eye(n_states, n_states).type(torch.FloatTensor)
    for j in range(1, J+1):
        #self.F += np.linalg.matrix_power(self.A_fn(x)*self.delta, j) / np.math.factorial(j)
        F_k += torch.matrix_power(A_mat*delta, j) / math.factorial(j)

    x_k_plus_1 = F_k @ x_k
    return x_k_plus_1

def chen_dynamics_fn(x_k, J=5, delta=0.002, alpha=0.0):
    
    n_states = x_k.shape[-1]
    A_mat = torch.Tensor([
        [-35.0, 35.0, 0.0],
        [-7.0, 28.0, -x_k[0]],
        [0.0, x_k[0], -3.0] 
    ]).type(torch.FloatTensor) 
    F_k = torch.eye(n_states, n_states).type(torch.FloatTensor)
    for j in range(1, J+1):
        #self.F += np.linalg.matrix_power(self.A_fn(x)*self.delta, j) / np.math.factorial(j)
        F_k += torch.matrix_power(A_mat*delta, j) / math.factorial(j)

    x_k_plus_1 = F_k @ x_k
    return x_k_plus_1

def rossler_dynamics_fn(x_k, J=5, delta=0.008, a=0.2, b=0.2, c=5.7):
    
    n_states = x_k.shape[-1]
    A_mat = torch.Tensor([
        [0, -1, -1],
        [1, a, 0],
        [0, 0, (b / x_k[2]) + (x_k[0] - c)]
    ]).type(torch.FloatTensor) 
    F_k = torch.eye(n_states, n_states).type(torch.FloatTensor) 
    for j in range(1, J+1):
        #self.F += np.linalg.matrix_power(self.A_fn(x)*self.delta, j) / np.math.factorial(j)
        F_k += torch.matrix_power(A_mat*delta, j) / math.factorial(j)

    x_k_plus_1 = F_k @ x_k
    return x_k_plus_1 

def lorenz96_process_model(T_time, n_states, method='RK45', N=20, F_mu=8.0, delta=0.01, sigma_e2_dB=-10.0):

    def L96(t, x, N=20, F_mu=8.0, sigma_e2=.1):
        """Lorenz 96 model with constant forcing
        Adapted from: https://www.wikiwand.com/en/Lorenz_96_model 
        """
        # Setting up vector
        d = np.zeros(N)
        # Loops over indices (with operations and Python underflow indexing handling edge cases)
        F_N = np.random.normal(loc=F_mu, scale=np.sqrt(sigma_e2), size=(N,)) # Incorporating Process noise through the forcing constant
        for i in range(N):
            #print(F_N[i])
            d[i] = (x[(i + 1) % N] - x[i - 2]) * x[i - 1] - x[i] + F_N[i]
        return d

    sigma_e2 = dB_to_lin(sigma_e2_dB)
    x0 = F_mu * np.ones(n_states)  # Initial state (equilibrium)
    x0[0] += delta  # Add small perturbation to the first variable
    sol = solve_ivp(L96, 
                    t_span=(0.0, T_time), 
                    y0=x0, 
                    args=(n_states, F_mu, sigma_e2,), 
                    method=method, 
                    t_eval=np.arange(0.0, T_time, delta), 
                    max_step=delta)

    x_lorenz96 = torch.from_numpy(np.concatenate((sol.y.T, x0.reshape((1, -1))), axis=0)).type(torch.FloatTensor)
    return x_lorenz96

def get_dynamics_fn_dict(fn_name):

    DYNAMICS_FN_LIST = {
        "linear": linear_dynamics_fn,
        "nonlinear1d": nonlinear1d_dynamics_fn,
        "lorenz63": lorenz63_dynamics_fn,
        "chen": chen_dynamics_fn,
        "rossler": rossler_dynamics_fn
    }

    return DYNAMICS_FN_LIST[fn_name]

In [117]:
params_dict = {
    "nonlinear1d": {
        "a":0.5,
        "b":25.0,
        "c":8.0
    }
}
dynamics_fn_nonlinear = get_dynamics_fn_dict(fn_name="nonlinear1d")

In [118]:
dynamics_fn_nonlinear

<function __main__.nonlinear1d_dynamics_fn(x_k, k, a=0.5, b=25.0, c=8.0)>