VAEs dynamiques : https://arxiv.org/abs/2008.12595

# Training Deep Kalman Filter

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# from tests.unit_tests import test_brick_1, test_brick_2, test_brick_3, test_brick_4, test_brick_5, test_brick_6
# from libs.dkf import DeepKalmanFilter, loss_function

In [None]:
def seed_everything(seed=42):
    """
    Set seed for reproducibility.
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set seed for reproducibility
seed_everything(42)

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    dtype = torch.cuda.FloatTensor
else:
    device = torch.device('cpu')
    dtype = torch.FloatTensor

print(f"Using {device}")

torch.set_default_dtype(torch.float64)

if device.type == 'cuda':
    print('GPU Name:', torch.cuda.get_device_name(0))
    print('Total GPU Memory:', round(torch.cuda.get_device_properties(0).total_memory/1024**3,1), 'GB')

# Modèle

### Structure SSM (Single State Model pour variables latentes + VAE pour observations):

- $z_t$ variables latentes forment une chaîne de Markov, transition $p(z_t \vert z_{t-1})$
- $x_t$ observations, modèle $p_{\theta_x}(x_t \vert z_t)$
- NB : pas de commande/input $u_t$ ici

### Deep Kalman Filter :

\begin{align}
p_{\theta_z}(z_t \vert z_{t-1}) &= \mathcal{N}(z_t \vert \mu_{\theta_z}(z_{t-1}), \text{diag}(\sigma_{\theta_z}^{2}(z_{t-1}))) \\
d_z(z_{t-1}) &= [ \mu_{\theta_z}(z_{t-1}), \sigma_{\theta_z}(z_{t-1}) ] \\
p_{\theta_x}(x_t \vert z_{t}) &= \mathcal{N}(x_t \vert \mu_{\theta_x}(z_{t}), \text{diag}(\sigma_{\theta_x}^{2}(z_{t}))) \\
d_x(z_{t}) &= [ \mu_{\theta_x}(z_{t}), \sigma_{\theta_x}(z_{t}) ] \\
\end{align}

où $d_x, d_z$ sont des réseaux.

### Modèle inférence

Le "true posterior" s'écrit :
\begin{align}
p_{\theta}(z_{1:T} \vert x_{1:T}) &= \prod_{t=1}^T p_{\theta} (z_t \vert z_{1:t-1}, x_{1:T} ) \\
&= \prod_{t=1}^T p_{\theta} (z_t \vert z_{t-1}, x_{t:T} )
\end{align}

où la première écriture est l'application de la chain rule, et la deuxième est un résultat de D-séparation (latentes à dépendance Markovienne).

On choisit comme approximation du posterior (=encodeur) une formulation calquée sur le vrai posterior :

\begin{align}
q_{\phi}(z_{1:T} \vert x_{1:T}) &= \prod_{t=1}^T q_{\phi} (z_{t} \vert z_{t-1}, x_{t:T})
\end{align}

On voit que l'inférence prend en compte les observations futures $x_{t:T}$ (comme le Kalman smoother par exemple)

# Implémentation de l'inférence

- **backward RNN** (dans les faits, un LSTM) pour encoder les $x_{t:T}$ par les hidden states $h_t$ : 

\begin{align}
h_t = \text{LSTM}(h_{t+1}, x_t)
\end{align}

- **combiner** (réseau MLP) pour aggréger $h_t$ et $z_{t-1}

\begin{align}
g_t = \text{Combiner}(h_t, z_{t-1})
\end{align}

- **Encoder** (réseau MLP) pour inférer les paramètres du posterior:

\begin{align}
e_z(g_t) &= [ \mu_\phi(g_t), \sigma_\phi(g_t)] \\
q_\phi(z_t \vert g_t) &= \mathcal{N}(z_t \vert \mu_{\phi}(g_t), \text{diag}(\sigma_\phi^2(g_t)))
\end{align}

NB : il existe d'autres formulations du posterior approximé $q_\phi$, qui peuvent faire intervenir un forward LSTM.

# Training

Le modèle s'entraîne en maximisant un ELBO, dont la formulation générique se simplifie dans le cas du DKF en :

\begin{align}
\mathcal{L}(\theta, \phi; X) &= \sum_{t=1}^T \mathbb{E}_{q_\phi(z_t \vert x_{1:T})} \log(p_{\theta_x}(x_t \vert z_t)) -
\sum_{t=1}^T \mathbb{E}_{q_\phi(z_{t-1} \vert x_{1:T})} \text{D}_{\text{KL}} \left[ q_\phi(z_t \vert z_{-1}, x_{t:T}) \vert\vert 
p_{\theta_z}(z_t \vert z_{t-1}) \right]
\end{align}

Les deux termes s'explicitent de la façon suivante (avec $D$ dimension de l'espace des observations) :

\begin{align}
p_{\theta_x}(x_t \vert z_t) &= \mathcal{N}(x_t \vert \mu_{\theta_x}(z_t), \text{diag}(\sigma_{\theta_x}^2(z_t))) \\
\log{p_{\theta_x}(x_t \vert z_t)} &= -\frac{D}{2} \log{2\pi} - \frac{1}{2}\log{\vert \text{diag}(\sigma_{\theta_x}^2(z_t)) \vert} - 
\frac{1}{2} \left[ (x_t - \mu_{\theta_x}(z_t))^T (\text{diag}(\sigma_{\theta_x}^2(z_t)))^{-1} (x_t - \mu_{\theta_x}(z_t)) \right] \\
&= \frac{1}{2} \left( \sum_{i=1}^D \log{\sigma_{\theta_x}^2(z_t)}\vert_{i} + (x_t - \mu_{\theta_x}(z_t))^T \text{diag} \frac{1}{\sigma_{\theta_x}^2(z_t)} (x_t - \mu_{\theta_x}(z_t)) \right)
\end{align}

Et la KL entre les deux Gaussiennes:

\begin{align}
q_\phi(z_t \vert z_{t-1}, x_{t:T}) &= \mathcal{N}(z_t \vert \mu_{\phi}(g_t), \text{diag}(\sigma_\phi^2(g_t))) \\
p_{\theta_z}(z_t \vert z_{t-1}) &= \mathcal{N}(z_t \vert \mu_{\theta_z}(z_{t-1}), \text{diag}(\sigma_{\theta_z}^{2}(z_{t-1}))) \\
\end{align}

a une close form (avec $Z$ dimension de l'espace latent):

\begin{align}
\text{D}_{\text{KL}}(q_\phi \vert\vert p_{\theta_z}) &= \frac{1}{2} \left[ \text{Tr}(\text{diag}(\sigma_{\theta_z}^{2})^{-1} \text{diag}(\sigma_\phi^2) ) + (\mu_{\theta_z} - \mu_\phi)^T (\text{diag}(\sigma_{\theta_z}^{2})^{-1}) (\mu_{\theta_z} - \mu_\phi) +
\log{\frac{\vert \text{diag}(\sigma_{\theta_z}^{2})\vert}{\vert \text{diag}(\sigma_\phi^2) \vert} } \right] \\
&= \frac{1}{2}\left[ \sum_{i=1}^Z \log{\sigma_{\theta_z}^{2}}\vert_i - \sum_{i=1}^Z \log{\sigma_{\phi}^{2}}\vert_i +
 (\mu_{\theta_z} - \mu_\phi)^T \text{diag}(\frac{1}{\sigma_{\theta_z}^{2}}) (\mu_{\theta_z} - \mu_\phi) + \sum_{i=1}^D \frac{\sigma_{\phi}^{2}\vert_i} {\sigma_{\theta_z}^{2}\vert_i} - Z 
\right]
\end{align}

# Code

In [None]:
X_DIM = 1 # Dimension of the observation space
Z_DIM = 16 # Dimension of the latent space
H_DIM = 16 # Dimension of the hidden state of the LSTM network(s)
G_DIM = 8 # Dimension of the output of the combiner
INTERMEDIATE_LAYER_DIM = 16 # Dimension of the intermediate layers of the MLPs

### Briques de base

In [None]:
#--- brick 1 : backward LSTM -----------------------------

class BackwardLSTM(nn.Module):
    """
    Backward LSTM module.
    """
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(BackwardLSTM, self).__init__()
        
        self.input_size = input_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        self.lstm = nn.LSTM(
            input_size,   # dimension of the observation space
            hidden_size,  # dimension of the hidden state of the LSTM network
            num_layers=num_layers, # number of layers of the LSTM network
            batch_first=False, # using the default PyTorch LSTM implementation, expecting input shape (seq_len, batch, input_size)
            bidirectional=False # unidirectional LSTM to start with
            )

    def forward(self, x):
        # Reverse the input sequence - axis 0 is the time axis here
        x_reversed = torch.flip(x, [0])
        # Pass through LSTM
        # using initial hidden state and cell state as zeros
        out, _ = self.lstm(x_reversed)
        # Reverse the output sequence
        out_reversed = torch.flip(out, [0])
        # return output shape (seq_len, batch, hidden_size)
        
        return out_reversed

In [None]:
#--- brick 2 : combiner -----------------------------
#
# this combines the latent variable at time t-1
# with the hidden state from the backward LSTM at time t,
# to compute a tX_DIM = 1 # Dimension of the observation space
Z_DIM = 16 # Dimension of the latent space
H_DIM = 16 # Dimension of the hidden state of the LSTM network(s)
G_DIM = 8 # Dimension of the output of the combiner
INTERMEDIATE_LAYER_DIM = 16 # Dimension of the intermediate layers of the MLPsensor g at time t, that will be used
# to compute the parameters of the approximate posterior distribution
# of the latent variable
#

class CombinerMLP(nn.Module):
    """Combiner module. Takes the hidden state of the backward LSTM at time t
    and the latent variable at time t-1, to compute a tensor g at time t,
    that will be used to compute the parameters of the approximate posterior
    distribution of the latent variable.

    Args:
        nn (_type_): _description_
    """
    
    def __init__(self, 
                 latent_dim=Z_DIM, 
                 hidden_dim=H_DIM, 
                 output_dim=G_DIM,
                 layers_dim = None,  # list of layers dimensions, without the input dimnesion, without the output dimension
                 activation = 'tanh',
                 inter_dim = INTERMEDIATE_LAYER_DIM,
                 ):
        super(CombinerMLP, self).__init__()
        
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        if activation == 'tanh':
            self.activation_fn = nn.Tanh()
        elif activation == 'relu':
            self.activation_fn = nn.ReLU()
        else:
            raise ValueError(f"Activation function {activation} not supported. Use 'tanh' or 'relu'.")
        self.inter_dim = inter_dim
        self.layers_dim = layers_dim
        
        if self.layers_dim is None:
            self.layers_dim = [self.inter_dim]
        else:
            self.layers_dim = layers_dim
            
        # explicitly define the MLP layers
        layers = []
        for i, dim in enumerate(self.layers_dim):
            if i==0:  #first layer, latent_dim + hidden_dim => layers_dim[0]
                layers.append(nn.Linear(latent_dim + hidden_dim, dim))
            else:  # all other layers
                layers.append(nn.Linear(self.layers_dim[i-1], dim))
            layers.append(self.activation_fn)
        # last layer : layers_dim[-1] => output_dim
        layers.append(nn.Linear(self.layers_dim[-1], output_dim))
            
        # build the MLP
        self.mlp = nn.Sequential(*layers)
            
        
    def forward(self, h, z):
        """
        Forward pass of the combiner module.
        Args:
            h: hidden state of the backward LSTM at time t
            shape (batch, hidden_dim)
            z: latent variable at time t-1
            shape (batch, latent_dim)
        Returns:
            g: tensor g at time t
            shape (batch, output_dim)
        """
        
        # Concatenate the hidden state and the latent variable on their dimension
        x = torch.cat((h, z), dim=-1)
        
        # Pass through MLP
        g = self.mlp(x)
        
        return g     

In [None]:
#--- brick 3 : Encoder -----------------------------
#
# This computes the parameters of the approximate posterior distribution
# of the latent vatiable at time t.
# The approximate posterior distribution is a Gaussian distribution,
# we use a MLP to compute the mean and the log of the variance.
#

class EncoderMLP(nn.Module):
    """Encoder module. Computes the parameters of the approximate posterior
    distribution of the latent variable at time t. The approximate posterior
    distribution is a Gaussian distribution, we use a MLP to compute the mean
    and the log of the variance.

    Args:
        nn (_type_): _description_
    """
    
    def __init__(self, 
                 latent_dim=Z_DIM, # Dimension of the latent space
                 combiner_dim=G_DIM, # Dimension of the combiner output
                 inter_dim=INTERMEDIATE_LAYER_DIM, # Dimension of the intermediate layers
                 layers_dim = None, # Dimension of the MLP layers (without inout nor output)
                 activation = 'tanh', # Activation function
    ):
        super(EncoderMLP, self).__init__()
        
        self.latent_dim = latent_dim
        self.combiner_dim = combiner_dim
        if activation == 'tanh':
            self.activation_fn = nn.Tanh()
        elif activation == 'relu':
            self.activation_fn = nn.ReLU()
        else:
            raise ValueError(f"Activation function {activation} not supported. Use 'tanh' or 'relu'.")
        self.inter_dim = inter_dim
        
        if layers_dim is None:
            self.layers_dim = [self.inter_dim]
        else:
            self.layers_dim = layers_dim
            
        # explicitly define the MLP layers
        layers = []
        for i, dim in enumerate(self.layers_dim):
            if i==0:
                layers.append(nn.Linear(combiner_dim, dim))
            else:
                layers.append(nn.Linear(layers_dim[i-1], dim))
            layers.append(self.activation_fn)
            
        # last layer is linear, no activation
        layers.append(nn.Linear(self.layers_dim[-1], 2 * latent_dim)) 
                    
        # build the MLP
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, g):
        """
        Forward pass of the encoder module.
        
        Args:
            g: tensor g at time t
            shape (batch, combiner_dim)
            
        Returns:
            mu: mean of the approximate posterior distribution
            shape (batch, latent_dim)
            logvar: log of the variance of the approximate posterior distribution
            shape (batch, latent_dim)
        """
        
        # Pass through MLP
        out = self.mlp(g)
        
        # Split the output into mean and log variance
        # each with shape (batch, latent_dim)
        mu, logvar = out[:, :self.latent_dim], out[:, self.latent_dim:]
        
        return mu, logvar

In [None]:
#--- brick 4 : Latent Space Transition -----------------------------       
#
# This computes the parameters of the transition distribution
# of the latent variable at time t. Ie the prior distribution, 
# before inference.
# The transition distribution is a Gaussian distribution,
# we use a MLP to compute the mean and the log of the variance.
#

class LatentSpaceTransitionMLP(nn.Module):
    """Latent space transition module. Computes the parameters of the
    transition distribution of the latent variable at time t. The transition
    distribution is a Gaussian distribution, we use a MLP to compute the mean
    and the log of the variance.

    Args:
        nn (_type_): _description_
    """
    
    def __init__(self, 
                 latent_dim=Z_DIM, # Dimension of the latent space
                 inter_dim=INTERMEDIATE_LAYER_DIM, # Dimension of the intermediate layers
                 layers_dim = None, # Dimension of the MLP layers
                 activation = 'tanh', # Activation function
    ):
        super(LatentSpaceTransitionMLP, self).__init__()
        
        self.latent_dim = latent_dim
        if activation == 'tanh':
            self.activation_fn = nn.Tanh()
        elif activation == 'relu':
            self.activation_fn = nn.ReLU()
        else:
            raise ValueError(f"Activation function {activation} not supported. Use 'tanh' or 'relu'.")
        self.inter_dim = inter_dim
        
        if layers_dim is None:
            layers_dim = [self.inter_dim]
            
        # explicitly define the MLP layers
        layers = []
        for i, dim in enumerate(layers_dim):
            if i==0:
                layers.append(nn.Linear(latent_dim, dim))
            else:
                layers.append(nn.Linear(layers_dim[i-1], dim))
            layers.append(self.activation_fn)
            
        # last layer is linear, no activation
        layers.append(nn.Linear(layers_dim[-1], 2 * latent_dim)) 
                    
        # build the MLP
        self.mlp = nn.Sequential(*layers)
               
    def forward(self, z):
        """
        Forward pass of the latent space transition module.
        
        Args:
            z: latent variable at time t-1
            shape (seq_len, batch, latent_dim)
            
        Returns:
            mu: mean of the transition distribution
            shape (seq_len, batch, latent_dim)
            logvar: log of the variance of the transition distribution
            shape (seq_len, batch, latent_dim)
        """
        
        # Pass through MLP
        out = self.mlp(z)
        
        # Split the output into mean and log variance
        # each with shape (batch, latent_dim)
        mu, logvar = out[:, :, :self.latent_dim], out[:, :, self.latent_dim:]
        
        return mu, logvar

In [None]:
#--- brick 5 : Decoder (ie Observation Model) -----------------------------
#
# This computes the parameters of the distribution of 
# the observed variable 'x', given the latent variable 'z'.
# The distribution is a Gaussian distribution,
# we use a MLP to compute the mean and the log of the variance.
#

class DecoderMLP(nn.Module):
    """Decoder module. Computes the parameters of the distribution of the
    observed variable 'x', given the latent variable 'z'. The distribution is
    a Gaussian distribution, we use a MLP to compute the mean and the log of
    the variance.

    Args:
        nn (_type_): _description_
    
    """
    
    def __init__(self, 
                 latent_dim=Z_DIM, # Dimension of the latent space
                 observation_dim=X_DIM, # Dimension of the observation space
                 inter_dim=INTERMEDIATE_LAYER_DIM, # Dimension of the intermediate layers
                 layers_dim = None, # Dimension of the MLP layers
                 activation = 'tanh', # Activation function
    ):
        super(DecoderMLP, self).__init__()
        
        self.latent_dim = latent_dim
        self.observation_dim = observation_dim
        self.inter_dim = inter_dim
        
        if activation == 'tanh':
            self.activation_fn = nn.Tanh()
        elif activation == 'relu':
            self.activation_fn = nn.ReLU()
        else:
            raise ValueError(f"Activation function {activation} not supported. Use 'tanh' or 'relu'.")
        
        if layers_dim is None:
            layers_dim = [self.inter_dim] # one layer per default
            
        # explicitly define the MLP layers
        layers = []
        for i, dim in enumerate(layers_dim):
            if i==0:
                layers.append(nn.Linear(latent_dim, dim))
            else:
                layers.append(nn.Linear(layers_dim[i-1], dim))
            layers.append(self.activation_fn)
            
        # last layer is linear, no activation
        layers.append(nn.Linear(layers_dim[-1], 2 * self.observation_dim)) 
                    
        # build the MLP
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, z):
        """
        Forward pass of the decoder module.
        Args:
            z: latent variable at time t
            shape (seq_len, batch, latent_dim)
        Returns:
            mu: mean of the distribution of the observed variable
            shape (seq_len, batch, observation_dim)
            logvar: log of the variance of the distribution of the observed variable
            shape (seq_len, batch, observation_dim)
        """
        # Pass through MLP
        out = self.mlp(z)
        
        # Split the output into mean and log variance
        # each with shape (batch, observation_dim)
        mu, logvar = out[:, :, :self.observation_dim], out[:, :, self.observation_dim:]
        
        return mu, logvar

In [None]:
# --- brick 6 : Sampler with reparameterization trick -----------------------------
#
# This samples from a normal distribution of given mean and log variance
# using the reparameterization trick.

class Sampler(nn.Module):
    """Sampler module. Samples from a normal distribution of given mean and
    log variance using the reparameterization trick.

    Args:
        nn (_type_): _description_
    """
    
    def __init__(self):
        super(Sampler, self).__init__()
        
    def forward(self, mu, logvar):
        """
        Forward pass of the sampler module.
        
        Args:
            mu: mean of the distribution
            shape (batch, dim)
            logvar: log of the variance of the distribution
            shape (batch, dim)
            
        Returns:
            v: sampled variables
            shape (batch, dim)
        """
        
        # Sample from a normal distribution using the reparameterization trick
        std = torch.exp(0.5 * logvar)  # standard deviation
        eps = torch.randn_like(std)  # random noise
        v = mu + eps * std  # sampled variables
        
        return v

### Class DeepKalman Filter

In [None]:
class DeepKalmanFilter(nn.Module):
    """
    Deep Kalman Filter (DKF) module. Implements the DKF algorithm.
    
    Args:
        nn (_type_): _description_
        
    Returns:
        _type_: _description_
    """
    
    def __init__(self,
                 input_dim=X_DIM, # Dimension of the observation space
                 latent_dim=Z_DIM, # Dimension of the latent space
                 hidden_dim=H_DIM, # Dimension of the hidden state of the LSTM network
                 combiner_dim=G_DIM, # Dimension of the combiner output
                 inter_dim=INTERMEDIATE_LAYER_DIM, # Dimension of the intermediate layers
                 activation='tanh', # Activation function
                 num_layers=1, # Number of layers of the LSTM network
                 device='cpu' # Device to use (cpu or cuda)
                 ):
        super(DeepKalmanFilter, self).__init__()
        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.combiner_dim = combiner_dim
        self.inter_dim = inter_dim
        self.device = device
        
        # define the modules
        
        self.backward_lstm = BackwardLSTM(
            input_size=self.input_dim,
            hidden_size=self.hidden_dim,
            num_layers=num_layers
        )
        
        self.combiner = CombinerMLP(
            latent_dim=self.latent_dim,
            hidden_dim=self.hidden_dim,
            output_dim=self.combiner_dim,
            activation=activation,
            layers_dim=None, # list of layers dimensions, without the input dimension, without the output dimension
            inter_dim=self.inter_dim
        )
        
        self.encoder = EncoderMLP(
            latent_dim=self.latent_dim,
            combiner_dim=self.combiner_dim,
            inter_dim=self.inter_dim,
            activation=activation,
            layers_dim=None, # list of layers dimensions, without the input dimension, without the output dimension
        )
        
        self.latent_space_transition = LatentSpaceTransitionMLP(
            latent_dim=self.latent_dim,
            inter_dim=self.inter_dim,
            activation=activation,
            layers_dim=None, # list of layers dimensions, without the input dimension, without the output dimension
        )
        
        self.decoder = DecoderMLP(
            latent_dim=self.latent_dim,
            observation_dim=self.input_dim,
            inter_dim=self.inter_dim,
            activation=activation,
            layers_dim=None, # list of layers dimensions, without the input dimension, without the output dimension
        )
        
        self.sampler = Sampler()
        
    def forward(self, x):
        """
        Forward pass of the Deep Kalman Filter. Runs one step inference
        
        Args:
            x: input sequence
            shape (seq_len, batch, input_dim)
        Returns:
        
        """
        
        # we assume that the input sequence is of shape (seq_len, batch, input_dim)
        seq_len, batch_size, input_dim = x.shape
        assert input_dim == self.input_dim, f"Input dimension {input_dim} does not match the expected dimension {self.input_dim}"
        
        # initialize the latent variable at time t=0
        # NB : in INRIA code : self.register_buffer
        # "If you have parameters in your model, which should be saved and restored in the state_dict, 
        # but not trained by the optimizer, you should register them as buffers.
        # Buffers won’t be returned in model.parameters(), 
        # so that the optimizer won’t have a change to update them.#
        z0 = torch.zeros(batch_size, self.latent_dim).to(self.device)
        # initialize the hidden state of the backward LSTM at time t=0
        # NB : they are not used in a first version of this code
        h0 = torch.zeros(batch_size, self.hidden_dim).to(self.device)
        c0 = torch.zeros(batch_size, self.hidden_dim).to(self.device)
        # initialize the outputs
        # mu_x_s, logvar_x_s = torch.zeros(seq_len, batch_size, self.input_dim).to(self.device), torch.zeros(seq_len, batch_size, self.input_dim).to(self.device)
        mu_z_s, logvar_z_s = torch.zeros(seq_len, batch_size, self.latent_dim).to(self.device), torch.zeros(seq_len, batch_size, self.latent_dim).to(self.device)
        # mu_z_transition_s, logvar_z_transition_s = torch.zeros(seq_len, batch_size, self.latent_dim).to(self.device), torch.zeros(seq_len, batch_size, self.latent_dim).to(self.device)
        
        # run the backward LSTM on the input sequence
        # outputs are the hidden states, shape (seq_len, batch, hidden_dim)
        h_t_s = self.backward_lstm(x)
        
        # loop to compute the approximate posterior distribution of the latent variables z_t
        # given the observations x_t
        # initialize the sequence of sampled latent variables z_t
        sampled_z_t_s = torch.zeros(seq_len, batch_size, self.latent_dim).to(self.device)
        
        for t in range(seq_len):
            # at time t, get z_t-1 and h_t
            if t == 0:
                sampled_z_t_1 = z0
            else:
                sampled_z_t_1 = sampled_z_t_s[t-1]
            h_t = h_t_s[t]
            # compute g_t
            g_t = self.combiner(h_t, sampled_z_t_1)
            # compute the parameters of the approximate posterior distribution
            mu_z, logvar_z = self.encoder(g_t)
            mu_z_s[t], logvar_z_s[t] = mu_z, logvar_z
            # sample z_t and store it
            sampled_z_t = self.sampler(mu_z, logvar_z)
            sampled_z_t_s[t] = sampled_z_t
            
        # compute the parameters of the transition distribution
        z_t_lagged = torch.cat([z0.unsqueeze(0), sampled_z_t_s[:-1]])  # lagged z_t
        mu_z_transition_s, logvar_z_transition_s = self.latent_space_transition(z_t_lagged)
        
        # compute the parameters of the observation distribution
        mu_x_s, logvar_x_s = self.decoder(sampled_z_t_s)
            
        # return the outputs
        return x, mu_x_s, logvar_x_s, mu_z_s, logvar_z_s, mu_z_transition_s, logvar_z_transition_s
    
    def __repr__(self):
        
        msg = f"DeepKalmanFilter(input_dim={self.input_dim}, latent_dim={self.latent_dim}, hidden_dim={self.hidden_dim}, combiner_dim={self.combiner_dim}, inter_dim={self.inter_dim})"
        msg += f"\n{self.backward_lstm}"
        msg += f"\n{self.combiner}"
        msg += f"\n{self.encoder}"
        msg += f"\n{self.latent_space_transition}"
        msg += f"\n{self.decoder}"
        msg += f"\n{self.sampler}"
        
        return msg


### Loss function

In [None]:
def loss_function(x, x_hat, x_hat_logvar, z_mean, z_logvar,
                  z_transition_mean, z_transition_logvar, beta=1.0):
    """
    Compute the total loss for a variational autoencoder (VAE) with a weighted 
    reconstruction loss and a Kullback-Leibler (KL) divergence term.

    Parameters:
    -----------
    x : torch.Tensor
        Ground truth data with shape (seq_len, batch_size, x_dim).
    x_hat : torch.Tensor
        Reconstructed data from the VAE with shape
        (seq_len, batch_size, x_dim).
    x_hat_logvar : torch.Tensor
        Log variance of the reconstructed data with shape
        (seq_len, batch_size, x_dim).
    z_mean : torch.Tensor
        Mean of the latent variable distribution with shape 
        (seq_len, batch_size, x_dim).
    z_logvar : torch.Tensor
        Log variance of the latent variable distribution with shape 
        (seq_len, batch_size, x_dim).
    z_transition_mean : torch.Tensor
        Mean of the transition distribution in the latent space with shape 
        (seq_len, batch_size, x_dim).
    z_transition_logvar : torch.Tensor
        Log variance of the transition distribution in the latent space with 
        shape (seq_len, batch_size, x_dim).
    beta : float
        Weighting factor for the KL divergence term.
    loss_type : str
        Type of reconstruction loss to use. Options:
        - 'mse': Mean Squared Error (MSE) loss.
        - 'weighted_mse': Weighted Mean Squared Error (MSE) loss.

    Returns:
    --------
    total_loss : torch.Tensor
        The total loss, which is the sum of the reconstruction loss and the 
        KL divergence loss.

    Notes:
    ------
    - The "reconstruction loss" is based on formula above
    - The KL divergence loss measures the difference between the latent
      variable distribution and the transition distribution in the latent space.
    - Both losses are normalized by the sequence length (`seq_len`) and
      averaged over the batch.
    - The total loss is a combination of the reconstruction loss and the 
      KL divergence loss, weighted by the `beta` parameter.
    """
    
    seq_len, batch_size, x_dim = x.shape
    
    # Compute the reconstruction loss
    var = x_hat_logvar.exp()
    loss = torch.div((x - x_hat)**2, var)
            
    loss += x_hat_logvar
    loss = loss.sum(dim=2)  # Sum over the x_dim
    loss = loss.sum(dim=0)  # Sum over the sequence length
    loss = loss.mean()  # Mean over the batch
    reconstruction_loss = loss / seq_len
           
    # Compute the KL divergence loss
    kl_loss = (z_transition_logvar - z_logvar +
               torch.div((z_logvar.exp() + 
                         (z_transition_mean - z_mean).pow(2)),
                         z_transition_logvar.exp()))
    
    kl_loss = kl_loss.sum(dim=2)  # Sum over the z_dim
    kl_loss = kl_loss.sum(dim=0)  # Sum over the sequence length
    kl_loss = kl_loss.mean()  # Mean over the batch
    kl_loss = kl_loss / seq_len
                
    # Combine the reconstruction loss and the KL divergence loss
    total_loss = reconstruction_loss + beta * kl_loss
    
    return reconstruction_loss, kl_loss, total_loss

# Toy Case : Data Generation for Time Series Forecasting

In [None]:
def generate_time_series(batch_size, n_steps, noise=0.05):
    """Utility function to generate time series data.

    Args:
        batch_size (int): number of time series to generate (btach size)
        n_steps (_type_): length of each time series
    """
    
    f1,f2,o1,o2 = np.random.rand(4, batch_size, 1)  # return 4 values for each time series
    time = np.linspace(0, 1, n_steps)  # time vector
    
    series = 0.8 * np.sin((time - o1) * (f1 * 100 + 10)) # first sine wave
    series += 0.2 * np.sin((time - o1) * (f1 * 20 + 20)) # second sine wave
    series += noise * (np.random.randn(batch_size, n_steps) - 0.5)  # add noise
    
    return series

In [None]:
n_steps = 500
n_series = 10000
s = generate_time_series(n_series, n_steps+1)

In [None]:
N = 3
fig, axs = plt.subplots(N, 1, figsize=(16, 3 * N))
for i in range(N):
    axs[i].plot(s[i], color='blue', marker="x", linewidth=1)
    axs[i].set_title(f"Time series {i+1}")
    axs[i].set_xlabel("Time")
    axs[i].set_ylabel("Value")
    axs[i].grid(True)
plt.tight_layout()
plt.show()

In [None]:
cutoff = int(0.8 * n_series)

X_train, y_train = s[:cutoff,:n_steps], s[:cutoff,-1]
X_valid, y_valid = s[cutoff:,:n_steps], s[cutoff:,-1]

In [None]:
# form datasets, dataloaders, etc

BATCH_SIZE = 8192

from torch.utils.data import Dataset, DataLoader

class TimeSeriesDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X).to(device)
        self.y = torch.tensor(y).to(device)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
train_dataset = TimeSeriesDataset(X_train, y_train)
valid_dataset = TimeSeriesDataset(X_valid, y_valid)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)

Baseline metrics

In [None]:
# Predicted value is the last value of the time series

y_pred = X_valid[:,-1]
print(f"{np.mean(np.sqrt((y_valid - y_pred) ** 2)):.4f} RMSE")

In [None]:
N = 3
fig, axs = plt.subplots(N, 1, figsize=(16, 2 * N))
for i in range(N):
    input = torch.tensor(X_valid[i], device=device)
    target = torch.tensor(y_valid[i], device=device)
    output = y_pred[i]
    target = target.cpu().detach().numpy()
    axs[i].plot(input.cpu().detach().numpy(), color='blue', marker="x", linewidth=1, label="input")
    axs[i].scatter(n_steps, target, color='red', marker="o", linewidth=1, label="ground truth")
    axs[i].scatter(n_steps, output, color='green', marker="*", linewidth=1, label="prediction")
    axs[i].set_title(f"Time series {i+1}")
    axs[i].set_xlabel("Time")
    axs[i].set_ylabel("Value")
    axs[i].legend()
    axs[i].grid(True)
plt.tight_layout()
plt.show()

### Training DKF

In [None]:
xdim = 1
latent_dim = 16
h_dim = 16
combiner_dim = 8

In [None]:
dkf = DeepKalmanFilter(
    input_dim = xdim,
    latent_dim = latent_dim,
    hidden_dim = h_dim,
    combiner_dim = combiner_dim,
    num_layers = 1,
    device=device
).to(device)

print(dkf)

In [None]:
optimizer = torch.optim.Adam(dkf.parameters(), lr=5e-4)
loss_fn = loss_function

In [None]:
# Training step : perform training for one epoch

def train_step(model, optimizer, criterion, train_loader=train_loader, device=device):
    ### training step
    model.train()
    optimizer.zero_grad()
    ### loop on training data
    rec_loss = 0
    kl_loss = 0
    epoch_loss = 0
    
    for input, _ in train_loader:
        input = input.to(device).unsqueeze(-1)  # add a feature dimension
        input = input.permute(1, 0, 2)  # permute to (seq_len, batch_size, input_dim)

        _, mu_x_s, logvar_x_s, mu_z_s, logvar_z_s, mu_z_transition_s, logvar_z_transition_s = model(input)
        
        rec_loss, kl_loss, total_loss = criterion(input, mu_x_s, logvar_x_s, mu_z_s, logvar_z_s, mu_z_transition_s, logvar_z_transition_s)
        
        total_loss.backward()
        optimizer.step()
              
        rec_loss += rec_loss.item()
        kl_loss += kl_loss.item()
        epoch_loss += total_loss.item()
        
    epoch_loss /= len(train_loader)
    rec_loss /= len(train_loader)
    kl_loss /= len(train_loader)
    
    return rec_loss, kl_loss, epoch_loss

In [None]:
num_epochs = 100

In [None]:
rec_losses = []
kl_losses = []
epoch_losses = []

for i in range(num_epochs):
    
    # run the training step
    rec_loss, kl_loss, epoch_loss = train_step(dkf, optimizer, loss_fn)
    # log results
    rec_losses.append(rec_loss)
    kl_losses.append(kl_loss)
    epoch_losses.append(epoch_loss)
    # Print the losses for this epoch
    if (i+1) % 10 == 0:
        print(f"Epoch {i+1:>5}/{num_epochs} - Rec Loss: {rec_loss:.4e}, KL Loss: {kl_loss:.4e}, Total Loss: {epoch_loss:.4e}")

In [None]:
# Plot the losses

fig, axs = plt.subplots(1, 3, figsize=(16, 4))
axs[0].plot(torch.tensor(rec_losses).cpu().detach(), label='Rec Loss', color='blue')
axs[0].set_title('Reconstruction Loss')
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Loss')
axs[0].legend()
axs[0].grid()
axs[1].plot(torch.tensor(kl_losses).cpu().detach(), label='KL Loss', color='orange')
axs[1].set_title('KL Loss')
axs[1].set_xlabel('Epochs')
axs[1].set_ylabel('Loss')
axs[1].legend()
axs[1].grid()
axs[2].plot(torch.tensor(epoch_losses).cpu().detach(), label='Total Loss', color='green')
axs[2].set_title('Total Loss')
axs[2].set_xlabel('Epochs')
axs[2].set_ylabel('Loss')
axs[2].legend()
axs[2].grid()
plt.tight_layout()
plt.show()

In [None]:
# def test_step(model, criterion, valid_loader=valid_loader):
#     ### testing step
#     model.eval()
#     epoch_loss = 0
#     with torch.no_grad():
#         for input, target in valid_loader:
#             input = input.to(device).unsqueeze(-1)  # add a feature dimension
#             target = target.to(device).view(-1, 1)
#             output = model(input)
#             loss = criterion(output, target)
#             epoch_loss += loss.item()
#     epoch_loss /= len(valid_loader)
#     return epoch_loss

In [None]:
# def train_rnn_model(model, num_epochs=20, batch_size=32):
#     print(f"Start training RNN model for {num_epochs} epochs")
#     for i in range(num_epochs):
#         # loop on training data
#         train_step_loss = train_step(model, optimizer, criterion)
#         train_losses.append(train_step_loss)
#         # test step
#         test_step_loss = test_step(model, criterion)
#         valid_losses.append(test_step_loss)
#         print(f"epoch {i+1}/{num_epochs}, training loss = {train_step_loss:.4e}, validation loss = {test_step_loss:.4e}")
#     print("\nTraining finished")
#     return train_losses, valid_losses

In [None]:
# rnn = RNNModel(
#     input_dim=1,
#     output_dim=1,
#     hidden_dim=64,
#     num_layers=1,
#     batch_first=True,
#     device=device,
#     dtype=dtype
# ).to(device)

# print(rnn)

# lr = 1e-5
# optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)
# criterion = nn.MSELoss()
# train_losses = []
# valid_losses = []
# num_epochs = 50

# train_losses, valid_losses = train_rnn_model(rnn, num_epochs=num_epochs, batch_size=32)

# plt.plot(train_losses, label="train")
# plt.plot(valid_losses, label="valid")
# plt.legend()
# plt.xlabel("Epoch")
# plt.ylabel("Loss")
# plt.title("Training and Validation Loss")
# plt.grid()
# plt.show()

In [None]:
# y_pred = rnn(torch.tensor(X_valid).to(device).unsqueeze(-1))
# y_pred = y_pred.cpu().detach().numpy()

# print(f"\n{np.mean(np.sqrt((y_valid - y_pred) ** 2)):.4f} RMSE")

In [None]:
# N = 5
# fig, axs = plt.subplots(N, 1, figsize=(16, 3 * N))
# for i in range(N):
#     input = torch.tensor(X_valid[i], device=device).unsqueeze(1).unsqueeze(0)
#     # print(f"input has shape {input.shape}")
#     target = torch.tensor(y_valid[i], device=device).view(-1,1)
#     # print(f"target has shape {target.shape}")
#     output = rnn(input)
#     output = output.cpu().detach().numpy()
#     # print(f"output has shape {output.shape}")
#     target = target.cpu().detach().numpy()
#     axs[i].plot(input.squeeze().cpu().detach().numpy(), color='blue', marker="x", linewidth=1, label="input")
#     axs[i].scatter(n_steps, target, color='red', marker="o", linewidth=1, label="ground truth")
#     axs[i].scatter(n_steps, output, color='green', marker="*", linewidth=1, label="prediction")
#     axs[i].set_title(f"Time series {i+1}")
#     axs[i].set_xlabel("Time")
#     axs[i].set_ylabel("Value")
#     axs[i].legend()
#     axs[i].grid(True)
# plt.tight_layout()
# plt.show()

### Forecast N steps ahead

In [None]:
# N_AHEAD = 20
# n_series = 50000
# cutoff = int(n_series * 0.8)

# series = generate_time_series(n_series, n_steps + N_AHEAD)

# X_train, y_train = series[:cutoff, :n_steps], series[:cutoff, -N_AHEAD:]
# X_test, y_test = series[cutoff:, :n_steps], series[cutoff:, -N_AHEAD:]

# print(f"X_train shape: {X_train.shape}")
# print(f"y_train shape: {y_train.shape}")    
# print(f"X_test shape: {X_test.shape}")
# print(f"y_test shape: {y_test.shape}")

# train_dataset = TimeSeriesDataset(X_train, y_train)
# test_dataset = TimeSeriesDataset(X_test, y_test)

# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [None]:
# class RNNModelLookAhead(nn.Module):
#     def __init__(self, input_dim, output_dim, hidden_dim, n_ahead=N_AHEAD, num_layers=1, batch_first=True, device=device, dtype=dtype):
#         """Constructor for RNN.

#         Args:
#             input_dim (_type_): dimensionality of the input
#             hidden_dim (_type_): dimensionality of the hidden state
#             n_ahead (_type_, optional): number of time steps to predict. Defaults to N_AHEAD.
#             output_dim (_type_, optional): dimensionality of the output.
#             num_layers (int, optional): number of recurrent layers. Defaults to 1.
#             batch_first (bool, optional): whether batch dim is first or not. Defaults to True.
#                 1. batch_first=True: (batch, seq, feature_dimension)
#                 2. batch_first=False: (seq, batch, feature_dimension)
#             bidirectional (bool, optional): if True, becomes a bidriectional RNN. Defaults to False.
#                 1. bidirectional=True: num_directions=2, (batch, seq, hidden_dim * 2)
#                 2. bidirectional=False: num_directions=1, (batch, seq, hidden_dim)
#             device (_type_, optional): _description_. Defaults to device.
#             dtype (_type_, optional): _description_. Defaults to dtype.
#         """
#         super(RNNModelLookAhead, self).__init__()
        
#         self.input_dim = input_dim
#         self.output_dim = output_dim
#         self.hidden_dim = hidden_dim
#         self.num_layers = num_layers
#         self.batch_first = batch_first
#         self.bidirectional = False
#         self.n_ahead = n_ahead
        
#         self.rnn = nn.RNN(
#             input_size=input_dim,
#             hidden_size=hidden_dim,
#             num_layers=num_layers,
#             batch_first=batch_first,
#             bidirectional=self.bidirectional
#         )
#         self.fc = nn.Linear(hidden_dim, n_ahead*output_dim)
    
#     def forward(self, x):
#         # first, initialize the hidden state
#         h0 = torch.zeros((self.num_layers, x.size(0), self.hidden_dim), requires_grad=True).to(device)
#         # INPUT : x : (batch, sequence_length, input_feature_dimension)
#         x, _ = self.rnn(x, h0) 
#         # OUTPUT: N = 10
# fig, ax  = plt.subplots(N, 1, figsize=(16, 3 * N))
# x_shift = X_test.shape[-1]

# for i in range(N):
#     input = torch.tensor(X_test[i], device=device).unsqueeze(0).unsqueeze(-1)
#     # print(f"input has shape {input.shape}")
#     target = torch.tensor(y_test[i], device=device).view(-1, N_AHEAD, 1)
#     # print(f"target has shape {target.shape}")
#     output = rnn(input)
#     output = output.cpu().detach().numpy()
#     # print(f"output has shape {output.shape}")
#     target = target.cpu().detach().numpy()
    
#     ax[i].plot(input.squeeze().cpu().detach().numpy(), color='blue', marker="x", linewidth=1, label="input")
#     ax[i].plot(np.arange(len(target.squeeze()))+x_shift, target.squeeze(), color='red', marker="o", linewidth=1, label="ground truth")
#     ax[i].plot(np.arange(len(target.squeeze()))+x_shift, output.squeeze(), color='green', marker="*", linewidth=1, label="prediction")
#     ax[i].set_title(f"Time series {i+1}")
#     ax[i].set_xlabel("Time")
#     ax[i].set_ylabel("Value")
#     ax[i].legend()
#     ax[i].grid(True)

# plt.tight_layout()
# plt.show()
#         # - output : (batch, sequence_length, hidden_dimension * num_directions)
#         # - h_n : (num_layers * num_directions, batch, hidden_dimension) (hidden state for last time step)
#         x = self.fc(x[:, -1, :])  # take the last time step
#         x = x.view(-1, self.n_ahead, self.output_dim)
#         # OUTPUT: x : (batch, output_dimension)
#         return x

In [None]:
# rnn = RNNModelLookAhead(
#     input_dim=1,
#     output_dim=1,
#     n_ahead=N_AHEAD,
#     hidden_dim=128,
#     num_layers=4,
#     batch_first=True,
#     device=device,
#     dtype=dtype
# ).to(device)

# print(rnn)

In [None]:
# # Test dimensions

# x = torch.randn(32, 50, 1).to(device)
# y = rnn(x)
# print(f"input shape: {x.shape}")
# print(f"output shape: {y.shape}")

In [None]:
# lr = 1e-5
# optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)
# criterion = nn.MSELoss()
# num_epochs = 50

In [None]:
# train_losses = []
# valid_losses = []

# print(f"Start training RNN model for {num_epochs} epochs")

# for i in range(num_epochs):
#     # loop on training data
#     rnn.train()
#     optimizer.zero_grad()
#     ### loop on training data
#     epoch_loss = 0
#     for input, target in train_loader:
#         input = input.to(device).unsqueeze(-1)  # add a feature dimension
#         # print(f"input has shape {input.shape}")
#         target = target.to(device).view(-1, N_AHEAD, 1)
#         # print(f"target has shape {target.shape}")
#         output = rnn(input)
#         # print(F"output has shape {output.shape}")
#         loss = criterion(output, target)
#         loss.backward()
#         optimizer.step()
#         epoch_loss += loss.item()
#     epoch_loss /= len(train_loader) 
#     train_losses.append(epoch_loss)
    
#     # test step
#     rnn.eval()
#     epoch_loss = 0
#     with torch.no_grad():
#         for input, target in test_loader:
#             input = input.to(device).unsqueeze(-1)  # add a feature dimension
#             target = target.to(device).view(-1, N_AHEAD, 1)
#             output = rnn(input)
#             loss = criterion(output, target)
#             epoch_loss += loss.item()
#     epoch_loss /= len(test_loader)
#     valid_losses.append(epoch_loss)
    
#     # report out
#     print(f"epoch {i+1:>4}/{num_epochs}, training loss = {train_losses[-1]:.4e}, validation loss = {valid_losses[-1]:.4e}")

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(16, 3))
# ax.plot(train_losses, label="train")
# ax.plot(valid_losses, label="valid")
# ax.legend()
# ax.set_xlabel("Epoch")
# ax.set_xticks(np.arange(0, num_epochs+1, 2))
# ax.set_xticklabels(np.arange(0, num_epochs+1, 2))
# ax.set_ylabel("Loss")
# ax.set_title("Training and Validation Loss")
# ax.grid()
# plt.show()

In [None]:
# y_pred = rnn(torch.tensor(X_test).to(device).unsqueeze(-1))
# y_pred = y_pred.cpu().detach().numpy().squeeze()
# # print(y_pred.shape)
# # print(y_test.shape)

# print(f"\n{np.mean(np.sqrt((y_test - y_pred) ** 2)):.4f} RMSE")

In [None]:
# N = 10
# fig, ax  = plt.subplots(N, 1, figsize=(16, 3 * N))
# x_shift = X_test.shape[-1]

# for i in range(N):
#     input = torch.tensor(X_test[i], device=device).unsqueeze(0).unsqueeze(-1)
#     # print(f"input has shape {input.shape}")
#     target = torch.tensor(y_test[i], device=device).view(-1, N_AHEAD, 1)
#     # print(f"target has shape {target.shape}")
#     output = rnn(input)
#     output = output.cpu().detach().numpy()
#     # print(f"output has shape {output.shape}")
#     target = target.cpu().detach().numpy()
    
#     ax[i].plot(input.squeeze().cpu().detach().numpy(), color='blue', marker="x", linewidth=1, label="input")
#     ax[i].plot(np.arange(len(target.squeeze()))+x_shift, target.squeeze(), color='red', marker="o", linewidth=1, label="ground truth")
#     ax[i].plot(np.arange(len(target.squeeze()))+x_shift, output.squeeze(), color='green', marker="*", linewidth=1, label="prediction")
#     ax[i].set_title(f"Time series {i+1}")
#     ax[i].set_xlabel("Time")
#     ax[i].set_ylabel("Value")
#     ax[i].legend()
#     ax[i].grid(True)

# plt.tight_layout()
# plt.show()

### Bidirectionnal RNN

https://www.geeksforgeeks.org/bidirectional-recurrent-neural-network/

https://www.kaggle.com/code/amansherjadakhan/introduction-to-bidirectional-rnn