In [None]:
import jax.numpy as jnp
from utils import HMM
from sklearn.preprocessing import normalize

In [2]:

class HMM:
    """
    A Hidden Markov Model (HMM) class for modeling systems with hidden states and observable outputs.

    Attributes:
        observations (jnp.ndarray): A 1-D array of observations.
        labels (list or jnp.ndarray): Labels of the hidden states.
        transition_matrix (jnp.ndarray): A square matrix defining state transition probabilities.
        emission_matrix (jnp.ndarray): A matrix defining observation probabilities given states.
        initial_state (jnp.ndarray): Initial starting probabilities of the states.
        states (jnp.ndarray): Array of state indices.
        T (int): Number of observations.
        N (int): Number of states.
        alpha (jnp.ndarray): Forward probabilities matrix.
        beta (jnp.ndarray): Backward probabilities matrix.
        gamma (jnp.ndarray): Posterior probabilities matrix.
        theta (jnp.ndarray): Intermediate matrix for Baum-Welch algorithm.
    """

    def __init__(
        self,
        observations,
        labels,
        transition_matrix,
        emission_matrix,
        initial_state=None,
    ):
        """
        Initializes the HMM with observations, state characteristics, transition matrix, emission matrix,
        and optional initial state.

        Args:
            observations (jnp.ndarray): A 1-D array of observations.
            labels (list or jnp.ndarray): Labels of the hidden states.
            transition_matrix (jnp.ndarray): A square matrix of state transition probabilities.
            emission_matrix (jnp.ndarray): A matrix of observation probabilities given states.
            initial_state (jnp.ndarray, optional): Initial probabilities of the states. If None, the
                stationary distribution of the transition matrix is used.

        Raises:
            Exception: If the input matrices or vectors do not meet the required conditions.
        """
        self.observations = observations
        self.transition_matrix = transition_matrix
        self.emission_matrix = emission_matrix
        self.initial_state = (
            initial_state if initial_state is not None else self.stationary_states()
        )
        self.labels = labels
        self.states = jnp.array(list(range(len(labels))))
        self.T = len(observations)
        self.N = len(self.states)
        self.alpha = None  # Forward probabilities matrix
        self.beta = None  # Backward probabilities matrix
        self.gamma = None  # Posterior probabilities matrix
        self.theta = None  # Intermediate matrix for Baum-Welch algorithm

        assert self.verify_matrices()  # Validate input matrices

    def stationary_states(self):
        """
        Computes the stationary states of the transition matrix.

        Returns:
            jnp.ndarray: The stationary state probabilities.
        """
        eigenvalues, eigenvectors = jnp.linalg.eig(
            self.transition_matrix.T
        )  # Eigen decomposition
        stationary = eigenvectors[
            :, jnp.isclose(eigenvalues, 1)
        ].flatten()  # Eigenvector for eigenvalue 1
        stationary = abs(stationary) / jnp.sum(stationary)  # Normalize to sum to 1
        return jnp.real(stationary)  # Return real part of the stationary state

    def verify_matrices(self):
        """
        Validates the shapes and properties of the matrices and vectors used in the HMM.

        Returns:
            bool: True if all matrices and vectors are valid.

        Raises:
            Exception: If any of the validation checks fail.
        """
        shape_trans_mat = self.transition_matrix.shape

        # Check if observations is a 1-D vector
        if self.observations.ndim != 1 and (
            self.observations.ndim != 2 or 1 not in self.observations.shape
        ):
            raise Exception("The observations must be 1-D vector !")

        # Check if states is a 1-D vector or a 2-D vector with one dimension being 1
        elif self.states.ndim != 1 and (
            self.states.ndim != 2 or 1 not in self.states.shape
        ):
            raise Exception("The states must be 1-D vector !")

        # Check if the transition matrix is square
        elif shape_trans_mat[0] != shape_trans_mat[1]:
            raise Exception(
                f"Invalid transition matrix, transition matrix must be a squared matrix as the number of states ({shape_trans_mat[0]})!"
            )

        # Check if the emission matrix has the same number of rows as the transition matrix
        elif shape_trans_mat[0] != self.emission_matrix.shape[0]:
            raise Exception(
                f"Invalid emission matrix, the emission matrix must have the same number of rows as the transition matrix as they share the same number of states ({shape_trans_mat[0]})!"
            )

        # Check if the transition matrix rows sum to 1 (valid probability distribution)
        elif jnp.any(jnp.sum(self.transition_matrix, axis=1) != 1):
            raise Exception(
                "Invalid transition matrix, sum of probabilities of each state must be 1!"
            )

        # Check if the emission matrix rows sum to 1 (valid probability distribution)
        elif jnp.any(jnp.sum(self.emission_matrix, axis=1) != 1):
            raise Exception(
                "Invalid emission matrix, sum of probabilities of each state must be 1!"
            )

        # Check if the initial state vector has the correct shape
        elif self.initial_state.shape != (
            1,
            shape_trans_mat[0],
        ) and self.initial_state.shape != (shape_trans_mat[0],):
            raise Exception(
                f"Invalid initial state vector, it must be (1,{shape_trans_mat[0]}) or {(shape_trans_mat[0],)}"
            )

        elif jnp.any(self.observations < 0) or jnp.any(
            self.observations > self.emission_matrix.shape[1] - 1
        ):
            raise Exception(
                f"Invalid observations vector, this system has only {self.emission_matrix.shape[1]} possible observations ({list(range(self.emission_matrix.shape[1]))})"
            )

        return True  # All checks passed

    def verify_obs(self, obs):
        """
        Verifies if an observation is valid.

        Args:
            obs (int): The observation to verify.

        Raises:
            Exception: If the observation is not an integer or is out of range.
        """
        if not isinstance(obs, int):
            raise Exception("The observation must be an integer")
        elif obs < 0 or obs > self.emission_matrix.shape[1] - 1:
            raise Exception(
                f"This observation doesn't exist, this system has only {self.emission_matrix.shape[1]} possible observations ({list(range(self.emission_matrix.shape[1]))})"
            )
        return True

    def verify_position(self, position):
        """
        Verifies if a position is valid.

        Args:
            position (int): The position to verify.

        Raises:
            Exception: If the position is not an integer or is out of range.
        """
        if position is not None:
            if not isinstance(position, int):
                raise Exception("The position must be an integer")
            elif position < 1 or position > self.gamma.shape[1]:
                raise Exception(
                    f"This position doesn't exist, this system has only {self.gamma.shape[1]} possible positions ({list(range(1, self.gamma.shape[1] + 1))})"
                )
            return True
        return False

    def initial_alpha(self):
        """
        Initializes the forward probabilities (alpha) at time t=0.

        Updates:
            self.alpha (jnp.ndarray): Forward probabilities matrix initialized for t=0.
        """
        self.alpha = jnp.zeros((self.N, self.T))
        for i in range(self.N):
            self.alpha = self.alpha.at[i, 0].set(
                self.initial_state[i] * self.emission_matrix[i, self.observations[0]]
            )

    def last_alpha(self, t):
        """
        Computes the forward probabilities (alpha) at time t.

        Args:
            t (int): The time step for which to compute the forward probabilities.

        Returns:
            jnp.ndarray: Updated forward probabilities matrix.

        Updates:
            self.alpha (jnp.ndarray): Forward probabilities matrix updated for time t.
        """
        for j in range(self.N):
            sum = 0
            for i in range(self.N):
                sum += self.alpha[i, t - 1] * self.transition_matrix[i, j]
            self.alpha = self.alpha.at[j, t].set(
                sum * self.emission_matrix[j, self.observations[t]]
            )
        return self.alpha

    def forward_pass(self):
        """
        Performs the forward pass to compute the forward probabilities for all time steps.

        Returns:
            jnp.ndarray: The forward probabilities matrix.
        """
        self.initial_alpha()
        for t in range(1, self.T):
            self.last_alpha(t)
        return self.alpha



    def initial_beta(self):
        """
        Initializes the backward probabilities (beta) at time T-1.

        Updates:
            self.beta (jnp.ndarray): Backward probabilities matrix initialized for t=T-1.
        """
        self.beta = jnp.zeros((self.N, self.T))
        self.beta = self.beta.at[:, self.T - 1].set(1)
        return self.beta

    def last_beta(self, t):
        """
        Computes the backward probabilities (beta) at time t.

        Args:
            t (int): The time step for which to compute the backward probabilities.

        Returns:
            jnp.ndarray: Updated backward probabilities matrix.

        Updates:
            self.beta (jnp.ndarray): Backward probabilities matrix updated for time t.
        """
        for j in range(self.N):
            sum = 0
            for i in range(self.N):
                sum += (
                    self.beta[i, t + 1]
                    * self.transition_matrix[j, i]
                    * self.emission_matrix[i, self.observations[t + 1]]
                )
            self.beta = self.beta.at[j, t].set(sum)
        return self.beta

    def backward_pass(self):
        """
        Performs the backward pass to compute the backward probabilities for all time steps.

        Returns:
            jnp.ndarray: The backward probabilities matrix.
        """
        self.initial_beta()
        for t in range(self.T - 2, -1, -1):
            self.last_beta(t)
        return self.beta

    def posterior_probabilities(self, position=None):
        """
        Computes the posterior probabilities (gamma) for all states and time steps.

        Args:
            position (int, optional): The position to print probabilities for.

        Returns:
            jnp.ndarray: The posterior probabilities matrix.

        Prints:
            The probabilities of being in the most and least likely states at the specified position.
        """
        self.forward_pass()
        self.backward_pass()
        self.gamma = normalize(
            self.beta * self.alpha, "l1", axis=0
        )  # Normalize to get probabilities

        if self.verify_position(position):
            print(
                f"For the position {position}, you have {jnp.max(self.gamma[:, position - 1]):.3f} chance to be in state '{self.labels[jnp.argmax(self.gamma[:, position - 1])]}' and {jnp.min(self.gamma[:, position - 1]):.3f} chance to be in state '{self.labels[jnp.argmin(self.gamma[:, position - 1])]}'"
            )
        return self.gamma

    def baum_welch(self, threshold=0.001, verbose=True):
        """
        Performs the Baum-Welch algorithm to estimate the HMM parameters (transition and emission matrices).

        Args:
            threshold (float): Convergence threshold for parameter updates.
            verbose (bool): Whether to print intermediate results.

        Returns:
            None
        """
        # Compute posterior probabilities if not already done

        new_transition = self.transition_matrix
        new_emission = self.emission_matrix
        x = 0  # Iteration counter

        while True:
            self.theta = jnp.zeros(
                (self.T, self.N, self.N)
            )  # Initialize intermediate matrix

            self.posterior_probabilities()

            # Compute theta for each time step and state pair
            for t in range(self.T):
                for i in range(self.N):
                    for j in range(self.N):
                        self.theta = self.theta.at[t, i, j].set(
                            self.alpha[i, t]
                            * self.transition_matrix[i, j]
                            * self.emission_matrix[j, self.observations[t + 1]]
                            * self.beta[j, t + 1]
                        )
                self.theta = self.theta.at[t, :, :].set(
                    self.theta[t, :, :] / jnp.sum(self.theta[t, :, :])
                )  # Normalize
            
            # Update transition matrix
            for i in range(self.N):
                for j in range(self.N):
                    new_transition = new_transition.at[i, j].set(
                        jnp.sum(self.theta[:, i, j]) / jnp.sum(self.gamma[i, :-1])
                    )
            new_transition = jnp.array(
                normalize(new_transition, "l1", axis=1)
            )  # Normalize rows

            # Update emission matrix
            for i in range(self.N):
                sum = jnp.sum(self.gamma[i, :])
                for t in range(self.T):
                    new_emission = new_emission.at[i, t].set(
                        jnp.sum(self.gamma[i, self.observations == t]) / sum
                    )
            new_emission = jnp.array(
                normalize(new_emission, "l1", axis=1)
            )  # Normalize rows

            x += 1  # Increment iteration counter

            # Check for convergence
            if jnp.max(jnp.abs(new_transition - self.transition_matrix)) < threshold:
                self.transition_matrix = new_transition
                self.emission_matrix = new_emission
                self.initial_state = self.stationary_states()
                if verbose:
                    print(f"Converged after {x} iterations.")
                break

            self.transition_matrix = new_transition
            self.emission_matrix = new_emission

    def add_observation(self, obs):
        """
        Adds a new observation to the sequence and updates the forward probabilities.

        Args:
            obs (int): The new observation to add.

        Returns:
            jnp.ndarray: The updated forward probabilities matrix.

        Raises:
            AssertionError: If the observation is invalid.
        """
        assert self.verify_obs(obs)
        self.observations = jnp.hstack([self.observations, obs])
        self.T += 1
        if self.alpha is None:
            self.forward_pass()

        self.alpha = jnp.hstack([self.alpha, jnp.zeros((self.N, 1))])
        self.last_alpha(self.T - 1)
        self.backward_pass()
        self.baum_welch()

In [3]:
states = ["Rainy","Sunny"]
observations = jnp.array([1,2,0, 2, 1,0,1,0,0,2,2,0,1])  # Walk (0), Shop (1), Stay at home (2)
start_prob = jnp.array([0.6, 0.4])
trans_prob = jnp.array([[0.7, 0.3], 
                        [0.4, 0.6]])
emit_prob = jnp.array([[0.1, 0.4, 0.5], 
                       [0.6, 0.3, 0.1]])

In [4]:
hmmm=HMM(observations,states,trans_prob,emit_prob,start_prob)


In [5]:
hmmm.baum_welch()

Converged after 23 iterations.


In [240]:
x,y=jnp.sum(hmmm.theta,axis=2).T,hmmm.gamma


In [241]:
observations2 = jnp.array([1,2,0, 2, 1,0,1,0,0,2,2,0])

In [242]:
hmmm2=HMM(observations2,states,trans_prob,emit_prob,start_prob)
hmmm2.add_observation(1)

Converged after 23 iterations.


In [243]:
jnp.sum(hmmm2.theta,axis=2).T-x,hmmm2.gamma-y

(Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],      dtype=float32),
 array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       dtype=float32))

In [139]:
hmmm.alpha.shape

(2, 13)

In [None]:
class ContinuousHMM(HMM):
    def __init__(self,observations, labels, transition_matrix,means,cov, initial_state=None):
        super().__init__(observations, labels, transition_matrix, None, initial_state)
        self.T=self.observations.shape[0]
        self.means=means
        self.cov=cov
        self.emission_matrix = self.initiate_emissions()

    def verify_matrices(self):
        """
        Validates the shapes and properties of the matrices and vectors used in the HMM.

        Returns:
            bool: True if all matrices and vectors are valid.

        Raises:
            Exception: If any of the validation checks fail.
        """
        shape_trans_mat = self.transition_matrix.shape


        # Check if states is a 1-D vector or a 2-D vector with one dimension being 1
        if self.states.ndim != 1 and (
            self.states.ndim != 2 or 1 not in self.states.shape
        ):
            raise Exception("The states must be 1-D vector !")

        # Check if the transition matrix is square
        elif shape_trans_mat[0] != shape_trans_mat[1]:
            raise Exception(
                f"Invalid transition matrix, transition matrix must be a squared matrix as the number of states ({shape_trans_mat[0]})!"
            )


        # Check if the transition matrix rows sum to 1 (valid probability distribution)
        elif jnp.sum(self.transition_matrix) != len(self.transition_matrix):
            raise Exception(
                "Invalid transition matrix, sum of probabilities of each state must be 1!"
            )

        # Check if the initial state vector has the correct shape
        elif self.initial_state.shape != (
            1,
            shape_trans_mat[0],
        ) and self.initial_state.shape != (shape_trans_mat[0],):
            raise Exception(
                f"Invalid initial state vector, it must be (1,{shape_trans_mat[0]}) or {(shape_trans_mat[0],)}"
            )

        return True  # All checks passed
    
    
    def gaussian_pdf(self, x, mu, sigma):
        return jnp.exp(-0.5 * jnp.dot((x - mu).T, jnp.linalg.solve(sigma, (x - mu)))) / \
            jnp.sqrt((2 * jnp.pi) ** len(mu) * jnp.linalg.det(sigma))

    def initial_alpha(self):
        """
        Initializes the forward probabilities (alpha) at time t=0.

        Updates:
            self.alpha (jnp.ndarray): Forward probabilities matrix initialized for t=0.
        """
        self.alpha = jnp.zeros((self.N, self.T))
        for i in range(self.N):
            self.alpha = self.alpha.at[i, 0].set(
                self.initial_state[i] * self.emission_matrix[i, 0]
            )

    def last_alpha(self, t):
        """
        Computes the forward probabilities (alpha) at time t.

        Args:
            t (int): The time step for which to compute the forward probabilities.

        Returns:
            jnp.ndarray: Updated forward probabilities matrix.

        Updates:
            self.alpha (jnp.ndarray): Forward probabilities matrix updated for time t.
        """
        for j in range(self.N):
            sum = 0
            for i in range(self.N):
                sum += self.alpha[i, t - 1] * self.transition_matrix[i, j]
            self.alpha = self.alpha.at[j, t].set(
                sum * self.emission_matrix[j, t]
            )
        return self.alpha

    def last_beta(self, t):
        """
        Computes the backward probabilities (beta) at time t.

        Args:
            t (int): The time step for which to compute the backward probabilities.

        Returns:
            jnp.ndarray: Updated backward probabilities matrix.

        Updates:
            self.beta (jnp.ndarray): Backward probabilities matrix updated for time t.
        """
        for j in range(self.N):
            sum = 0
            for i in range(self.N):
                sum += (
                    self.beta[i, t + 1]
                    * self.transition_matrix[j, i]
                    * self.emission_matrix[i, t+1]
                )
            self.beta = self.beta.at[j, t].set(sum)
        return self.beta

    def initiate_emissions(self):

        emission_matrix=jnp.zeros((self.N,self.observations.shape[1]))
        for state in range(self.N):
            for obs in range(self.observations.shape[1]):
                emission_matrix=emission_matrix.at[state,obs].set(self.gaussian_pdf(self.observations[obs,:],self.means[state],self.cov[state]))

        emission_matrix=jnp.nan_to_num(emission_matrix, nan=0)
        emission_matrix = jnp.array(
                normalize(emission_matrix, "l1", axis=1)
            )
        return emission_matrix
    def baum_welch(self, threshold=0.001, verbose=True):
        """
        Performs the Baum-Welch algorithm to estimate the HMM parameters (transition and emission matrices).

        Args:
            threshold (float): Convergence threshold for parameter updates.
            verbose (bool): Whether to print intermediate results.

        Returns:
            None
        """
        # Compute posterior probabilities if not already done

        new_transition = self.transition_matrix
        x = 0  # Iteration counter
        print(self.emission_matrix)
        while True:
            self.theta = jnp.zeros(
                (self.T, self.N, self.N)
            )  # Initialize intermediate matrix

            self.posterior_probabilities()

            # Compute theta for each time step and state pair
            for t in range(self.T):
                for i in range(self.N):
                    for j in range(self.N):
                        self.theta = self.theta.at[t, i, j].set(
                            self.alpha[i, t]
                            * self.transition_matrix[i, j]
                            * self.emission_matrix[j, t+1]
                            * self.beta[j, t + 1]
                        )
                self.theta = self.theta.at[t, :, :].set(
                    self.theta[t, :, :] / jnp.sum(self.theta[t, :, :])
                )  # Normalize

            # Update transition matrix
            for i in range(self.N):
                for j in range(self.N):
                    new_transition = new_transition.at[i, j].set(
                        jnp.sum(self.theta[:, i, j]) / jnp.sum(self.gamma[i, :-1])
                    )
            new_transition=jnp.nan_to_num(new_transition, nan=0)
            new_transition = jnp.array(
                normalize(new_transition, "l1", axis=1)
            )  # Normalize rows

            # Update emission matrix
            for state in range(self.N):
                self.means=self.means.at[state].set(jnp.sum(jnp.expand_dims(self.gamma[state, :],axis=-1)*self.observations,axis=0)/self.gamma[state, :].sum())
                centered = self.observations - self.means[state]
                self.cov=self.cov.at[state].set(((jnp.expand_dims(self.gamma[state, :],axis=-1) * centered).T @ centered) / self.gamma[state, :].sum())


            
            x += 1  # Increment iteration counter

            # Check for convergence
            if jnp.max(jnp.abs(new_transition - self.transition_matrix)) < threshold:
                self.transition_matrix = new_transition
                self.emission_matrix = self.initiate_emissions()    
                if verbose:
                    print(f"Converged after {x} iterations.")
                break

            self.transition_matrix = new_transition
            self.emission_matrix = self.initiate_emissions() 
            
        

In [131]:
import numpy as np
trans_prob = jnp.array([[0.7, 0.3], 
                        [0.4, 0.6]])
states = ["Rainy","Sunny"]
obss=np.random.randn(10, 3)
means=jnp.array([[2.1,3.2,1.3],
                 [3.4,2.5,0.6]])
start=jnp.array([0.4,0.6])
covs=jnp.array([jnp.eye(3) for _ in range(2)])
xx=ContinuousHMM(obss,states,trans_prob,means,covs,start)

xx.baum_welch()




[[9.9780917e-01 9.4512117e-04 1.2457410e-03]
 [9.9895102e-01 4.2077460e-04 6.2815769e-04]]
[[ 0.21205407  0.26748067 -0.46065518]
 [ 0.35632083  0.43651223 -0.41537017]]
[[ 0.18874164  0.2545022  -0.47551832]
 [ 0.5039736   0.5444575  -0.3347824 ]]
[[ 0.1782228   0.23794855 -0.46718225]
 [ 0.78567517  0.85864604 -0.31336582]]
[[ 0.16915174  0.22217737 -0.4586704 ]
 [ 1.0464524   1.1931627  -0.3604573 ]]


  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


[[ 0.16611746  0.21791899 -0.4576246 ]
 [ 1.0618404   1.2194287  -0.37294987]]
[[ 0.2456649   0.3068611  -0.45010486]
 [ 0.          0.          0.        ]]
Converged after 7 iterations.


In [101]:
xx.beta

Array([[0.00101668, 0.0021809 , 0.00469145, 0.01009201, 0.0217094 ,
        0.04670015, 0.10045897, 0.21610218, 0.46486792, 1.        ],
       [0.00101627, 0.00218003, 0.00468957, 0.01008795, 0.02170069,
        0.04668141, 0.10041865, 0.21601543, 0.46468133, 1.        ]],      dtype=float32)

In [49]:
import scipy
scipy.stats.norm.cdf(0.02/np.sqrt((0.9)))-scipy.stats.norm.cdf(0)

0.008409818785913625

In [16]:
import numpy as np

class ContinuousHMM2:
    def __init__(self, n_states, n_components=1):
        self.n_states = n_states
        self.n_components = n_components
        self.A = None  # Transition matrix
        self.pi = None  # Initial state probabilities
        self.means = None  # Means for each state
        self.covs = None  # Covariances for each state
        self.weights = None  # Weights for GMM components

    def initialize(self, obs_dim):
        self.A = np.array([[0.7, 0.3], 
                        [0.4, 0.6]])

        self.pi = np.array([0.4,0.6])
        
        self.means = np.array([[.1,.2,.3],
                                [.4,.5,.6]])
        self.covs = np.array(covs)
        
        if self.n_components > 1:
            self.weights = np.ones((self.n_states, self.n_components)) / self.n_components

    def gaussian_pdf(self, x, mu, sigma):
        return np.exp(-0.5 * np.dot((x - mu).T, np.linalg.solve(sigma, (x - mu)))) / \
               np.sqrt((2 * np.pi) ** len(mu) * np.linalg.det(sigma))

    def emission_prob(self, obs, state):
        prob = 0.0
        if self.n_components == 1:
            prob = self.gaussian_pdf(obs, self.means[state], self.covs[state])
        else:
            for k in range(self.n_components):
                prob += self.weights[state, k] * self.gaussian_pdf(obs, self.means[state, k], self.covs[state, k])
        return prob

    def forward(self, obs):
        T = len(obs)
        alpha = np.zeros((T, self.n_states))
        alpha[0] = self.pi * np.array([self.emission_prob(obs[0], i) for i in range(self.n_states)])
        alpha[0] /= alpha[0].sum()

        for t in range(1, T):
            for j in range(self.n_states):
                alpha[t, j] = (alpha[t-1] @ self.A[:, j]) * self.emission_prob(obs[t], j)
            alpha[t] /= alpha[t].sum()

        return alpha

    def backward(self, obs):
        T = len(obs)
        beta = np.zeros((T, self.n_states))
        beta[-1] = 1.0

        for t in reversed(range(T-1)):
            for i in range(self.n_states):
                beta[t, i] = sum(beta[t+1, j] * self.A[i, j] * self.emission_prob(obs[t+1], j) for j in range(self.n_states))
            beta[t] /= beta[t].sum()

        return beta

    def baum_welch(self, obs, max_iter=100, tol=1e-6):
        T = len(obs)
        old_log_likelihood = -np.inf

        for _ in range(max_iter):
            # E-step
            alpha = self.forward(obs)
            beta = self.backward(obs)

            gamma = alpha * beta
            gamma /= gamma.sum(axis=1, keepdims=True)

            xi = np.zeros((T-1, self.n_states, self.n_states))
            for t in range(T-1):
                for i in range(self.n_states):
                    for j in range(self.n_states):
                        xi[t, i, j] = alpha[t, i] * self.A[i, j] * self.emission_prob(obs[t+1], j) * beta[t+1, j]
                xi[t] /= xi[t].sum()

            # M-step
            self.pi = gamma[0]
            self.A = xi.sum(axis=0) / gamma[:-1].sum(axis=0, keepdims=True).T

            for i in range(self.n_states):
                if self.n_components == 1:
                    self.means[i] = np.sum(gamma[:, i][:, None] * obs, axis=0) / gamma[:, i].sum()
                    centered = obs - self.means[i]
                    self.covs[i] = np.dot((gamma[:, i][:, None] * centered).T, centered) / gamma[:, i].sum()
                else:
                    for k in range(self.n_components):
                        gamma_ik = gamma[:, i] * self.weights[i, k]
                        self.means[i, k] = np.sum(gamma_ik[:, None] * obs, axis=0) / gamma_ik.sum()
                        centered = obs - self.means[i, k]
                        self.covs[i, k] = np.dot((gamma_ik[:, None] * centered).T, centered) / gamma_ik.sum()
                        self.weights[i, k] = gamma_ik.sum() / gamma[:, i].sum()

            log_likelihood = np.log(alpha[-1].sum())
            if np.abs(log_likelihood - old_log_likelihood) < tol:
                break
            old_log_likelihood = log_likelihood


In [26]:
import numpy as np
trans_prob = jnp.array([[0.7, 0.3], 
                        [0.4, 0.6]])
states = ["Rainy","Sunny"]
obss=np.random.randn(10, 3)
means=jnp.array([[.1,.2,.3],
                 [.4,.5,.6]])
start=jnp.array([0.4,0.6])
covs=jnp.array([jnp.eye(3) for _ in range(2)])
xx=ContinuousHMM(obss,states,trans_prob,means,covs,start)


xx.baum_welch()
xx.means




Converged after 21 iterations.


Array([[-0.3616777 , -0.04198264, -0.7865546 ],
       [-0.18753476, -0.14427495,  0.10520335]], dtype=float32)

In [21]:

xxx=ContinuousHMM2(2,1)
xxx.initialize(obs_dim=3)
print(xxx.means,xxx.covs)
xxx.baum_welch(obss)
xxx.means,xxx.covs

[[0.1 0.2 0.3]
 [0.4 0.5 0.6]] [[[1. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]]

 [[1. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]]]


(array([[-0.01992633, -0.47433971, -0.27919513],
        [ 0.33300028,  0.40493449,  0.06924713]]),
 array([[[ 0.27247882, -0.11831573,  0.02044858],
         [-0.11831573,  1.3551286 , -0.02501094],
         [ 0.02044858, -0.02501094,  1.0493927 ]],
 
        [[ 0.5349582 ,  0.53679246, -0.27532047],
         [ 0.53679246,  2.0981014 , -0.993038  ],
         [-0.27532047, -0.993038  ,  1.2465011 ]]], dtype=float32))

In [22]:
obss.shape[1]

3

In [9]:
xx.means

Array([[ 1.2280328,  1.0541315, -1.2647258],
       [ 1.2280377,  1.0541357, -1.2647277]], dtype=float32)