<a href="https://colab.research.google.com/github/MichalSlowakiewicz/Statistical-Data-Analysis-2/blob/master/LAB_8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **SAD2, lab 8:** Expectation-Maximization for Hidden Markov Models

*Justyna Król*



## **Expectation–Maximization (EM) Algorithm**

In many real-world problems we want to estimate model parameters even though part of the data is hidden or unobserved. Examples include clustering, mixture models, gene annotation, and Hidden Markov Models (HMMs). In such settings, the likelihood function becomes difficult to optimize directly because we must marginalize over all possible latent variable configurations.

Let:

$X$ = observed data

$Z$ = latent (unobserved) variables

$\theta$ = model parameters we want to estimate

We want to maximize the log-likelihood:

$$ \log p(X \mid \theta) = \log \sum_{Z} p(X, Z \mid \theta) $$

Direct maximization is often intractable because of the summation over $Z$.
EM solves this using two alternating steps:

#### **1) E-Step (Expectation)**

Compute $ Q(\theta \mid \theta^{(t)})$ defined as the expected value of the log likelihood function of  $\theta$, with respect to the current conditional distribution of  $Z$ given  $X$ and the current estimates of the parameters $\theta^{(t)}$:

$$
Q(\theta \mid \theta^{(t)})
= \mathbb{E}_{Z \sim p(\cdot \mid X, \theta^{(t)})} \left[ \log p(X, Z \mid \theta) \right]
:= \int \log p(X, Z \mid \theta) \, p(Z \mid X, \theta^{(t)}) \, dZ .
$$


#### **2) M-Step (Maximization)**

Update parameters by maximizing the $Q$ function:

$$ \theta^{(t+1)} = \arg\max_{\theta} Q(\theta \mid \theta^{(t)}) $$

This gives us the new parameters - then the EM cycle repeats until convergence.


## Reminder: Hidden Markov Models

>*Definition.*
A **Hidden Markov Model** $\mathcal{M}$ is a triplet  
$$
\mathcal{M} = (\Sigma, Q, \Theta),
$$
where:

- $\Sigma$ is an alphabet of observable symbols,  
- $Q$ is a finite set of hidden states,  
- $\Theta = (\pi, T, E)$ is a collection of probability distributions consisting of:  

  - **initial state probabilities** $\pi_i$ for each $i \in Q$:
  $$ \pi_i = \mathbb{P}(Z_1 = i) $$

  - **transition probabilities** $t_{i,j}$ for $i, j \in Q$:
  $$ t_{i,j} = \mathbb{P}(Z_n = j \mid Z_{n-1} = i) $$

  - **emission probabilities** $e_j(s)$ for $j \in Q$ and $s \in \Sigma$:
  $$ e_j(s) = \mathbb{P}(X_n = s \mid Z_n = j) $$

Here,  
- $X = (X_1, X_2, \ldots)$ is the sequence of observable symbols over $\Sigma$, and  
- $Z = (Z_1, Z_2, \ldots)$ is the sequence of hidden states over $Q$.


## EM for HMMs

**Motivation:**  
If the parameters $\theta$ of an HMM are known, the distribution over the latent states $Z$ can be computed efficiently using the **forward–backward algorithm**.  
Conversely, if the latent states $Z$ were known, estimating the parameters $\theta$ would be straightforward: we could count how often transitions and emissions occur and normalize these counts to obtain updated transition, emission, and initial probabilities.

This motivates an iterative procedure when both $\theta$ and $Z$ are unknown:

1. Initialize the parameters $\theta$.
2. **E-step:** Using the forward–backward algorithm, compute the posterior distribution of the latent states $Z$ given the current parameters $\theta$.
3. **M-step:** Update $\theta$ using the expected sufficient statistics of $Z$ computed in the E-step.
4. Repeat steps 2–3 until convergence.

This EM procedure guarantees a monotonic increase in the data log-likelihood and converges to a local optimum.


## **Baum–Welch Algorithm**

The **Baum–Welch algorithm** is an EM procedure for HMMs that iteratively updates parameters $\theta = (\pi, T, E)$ when the latent states $Z$ are unknown.

**Steps:**

1. **E-step:** Compute the posterior probabilities of hidden states using the **forward–backward algorithm**:

$$
\gamma_n(i) = \mathbb{P}(Z_n = i \mid X, \theta), $$
$$
\xi_n(i,j) = \mathbb{P}(Z_n = i, Z_{n+1} = j \mid X, \theta)
$$

2. **M-step:** Update the parameters using these expected counts:

- Initial state probabilities:
$$
\pi_i^{\text{new}} = \gamma_1(i)
$$

- Transition probabilities:
$$
t_{ij}^{\text{new}} = \frac{\sum_{n=1}^{N-1} \xi_n(i,j)}{\mathbb{P}(X)}
$$

- Emission probabilities:
$$
e_j(s)^{\text{new}} = \frac{\sum_{n=1}^{N} \gamma_n(j) \mathbf{1}_{\{X_n = s\}}}{\mathbb{P}(X)}
$$

Repeat E-step and M-step until convergence. The algorithm monotonically increases the data log-likelihood and converges to a local optimum.


## **Theoretical task 1:**

Let $X_1, \dots, X_N$ be a sequence of observations in an HMM.  
Given:  

- the current parameter estimates $\theta = (\pi, T, E)$,  
- the output of the forward function $f_{i, n} = \mathbb{P}(X_{1:n}, Z_n = i )$,  
- the output of the backward function $b_{i, n} = \mathbb{P}(X_{n+1:N} \mid Z_n = i)$,  

prove that the joint posterior of consecutive hidden states is

$$
\xi_n(i,j) = \mathbb{P}(Z_n = i, Z_{n+1} = j \mid X) =  \frac{f_{i,n} \, t_{ij} \, e_j(X_{n+1}) \, b_{j,n+1}}{\mathbb{P}(X)}
$$




### Using Log Probabilities in Forward, Backward, and Baum–Welch Algorithms

In this lab, we will work with **log probabilities** to improve numerical stability when dealing with long sequences. Accordingly, the **forward, backward, and Baum–Welch algorithms** will be implemented in a log-space version.  

- In the **forward algorithm**, we compute the log of the forward variables $\log f_{i,n}$, replacing products of probabilities with sums of logs.  
- Similarly, the **backward algorithm** uses $\log b_{i,n}$ to propagate probabilities backward in log-space.  
- The **Baum–Welch EM updates** are also computed using log probabilities, ensuring that expected counts and parameter updates remain numerically stable.  

Using log-space computations prevents underflow, allows handling long sequences, and ensures that all steps of the EM procedure are robust.


## Baum–Welch Algorithm for Multiple Sequences

The standard Baum–Welch algorithm works on a single observation sequence, but it can be generalized to handle **multiple independent sequences** $X^{(1)}, X^{(2)}, \dots, X^{(M)}$ generated from the same HMM.

**Algorithm Steps:**

1. **E-step:**  
   For each sequence $X^{(m)}$, compute the forward and backward variables and the expected counts of:  

   - Transitions:
   $$
   \xi_n^{(m)}(i,j) = \mathbb{P}(Z_n = i, Z_{n+1} = j \mid X^{(m)}, \theta)
   $$

   - States:
   $$
   \gamma_n^{(m)}(i) = \mathbb{P}(Z_n = i \mid X^{(m)}, \theta)
   $$

2. **M-step:**  
   Update the HMM parameters by **summing over all sequences**:

   $$
   \pi_i^{\text{new}} = \frac{1}{M} \sum_{m=1}^M \gamma_1^{(m)}(i)
   $$

   $$
   t_{ij}^{\text{new}} = \frac{\sum_{m=1}^M \sum_{n=1}^{N_m-1} \xi_n^{(m)}(i,j)}{\sum_{m=1}^M \sum_{n=1}^{N_m-1} \gamma_n^{(m)}(i)}
   $$

   $$
   e_j(s)^{\text{new}} = \frac{\sum_{m=1}^M \sum_{n=1}^{N_m} \gamma_n^{(m)}(j) \mathbf{1}_{\{X_n^{(m)} = s\}}}{\sum_{m=1}^M \sum_{n=1}^{N_m} \gamma_n^{(m)}(j)}
   $$

3. **Iterate** the E-step and M-step until convergence.

**Notes:**

- Each sequence contributes independently to the expected counts.  
- Summing over sequences ensures that parameter updates are informed by all observed data.  
- The log-likelihood of the full dataset (sum over sequences) increases monotonically at each iteration.


### **Exercises**

You are provided with a working implementation of the multi-sequence **Baum–Welch algorithm** for Hidden Markov Models. Your goal is to examine the algorithm's convergence behavior.

First, examine the code below.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)

def forward_log(T_log, E_log, pi_log, sequence):
    """
    Forward algorithm in log space.
    T_log: log transition matrix (K x K)
    E_log: log emission matrix  (K x M)
    pi_log: log initial probabilities (K)
    sequence: observed sequence (list or tensor of ints)
    """
    K = E_log.shape[0]
    N = len(sequence)

    f_log = torch.zeros((K, N), dtype=torch.float32)

    # Initialization
    f_log[:, 0] = pi_log + E_log[:, sequence[0]]

    # Recursion
    for i in range(1, N):
        for l in range(K):
            f_log[l, i] = E_log[l, sequence[i]] + torch.logsumexp(f_log[:, i-1] + T_log[:, l], dim=0)

    loglik = torch.logsumexp(f_log[:, -1], dim=0)

    return f_log, loglik


def backward_log(T_log, E_log, sequence):
    """
    Backward algorithm in log space.
    T_log: log transition matrix (K x K)
    E_log: log emission matrix (K x M)
    sequence: observed sequence
    """
    K = T_log.shape[0]
    N = len(sequence)

    b_log = torch.zeros((K, N), dtype=torch.float32)

    # Initialization
    b_log[:, N - 1] = 0.

    for i in range(N - 2, -1, -1):
        for k in range(K):
            b_log[k, i] = torch.logsumexp(T_log[k, :] + E_log[:, sequence[i + 1]] + b_log[:, i + 1], dim=0)

    return b_log


def get_posterior_log(T_log, E_log, pi_log, sequence):
    """
    Compute posterior p(z_i | x) from log forward and log backward scores
    Posterior is returned in ordinary prob space.
    """
    f_log, loglik = forward_log(T_log, E_log, pi_log, sequence)
    b_log = backward_log(T_log, E_log, sequence)

    log_posterior = f_log + b_log

    # Convert to normalized probabilities
    log_posterior -= torch.logsumexp(log_posterior, dim=0, keepdim=True)

    posterior = torch.exp(log_posterior)
    return posterior


def baum_welch_log(sequences, K, M, init_T, init_E, init_pi, max_iter=50, tol=1e-12):
    """
    Baum–Welch training for multiple sequences using log-space forward/backward.

    sequences: list of int tensors (each length N_n)
    K: number of hidden states
    M: number of emission symbols
    """
    T_log = torch.log(init_T + 1e-12)
    E_log = torch.log(init_E + 1e-12)
    pi_log = torch.log(init_pi + 1e-12)

    last_loglik = -float("inf")
    logliks = []
    for iteration in range(max_iter):
        # Accumulators for expected counts
        gamma_sum_for_transitions = torch.zeros((K,))
        gamma_init_sum = torch.zeros((K,))
        emission_sum = torch.zeros((K, M))
        xi_sum = torch.zeros((K, K))

        total_loglik = 0.0

        # E-step: calculate log_likelihood
        for seq in sequences:
            # Forward / Backward in log space
            f_log, loglik = forward_log(T_log, E_log, pi_log, seq)
            b_log = backward_log(T_log, E_log, seq)

            total_loglik += loglik.item()

            # Posterior state probabilities
            gamma_log = f_log + b_log - loglik
            gamma = torch.exp(gamma_log)         # (K × N)

            # Posterior transition probabilities: xi
            N = len(seq)
            for i in range(N - 1):
                xi_log = (
                    f_log[:, i].unsqueeze(1)
                    + T_log
                    + E_log[:, seq[i + 1]].unsqueeze(0)
                    + b_log[:, i + 1].unsqueeze(0)
                    - loglik
                )
                xi = torch.exp(xi_log)  # (K × K)
                xi_sum += xi

            # Accumulate expected counts
            gamma_sum_for_transitions += torch.sum(gamma[:, :-1], dim=1)
            gamma_init_sum += gamma[:, 0]
            for i, x in enumerate(seq):
                emission_sum[:, x] += gamma[:, i]

        # M-step: update probabilities
        pi = gamma_init_sum / torch.sum(gamma_init_sum)
        T = xi_sum / gamma_sum_for_transitions.unsqueeze(1)
        E = emission_sum / torch.sum(emission_sum, dim=1, keepdim=True)

        # Convert to log-space for next iteration
        T_log = torch.log(T + 1e-12)
        E_log = torch.log(E + 1e-12)
        pi_log = torch.log(pi + 1e-12)

        # Check convergence
        if iteration>5 and np.abs(total_loglik - last_loglik) < tol:
            print(f"Converged after {iteration + 1} iterations.")
            break

        last_loglik = total_loglik
        logliks.append(total_loglik)
        print(f"Iter {iteration + 1}: log-likelihood = {total_loglik:.4f}")
        print("T: ", torch.round(T, decimals=4))
        print("E: ", torch.round(E, decimals=4))
        print("pi: ", torch.round(pi, decimals=4))

    return T, E, pi, logliks

### **Exercise 1: Investigate the impact of parameter initialization on the algorithms convergence**

1. **Generate observation sequences**  
- Use a known HMM with specified parameters $(\pi, T, E)$.
$$
  \pi = [0.5, 0.5], \quad
  T = \begin{bmatrix} 0.9 & 0.1 \\ 0.2 & 0.8 \end{bmatrix}, \quad
  E = \begin{bmatrix} 0.5 & 0.5 \\ 0.1 & 0.9 \end{bmatrix}
$$
- Generate a set of 5 observation sequences from this model, each of length 500.

2. **Examine the impact of initialization**  
  - Run the Baum–Welch algorithm multiple times using
    - 15 different random initial parameter values.  
    - uniform initialization, e.g., all transition and emission probabilities equal and rows summing to 1:  
$$
    \pi = [0.5, 0.5], \quad
    T = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix}, \quad
    E = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix}
$$
  - Create a line plot to:
    - Compare how different initializations affect **convergence speed** and **final log-likelihood**.  
    - Compare the calculated log-likelihoods to the log-likelihood of the data under the **true parameters** used to generate it.

**Discussion Questions:**  
- What effects do you observe from different initializations?  
- How does the final convergence relate to the initial log-likelihood?  
- Are there any special cases or patterns?  
- Is uniform initialization a good strategy? Why or why not?
- How can we initialize the algorithm to ensure "good" convergence?





### **Exercise 2: Posterior state probabilities and parameter evaluation**

1. **Select learned parameters**  
   - From your Baum–Welch runs, choose the parameter set $(\hat{\pi}, \hat{T}, \hat{E})$ that achieved the **highest log-likelihood** on the training sequence.

2. **Generate a new sequence**  
   - Using the **true model parameters** $(\pi, T, E)$, generate a new observation sequence $X_{\text{new}}$ and record the corresponding hidden states $Z_{\text{true}}$.

3. **Compute posterior state probabilities**  
   - Using the selected learned parameters $(\hat{\pi}, \hat{T}, \hat{E})$, run the **forward–backward algorithm** to compute the posterior probabilities $\gamma_n(i) = \mathbb{P}(Z_n = i \mid X_{\text{new}}, \hat{\theta})$ for each time step.  
   - Repeat the computation using the **true parameters** $(\pi, T, E)$ for comparison.

4. **Compare with true hidden states**  
   - For each time step, compare the **most likely state** from $\gamma_n$ to the true hidden state $Z_{\text{true}}$.


5. **Visualization suggestions**  
   - Plot the **posterior probabilities over time** computed from learned vs. true parameters for each hidden state.  
   - Overlay the **true hidden states** to visually assess how well the learned model captures the latent sequence.  

How close are the posterior probabilities from the learned parameters to those from the true parameters?  
Does the learned model recover the hidden states accurately?  
Can you identify sequences or time steps where the model struggles?





### **Exercise 3: Effect of sequence length on convergence**

1. **Generate observation sequences**  
   - Using the HMM from Exercise 1, generate multiple sequences of different lengths, e.g., 50, 100, and 1000.  
   - Generate at least 5 sequences for each length.

2. **Run Baum–Welch**  
   - Apply the Baum–Welch algorithm separately to each sequence (i.e., train on one sequence at a time).

3. **Analyze results**  
   - For each sequence length, examine how the learned parameters and loglikelihoods vary between the algorithm runs.   
   - Discuss how sequence length affects convergence speed, stability, and accuracy of parameter estimation.


### Additional exercises:

1. Repeat the analysis for a bigger model, eg. with 4 hidden states and 6 possible emissions. What are your observations?

2. Partial parameter learning
   - Assume the **emission matrix is known**. Modify the Baum–Welch algorithm to update only the **transition matrix** and **initial state probabilities**.  
   - Investigate convergence behavior in this scenario and compare to learning all parameters simultaneously.  
   - Discuss whether convergence is faster or more stable when fewer parameters are estimated.
