# IIA project GG3: Neural Data Analysis

Easter 2023<br>
Project Leader: Yashar Ahmadian (ya311)


# Forward-Backward algorithm

$\newcommand{\valpha}{\vec{\alpha}}$
$\newcommand{\vbeta}{\vec{\beta}}$
$\newcommand{\talpha}{\tilde{\alpha}}$
$\newcommand{\tbeta}{\tilde{\beta}}$
$\newcommand{\T}{\mathcal{T}}$
$\newcommand{\J}{\mathcal{J}}$

The forward backward algorithm is a message passing algorithm (and an example of dynamic programming) used
for calculating the posterior probabilities of the hidden states of a HMM at different times, conditioned on a sequence of observations. For us we denote the hidden states by $s_t$ and the observations by $n_t$. 

The goal is to calculate the posterior probability $P(s_t | n_{1:T})$. By the definition of conditional probability,
this is given by $P(s_t , n_{1:T})/ P(n_{1:T})$. Thus up to normalization (found by summing the probability over the $K$ values of $s_t$) we need to evaluate the joint probability $P(s_t, n_{1:T})$. By the product rule of probability theory, the latter can be written as the product of $P(n_{t+1:T} | s_t)$ and $P(s_t, n_{1:t})$ (in addition to the product rule, we have also used the Markov property to replace $P(n_{t+1:T}| s_t, n_{1:t})$ with $P(n_{t+1:T} | s_t)$.
If we define

$\alpha_t^s := P(s_t =s, n_{1:t})$

$\beta_t^s := P(n_{t+1:T} | s_t=s)$

We thus have 

$P(s_t = s| n_{1:T}) \propto \alpha_t^s\,\, \beta_t^s$.

The gain here is that $\valpha_t$ and $\vbeta_t$ (both $K$ dimensional vectors with components $\alpha_t^s$ and $\beta_t^s$) satisfy recursion relations that can be used to compute them. These recursions can be derived by starting from the definitions of $\valpha_t$ and $\vbeta_t$, given above, and using the sum and product rules of probability theory, as well as the Markov property of the model. If we define $l_{t}^s$ to denote the conditional observation probability $P(n_t| s_t = s)$, we then find (see below for proof):

$\alpha_{t+1}^s = l_{t+1}^s 
\sum_{s'=1}^K \T_{s,s'} \alpha_{t}^{s'}
\qquad \qquad\quad$          (1)

and 

$\beta_{t}^s =
\sum_{s'=1}^K \beta_{t+1}^{s'} l_{t+1}^{s'} \T_{s',s} 
\qquad \qquad\qquad$          (2)

where $s'$ is summed over the $K$ possible states, in each case. Note that the $\alpha$-recursion goes forward in time, while the $\beta$-recursion goes backwards; hence the name of the algorith. 

Defining the time-dependent matrix $\J_t$ via $[\J_t]_{s,s'} = l_{t}^s \T_{s,s'}$, and assuming $\valpha$'s and $\vbeta$'s and column and row vectors, respectively, we can write the recursion relations more compactly as 

$
\valpha_{t+1} = \J_{t+1} \valpha_t
$

$
\vbeta_{t} = \vbeta_{t+1}\J_{t+1}
$

which also shows these recursions form discrete-time linear time-inhomogeneous systems. 

Finally, the initial conditions for the two recursions are

$\alpha_{t=1}^s = l_{t=1}^s \pi^s$

$\beta_{t=T}^s = 1$

The first follows immediately from the definition, $\alpha_1^s = P(n_1, s_1=s) = P(n_1 | s_1=s) P(s_1=s)$. The second follows because the set of future observation at the last time step, $n_{T+1:\ldots}$, is empty and thus its (conditional) probability (given $s_T$) is 1. 


**Proof of the recursion equations:** To prove (1), we start from the definition of $\valpha_t$, and use the sum rule  followed by the product rule to write

$
\alpha_{t+1}^s = P(s_{t+1} = s, n_{t+1}, n_{1:t}) = \sum_{s'} P(s_{t+1}=s, n_{t+1}, s_t=s', n_{1:t})
$

$
\qquad = \sum_{s'} P(s_{t+1}=s, n_{t+1} | s_t=s', n_{1:t})\, P( s_t=s', n_{1:t}) 
$

$ \qquad = \sum_{s'} P(s_{t+1}=s, n_{t+1} | s_t=s', n_{1:t})\,\, \alpha_{t+1}^{s'}
\qquad\qquad$ (3)

Now due to the *Markov property*, conditioning on   $n_{1:t}$ can be dropped in the left factor of the summands in the last expression:
$P(s_{t+1}=s, n_{t+1} | s_t=s', n_{1:t}) = P(s_{t+1}=s, n_{t+1} | s_t=s')$. Using the product rule one more time, we  write this as 
$P(n_{t+1}| s_{t+1}=s, s_t=s') P(s_{t+1}=s| s_t=s')$. We then use the *conditional independence property of the observations* (CIPO) in the HMM (i.e. the fact that conditioned on $s_{t+1}$, $n_{t+1}$ is independent of $s_t$), to drop the conditioning on $s_t$ in the left factor, substitute in (3), and obtain 

$ \alpha_{t+1}^s = = \sum_{s'} P(n_{t+1}| s_{t+1}=s) P(s_{t+1}=s| s_t=s')\,\, \alpha_{t+1}^{s'}
$

Finally, plugging in the definitions $l_{t}^s = P(n_t | s_t = s)$ and $\T_{ss'} = P(s_{t+1}=s| s_t=s')$, we obtain (1). 


To prove (2), start with the definition of $\vbeta_t$, then use the sum rule, the product rule, and then the *Markov property* to write:
$\beta_{t}^s = P(n_{t+1:T}| s_t=s) = \sum_{s'} P(n_{t+1:T}, s_{t+1}=s' | s_t=s)$

$\qquad = \sum_{s'} P(n_{t+1:T}| s_{t+1}=s' , s_t=s)P(s_{t+1}=s' | s_t=s)$

$\qquad = \sum_{s'} P(n_{t+1:T}| s_{t+1}=s')P(s_{t+1}=s' | s_t=s)$

Using the product rule and the Markov property again, we can write the left factors as $P(n_{t+1:T}| s_{t+1}=s') = P(n_{t+2:T}| s_{t+1}=s', n_{t+1})P(n_{t+1}| s_{t+1}=s') = P(n_{t+2:T}| s_{t+1}=s')P(n_{t+1}| s_{t+1}=s') $. Substituting this expression, together with the definitions of $\vbeta_{t+1}$, $\vec{l}_{t}$ and $\T$ we obtain (2).


## Code implementation:

The code in `inference.py` (adapted from the [SSM package](https://github.com/lindermanlab/ssm) by the Linderman lab) implements the above recursions in terms of the logs of the $\valpha$, $\vbeta$, and the observation and transition probabilities. If we let $\talpha$,  $\tbeta$, $\tilde \T$ and $ll_t$ denote the logs of $\alpha$, $\beta$, $\T$ and $l_t$, can write the recursion equations as 

$\talpha_{t+1}^s = ll_{t+1}^s  + \log 
\sum_{s'=1}^K \exp( \tilde{T}_{s,s'} + \talpha_{t+1}^{s'})
$

and 

$\tbeta_{t}^s = \log
\sum_{s'=1}^K \exp(\tbeta_{t+1}^{s'} + ll_{t+1}^{s'} + \tilde{T}_{s',s})
$

## Model log-likelihood

By the definition of $\alpha_T^s = P(s_T=s, n_{1:T})$, and if we sum this over all $s$, we obtain $P(n_{1:T})$. 
Recalling that model parameters, $\Theta$, where implicitly conditioned on in all of the above probabilities, the latter probability is nothing but the model likelihood 

$P(\text{observed data} | \Theta) = P(n_{1:T} | \Theta)$.

Thus the model likelihood can also be calculated using only the forward pass half of the forward-backward algorithm.

The function `hmm_normalizer` of `inference.py` calculates the model **log-** likelihood using the forward pass.