# Implementation of the dMMSB model (dynamic)

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax.scipy.special import logsumexp
from jax.nn import softmax

from jax import vmap, jit
from jax.tree_util import register_pytree_node_class
from functools import partial

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt


In [None]:
class dMMSB():
    def __init__(self, nodes, roles, timesteps, **kwargs):
        self.N = nodes
        self.K = roles
        self.T = timesteps

        key = kwargs.get('key', jax.random.PRNGKey(0))
        self.B = kwargs.get('B', None)
        self.mu = kwargs.get('mu', None)
        self.Sigma = kwargs.get('Sigma', None)
        self.gamma_tilde = kwargs.get('gamma_tilde', None)
        self.Sigma_tilde = kwargs.get('Sigma_tilde', None)
        self.nu = kwargs.get('nu', None)
        self.Phi = kwargs.get('Phi', None)

        # Initialize model parameters 
        if self.B is None:
            self.B = jax.random.uniform(key, (self.K, self.K)) #shape (K,K)
        if self.nu is None:
            self.nu = jax.random.normal(key) # scalar
        if self.mu is None:
            self.mu = jnp.ones(self.T) * self.nu #shape (T,)
        if self.Phi is None:
            self.Phi = jnp.eye(self.K) *  10 #shape (K,K)
        if self.Sigma is None:
            self.Sigma = jnp.tile(jnp.eye(self.K)[None, :, :], (self.T, 1, 1)) * 10 #shape (T,K,K)

        #KF RTS variables
        self.P = jnp.tile(jnp.eye(self.K)[None, :, :], (self.T, 1, 1)) * 20 #shape (T,K,K) | NOTE: large initial covariance
        self.Y = None
        self.L = None



        self.EPS = 1e-6
        
    def log_likelihood(self, delta, B, E):
        '''
        Compute the log likelihood of the data given the current parameters. eq (25)
        delta: shape (T,N,N,K,K)
        B: shape (K,K)
        E: shape (T,N,N)
        '''
        E_reshaped = E[:, :, :, None, None] # shape (T,N,N,1,1)
        B_reshaped = B[None, None, None, :, :] # shape (1,1,1,K,K)
        logB = jnp.log(B_reshaped + self.EPS) # shape (1,1,1,K,K)
        log1mB = jnp.log(1.0 - B_reshaped + self.EPS) # shape (1,1,1,K,K)

        ll_matrix = delta * (E_reshaped * logB + (1.0 - E_reshaped) * log1mB) # shape (T,N,N,K,K)
        ll = jnp.sum(ll_matrix)
        return ll
    
    #--------------------------------------------------------------
    # Inner Loop functions (From static version)
    #--------------------------------------------------------------

    def _compute_deltas(self, gamma_tilde, E):
        '''
        Compute delta matrices for all pairs (i,j) given current gamma_tilde and B
        gamma_tilde: (N,K)
        E: adjacency matrix (N,N)
        Returns: delta (N,N,K,K)
        '''
        
        gamma_i = gamma_tilde[:, None, :, None] # shape (N,1,K,1)
        gamma_j = gamma_tilde[None, :, None, :] # shape (1,N,1,K)

        gamma_sum = gamma_i + gamma_j # shape (N,N,K,K)

        #print("gamma sum is finite:", jnp.isfinite(gamma_sum).all())
        B_reshaped = self.B[None, None, :, :] # shape (1,1,K,K)
        E_reshaped = E[:, :, None, None] # shape (N,N,1,1)

        bernoulli_term = jnp.where(E_reshaped == 1, B_reshaped, 1 - B_reshaped) # shape (N,N,K,K)

        delta_exp_term = gamma_sum + jnp.log(bernoulli_term + self.EPS) # shape (N,N,K,K)
        #print("delta exp term is finite:", jnp.isfinite(delta_exp_term).all())
        max_delta_exp = jnp.max(delta_exp_term, axis=(-1,-2), keepdims=True)
        delta = jnp.exp(delta_exp_term - (max_delta_exp + logsumexp(delta_exp_term - max_delta_exp, axis=(-1,-2), keepdims=True))) # shape (N,N,K,K) logsumexp trick for numerical stability
        return delta # shape (N,N,K,K)
    
    def _compute_g_H(self, gamma_hat, K):
        '''
        Compute g and H at gamma_hat
        gamma_hat: (N, K)
        K: scalar, number of roles
        Returns: g: (N, K), H: (N, K, K)
        '''
        #g = jnp.exp(gamma_hat) / jnp.sum(jnp.exp(gamma_hat), axis=-1, keepdims=True) # shape (N,K)
        max_gamma = jnp.max(gamma_hat, axis=-1, keepdims=True)
        g = jnp.exp(gamma_hat - (max_gamma + logsumexp(gamma_hat - max_gamma, axis=-1, keepdims=True))) # shape (N,K) logsumexp trick for numerical stability
        H = jnp.einsum('ni,ij->nij', g, jnp.eye(K)) - jnp.einsum('ni,nj->nij', g, g) # shape (N,K,K)
        # print("g is finite:", jnp.isfinite(g).all())
        # print("H is finite:", jnp.isfinite(H).all())
        return g, H

    def _update_sigma_tilde(self, Sigma_inv, H, N):
        '''
        Compute Sigma_tilde = (Sigma^{-1} + (2N-2) H)^{-1}
        Sigma_inv: (K,K)
        H: (N,K,K) Hessian at gamma_hat
        N: scalar, number of nodes
        Returns: Sigma_tilde: (N,K,K)

        '''
        factor = 2.0 * N - 2.0
        A = Sigma_inv[None, :, :] + factor * H # shape (N,K,K)
        #jitter = 1e-6 * jnp.eye(self.K)
        #A = A + jitter[None, :, :]
        Sigma_tilde = jnp.linalg.inv(A)
        return Sigma_tilde # shape (N,K,K)

    def _compute_m_expect(self, delta):
        '''
        Compute m_expect per node: m_i,k = sum_{j != i} (E[z_i->j,k] + E[z_i<-j,k])
        delta: (N,N,K,K)
        Returns: m_expect: (N,K)
        '''
        z_ij = jnp.sum(delta, axis=-1) # shape (N,N,K) Expected z_i->j (sender)
        z_ji = jnp.sum(delta, axis=-2) # shape (N,N,K) Expected z_i<-j (receiver)

        z_ij_expected = jnp.sum(z_ij, axis=1) # shape (N,K)
        z_ji_expected = jnp.sum(z_ji, axis=0) # shape (N,K)

        z_sum = z_ij_expected + z_ji_expected# shape (N,K)

        diag_ij = jnp.diagonal(z_ij, axis1=0, axis2=1).T # shape(N,K)
        diag_ji = jnp.diagonal(z_ji, axis1=0, axis2=1).T # shape(N,K)
    
        m_expect = z_sum - diag_ij - diag_ji # shape (N,K)
      

        return m_expect # shape (N,K)
   
    def _update_gamma_tilde(self, delta, mu, Sigma_tilde, gamma_hat, g , H, N):
        '''
        Update gamma_tilde using Laplace approximation
        delta: (N,N,K,K)
        mu: (K,)
        Sigma_tilde: (N,K,K)
        gamma_hat: (N,K)
        m_expect: (N,K)
        Returns: gamma_tilde: (N,K), Sigma_tilde: (N,K,K)
        '''
        # g, H = self.compute_g_H(gamma_hat) # g: (N,K), H: (N,K,K)
        factor = 2.0 * N - 2.0 #scalar

        m_expect = self._compute_m_expect(delta) # shape (N,K)

        term_1 = m_expect - factor * g + factor * jnp.einsum('nij,nj->ni', H, gamma_hat) - factor * jnp.einsum('nij,j->ni', H, mu) # shape (N,K)

        gamma_tilde = mu[None, :] + jnp.einsum('nij,nj->ni', Sigma_tilde, term_1) # shape (N,K)

        return gamma_tilde # shape (N,K)

    def inner_step_static(self, gamma_tilde, Sigma_tilde, mu, Sigma_inv, E):
        '''
        Perform one inner iteration to update gamma_tilde and Sigma_tilde
        gamma_tilde: (N,K)
        Sigma_tilde: (N,K,K)
        mu: (K,)
        Sigma_inv: (K,K)
        E: adjacency matrix (N,N)
        Returns: updated gamma_tilde, Sigma_tilde, delta
        '''
    
        delta = self._compute_deltas(gamma_tilde, E) # shape (N,N,K,K)
        g, H = self._compute_g_H(gamma_tilde, self.K) # g: (N,K), H: (N,K,K)
        Sigma_tilde = self._update_sigma_tilde(Sigma_inv, H, self.N) # shape (N,K,K)
        gamma_tilde = self._update_gamma_tilde(mu, Sigma_tilde, gamma_tilde, g , H) # shape (N,K)
        return gamma_tilde, Sigma_tilde, delta

    #--------------------------------------------------------------
    
    def inner_step(self, gamma_tilde, Sigma_tilde, mu, Sigma_inv, E):
        '''
        Perform one inner iteration to update gamma_tilde and Sigma_tilde for all time steps using vmap
        gamma_tilde: (T,N,K)
        Sigma_tilde: (T,N,K,K)
        mu: (T,K)
        Sigma_inv: (T,K,K)
        E: adjacency matrix (T,N,N)
        Returns: updated 
        gamma_tilde (T,N,K)
        Sigma_tilde (T,N,K,K)
        delta (T,N,N,K,K)
        Returns updated gamma_tilde, Sigma_tilde, delta
        '''

        gamma_tilde, Sigma_tilde, delta = vmap(self.inner_step_static, in_axes=(0,0,0,0,0))(gamma_tilde, Sigma_tilde, mu, Sigma_inv, E)
        return gamma_tilde, Sigma_tilde, delta

        

    def update_mu(self, mu, P, Y, Sigma, Phi, N):
        '''
        Update mu using Kalman filter and RTS smoother with jax.lax.scan. eq (14,15,16,17)
        mu: shape (T,K) 
        P: shape (T,K,K)
        Y: shape (T,K)
        Sigma: shape (T,K,K) 
        Phi: shape (K,K) 
        N: scalar, number of nodes
        '''
        T = mu.shape[0]

        # --- 1. Kalman Filter (Forward Pass) ---
        def kalman_step(carry, inputs):
            mu_prev, P_prev = carry
            Y_t, Sigma_t = inputs

            # Prediction step
            mu_pred_t = mu_prev
            P_pred_t = P_prev + Phi

            # Update step
            #K_t = P_pred_t @ jnp.linalg.inv(P_pred_t + Sigma_t/N)
            tmp = P_pred_t + Sigma_t/N
            K_t = jnp.linalg.solve(tmp.T, P_pred_t.T).T  #numerically stable version
            mu_t = mu_pred_t + K_t @ (Y_t - mu_pred_t)
            P_t = P_pred_t - K_t @ P_pred_t

            new_carry = (mu_t, P_t)
            # Stack filtered and predicted states for the backward pass
            outputs_to_stack = (mu_t, P_t, mu_pred_t, P_pred_t)
            return new_carry, outputs_to_stack

        init_carry = (mu[0], P[0])
        inputs = (Y[1:], Sigma[1:])
        _, (mu_filtered_scanned, P_filtered_scanned, mu_pred_scanned, P_pred_scanned) = jax.lax.scan(
            kalman_step, init_carry, inputs, unroll=True
        )

        # Combine initial state with scanned results
        mu_filtered = jnp.concatenate([mu[0][None, :], mu_filtered_scanned], axis=0)
        P_filtered = jnp.concatenate([P[0][None, :, :], P_filtered_scanned], axis=0)
        # The prediction for time t is mu_{t-1}, so mu_pred starts from mu_0
        mu_pred = jnp.concatenate([mu[0][None, :], mu_pred_scanned], axis=0)
        P_pred = jnp.concatenate([P[0][None, :, :], P_pred_scanned], axis=0)


        # --- 2. RTS Smoother (Backward Pass) ---
        def rts_smoother_step(carry, inputs):
            mu_smooth_next, P_smooth_next = carry
            mu_filtered_t, P_filtered_t, mu_pred_next, P_pred_next = inputs

            #L_t = P_filtered_t @ jnp.linalg.inv(P_pred_next)
            L_t = jnp.linalg.solve(P_pred_next.T, P_filtered_t.T).T  # numerically stable version
            
            # Update step
            mu_smooth_t = mu_filtered_t + L_t @ (mu_smooth_next - mu_pred_next)
            P_smooth_t = P_filtered_t + L_t @ (P_smooth_next - P_pred_next) @ L_t.T

            new_carry = (mu_smooth_t, P_smooth_t)
            outputs_to_stack = (mu_smooth_t, P_smooth_t, L_t)
            return new_carry, outputs_to_stack

        init_carry_smooth = (mu_filtered[-1], P_filtered[-1])
        inputs_smooth = (mu_filtered[:-1], P_filtered[:-1], mu_pred[1:], P_pred[1:])
        
        _, (mu_smooth_scanned, P_smooth_scanned, L_scanned) = jax.lax.scan(
            rts_smoother_step, init_carry_smooth, inputs_smooth, reverse=True, unroll=True
        )

        mu_smooth = jnp.concatenate([mu_smooth_scanned, mu_filtered[-1][None, :]], axis=0)
        P_smooth = jnp.concatenate([P_smooth_scanned, P_filtered[-1][None, :, :]], axis=0)
        L = jnp.concatenate([L_scanned, jnp.zeros((1, P.shape[1], P.shape[2]))], axis=0)  # Last L is not defined

        return mu_smooth, P_smooth, L 

    def update_B(self, delta, E):
        '''
        Update B using the current parameters. eq (26)
        delta: shape (T,N,N,K,K)
        E: shape (T,N,N)
        '''
        E_reshaped = E[:, :, :, None, None] # shape (T,N,N,1,1)
        
        num = jnp.sum(delta * E_reshaped, axis=(0,1,2)) # shape (K,K)
        den = jnp.sum(delta, axis=(0,1,2)) # shape (K,K)

        B_new = (num + self.EPS) / (den + self.EPS)
        B_new = jnp.clip(B_new, 1e-6, 1.0 - 1e-6) 
        return B_new
    
    def update_Phi(self, mu, P, L):
        '''
        Update Phi using the current parameters. eq (19)
        mu: shape (T,K)
        P: shape (T,K,K)
        L: shape (T,K,K)
        '''
        sum = jnp.zeros((self.K, self.K))
        for t in range(0, self.T - 1):
            diff = mu[t+1] - mu[t]
            term1 = jnp.outer(diff, diff)
            term2 = L[t] @ P[t+1] @ L[t].T

            sum += term1 + term2

        Phi_new = sum / (self.T - 1)
        #NOTE: check if Phi is symmetric
        print("Phi_new symmetric:", jnp.allclose(Phi_new, Phi_new.T, atol=1e-6))
        Phi_new = Phi_new + jnp.eye(self.K) * self.EPS # Ensure positive definiteness

        return Phi_new

    def update_Sigma(self, mu, gamma_tilde, Sigma_tilde, N):
        '''
        Update Sigma_tilde using the current parameters. eq (20)
        mu: shape (T,K)
        gamma_tilde: shape (T,N,K)
        Sigma_tilde: shape (T,N,K,K)
        N: scalar, number of nodes
        '''
        
        diff = mu[:, None, :] - gamma_tilde # shape (T,N,K)
        
        sum_outer_products = jnp.einsum('tnk,tnj->tkj', diff, diff)  # shape (T,K,K) 

        sum_Sigma_tilde = jnp.sum(Sigma_tilde, axis=1)  # shape (T,K,K)

        Sigma_new = (sum_outer_products + sum_Sigma_tilde) / N  # shape (T,K,K)
    
        #NOTE: check if Sigma_new is symmetric
        print("Sigma_new symmetric:", jnp.allclose(Sigma_new, jnp.transpose(Sigma_new, (0,2,1)), atol=1e-6))

        Sigma_new = Sigma_new + jnp.eye(self.K)[None, :, :] * self.EPS  # Ensure positive definiteness
        return Sigma_new
    
    def update_nu(self, mu):
        '''
        Update nu using the current parameters. eq (21)
        mu: shape (T,)
        '''
        return mu[0]
    
    def fit(self, E, max_inner_iters=100, max_outer_iters=100, tol=1e-6, verbose=False):
        '''
        Fit the model to adjacency matrix E using variational EM
        Algorithm described in section 4.2 of the paper
        
        E: adjacency matrix (T,N,N)
        max_inner_iters: maximum iterations for inner loop
        max_outer_iters: maximum iterations for outer loop
        tol: tolerance for convergence
        verbose: whether to print progress
        '''
        
        i = 0 
        d_ll = jnp.inf
        prev_outer_ll = -jnp.inf
        while(d_ll > tol and i < max_outer_iters): # 2 (outer loop)
            if verbose:
                print(f"[outer {i}] mu: {self.mu}, Sigma diag: {jnp.diag(self.Sigma)}, B: {self.B}")

            #initialize q(gamma) parameters
            def init_q_gamma(key, mu_t, Sigma_t, N, K):
                '''
                Initialize q(gamma) and Sigma^-1 for a single time step.

                '''
                gamma_tilde = jax.random.multivariate_normal(key, mu_t, Sigma_t, shape=(N,)) # shape (N,K)

                g, H = self.compute_g_H(gamma_tilde, K) # g: (N,K), H: (N,K,K)

                jitter = self.EPS * jnp.eye(K) # for numerical stability
                Sigma_inv = jnp.linalg.inv(Sigma_t + jitter) # shape (K,K
                Sigma_tilde = self._update_sigma_tilde(Sigma_inv, H, N) # shape (N,K,K)

                return gamma_tilde, Sigma_tilde, Sigma_inv

            self.gamma_tilde, self.Sigma_tilde, Sigma_inv = vmap(init_q_gamma, in_axes=(None, 0, 0, None, None))(self.key, self.mu, self.Sigma, self.N, self.K) # shape (T,N,K), (T,N,K,K), (T,K,K)

            #NOTE:add multiple runs with different initializaitons and use of VMAP
            j = 0
            inner_d_ll = jnp.inf
            prev_inner_ll = -jnp.inf
            while(inner_d_ll > tol and j < max_inner_iters): # 2.2 inner loop
                #2.2.1 update q(gamma) and q(z) 
                self.gamma_tilde, self.Sigma_tilde, self.delta = self.inner_step(self.gamma_tilde, self.Sigma_tilde, self.mu, Sigma_inv, E)

                # 2.2.2 update B
                self.B = self.update_B(E, self.delta) # shape (K,K)

                #convergence check
                j += 1
                inner_ll = self.log_likelihood(self.delta, self.B, E) 
                inner_d_ll = jnp.abs(inner_ll - prev_inner_ll)
                #print("inner ll and prev:", inner_ll, prev_inner_ll)    
                prev_inner_ll = inner_ll

                if verbose:
                    print(f"  [inner {j}] ll: {inner_ll:.4f}, d_ll: {inner_d_ll:.6f}")

            # 2.3 RTS smoother to update mu and P
            self.Y = jnp.mean(self.gamma_tilde, axis=1) # shape (T,K)
            self.mu, self.P, self.L = self.update_mu(self.mu, self.P, self.Y, self.Sigma, self.Phi, self.N)

            # 2.4 update nu, Phi, Sigma
            self.nu = self.update_nu(self.mu) # scalar
            self.Phi = self.update_Phi(self.mu, self.P, self.L) # shape (K,K)
            self.Sigma = self.update_Sigma(self.mu, self.gamma_tilde, self.Sigma_tilde, self.N) # shape (T,K,K)


            #convergence check
            i += 1
            outer_ll = inner_ll #last inner ll is outer ll
            d_ll = jnp.abs(outer_ll - prev_outer_ll)
            prev_outer_ll = outer_ll

In [None]:
# def update_mu(mu, P,  Y, Sigma, Phi):
#         '''
#         Update mu using Kalman filter and RTS smoother. eq (14 .. 17)
#         mu: shape (T,)
#         P: shape (T,K,K)
#         Y: shape (T,K)
#         Sigma: shape (T,K,K)
#         Phi: shape (K,K)
#         '''
#         #Initialize arrays
#         T = mu.shape[0]
#         roles = Phi.shape[0]
#         mu_pred = jnp.zeros((T, roles))
#         P_pred = jnp.zeros((T, roles, roles))
#         K = jnp.zeros((T, roles, roles))
#         L = jnp.zeros((T, roles, roles))
#         mu_smooth = jnp.zeros((T, roles))
#         P_smooth = jnp.zeros((T, roles, roles))

#         mu_pred[0] = mu[0]
#         P_pred[0] = P[0]

#         #kalman filter (forward pass)
#         for t in range(1, T):
#             mu_pred[t] = mu[t-1]
#             P_pred[t] = P[t-1] + Phi
#             K[t] = P_pred[t] @ jnp.linalg.inv(P_pred[t] + Sigma[t])
#             mu[t] = mu_pred[t] + K[t] @ (Y[t]- mu_pred[t])
#             P[t] = P_pred[t] - K[t] @ P_pred[t]
        
#         #RTS smoother (backward pass)
#         mu_smooth[-1] = mu[-1]
#         P_smooth[-1] = P[-1]
#         for t in range(T-2, -1, -1):
#             L[t] = P[t] @ jnp.linalg.inv(P_pred[t+1])
#             mu_smooth[t] = mu[t] + L[t] @ (mu_smooth[t+1] - mu_pred[t+1])
#             P_smooth[t] = P[t] + L[t] @ (P_smooth[t+1] - P_pred[t+1]) @ L[t].T

#         return mu_smooth, P_smooth

def update_mu(mu, P, Y, Sigma, Phi, N):
    '''
    Update mu using Kalman filter and RTS smoother. eq (14 .. 17)
    This version uses the JAX functional update syntax: array.at[...].set(...)
    mu: shape (T,K)
    P: shape (T,K,K)
    Y: shape (T,K)
    Sigma: shape (T,K,K)
    Phi: shape (K,K)
    N: scalar, number of nodes
    '''
    T = mu.shape[0]
    K = Phi.shape[0]

    # --- 1. Kalman Filter (Forward Pass) ---
    # Initialize arrays to store results
    mu_filtered = jnp.zeros_like(mu)
    P_filtered = jnp.zeros_like(P)
    mu_pred = jnp.zeros_like(mu)
    P_pred = jnp.zeros_like(P)

    # Set initial conditions for t=0
    mu_filtered = mu_filtered.at[0].set(mu[0])
    P_filtered = P_filtered.at[0].set(P[0])

    # Loop from t=1 to T-1
    for t in range(1, T):
        # Prediction step
        mu_pred_t = mu_filtered[t-1]
        P_pred_t = P_filtered[t-1] + Phi

        # Store predictions
        mu_pred = mu_pred.at[t].set(mu_pred_t)
        P_pred = P_pred.at[t].set(P_pred_t)

        # Update step
        #K_t = P_pred_t @ jnp.linalg.inv(P_pred_t + Sigma[t]/N)
        tmp = P_pred_t + Sigma[t]/N
        K_t = jnp.linalg.solve(tmp.T, P_pred_t.T).T  # More stable
        mu_t = mu_pred_t + K_t @ (Y[t] - mu_pred_t)
        P_t = P_pred_t - K_t @ P_pred_t

        # Store filtered results
        mu_filtered = mu_filtered.at[t].set(mu_t)
        P_filtered = P_filtered.at[t].set(P_t)

    # --- 2. RTS Smoother (Backward Pass) ---
    # Initialize arrays for smoothed results
    mu_smooth = jnp.zeros_like(mu)
    P_smooth = jnp.zeros_like(P)

    # Set initial conditions for t=T-1
    mu_smooth = mu_smooth.at[-1].set(mu_filtered[-1])
    P_smooth = P_smooth.at[-1].set(P_filtered[-1])

    # Loop from t=T-2 down to 0
    for t in range(T - 2, -1, -1):
        # Smoother gain
        #L_t = P_filtered[t] @ jnp.linalg.inv(P_pred[t+1])
        L_t = jnp.linalg.solve(P_pred[t+1].T, P_filtered[t].T).T  # More stable
        
        # Update step
        mu_smooth_t = mu_filtered[t] + L_t @ (mu_smooth[t+1] - mu_pred[t+1])
        P_smooth_t = P_filtered[t] + L_t @ (P_smooth[t+1] - P_pred[t+1]) @ L_t.T

        # Store smoothed results
        mu_smooth = mu_smooth.at[t].set(mu_smooth_t)
        P_smooth = P_smooth.at[t].set(P_smooth_t)

    return mu_smooth, P_smooth

def update_mu_scan(mu, P, Y, Sigma, Phi, N):
    '''
    Update mu using Kalman filter and RTS smoother with jax.lax.scan.
    mu: shape (T,K) 
    P: shape (T,K,K)
    Y: shape (T,K)
    Sigma: shape (T,K,K) 
    Phi: shape (K,K) 
    N: scalar, number of nodes
    '''
    T = mu.shape[0]

    # --- 1. Kalman Filter (Forward Pass) ---
    def kalman_step(carry, inputs):
        mu_prev, P_prev = carry
        Y_t, Sigma_t = inputs

        # Prediction step
        mu_pred_t = mu_prev
        P_pred_t = P_prev + Phi

        # Update step
        #K_t = P_pred_t @ jnp.linalg.inv(P_pred_t + Sigma_t/N)
        tmp = P_pred_t + Sigma_t/N
        K_t = jnp.linalg.solve(tmp.T, P_pred_t.T).T  #numerically stable version
        mu_t = mu_pred_t + K_t @ (Y_t - mu_pred_t)
        P_t = P_pred_t - K_t @ P_pred_t

        new_carry = (mu_t, P_t)
        # Stack filtered and predicted states for the backward pass
        outputs_to_stack = (mu_t, P_t, mu_pred_t, P_pred_t)
        return new_carry, outputs_to_stack

    init_carry = (mu[0], P[0])
    inputs = (Y[1:], Sigma[1:])
    _, (mu_filtered_scanned, P_filtered_scanned, mu_pred_scanned, P_pred_scanned) = jax.lax.scan(
        kalman_step, init_carry, inputs, unroll=True
    )

    # Combine initial state with scanned results
    mu_filtered = jnp.concatenate([mu[0][None, :], mu_filtered_scanned], axis=0)
    P_filtered = jnp.concatenate([P[0][None, :, :], P_filtered_scanned], axis=0)
    # The prediction for time t is mu_{t-1}, so mu_pred starts from mu_0
    mu_pred = jnp.concatenate([mu[0][None, :], mu_pred_scanned], axis=0)
    P_pred = jnp.concatenate([P[0][None, :, :], P_pred_scanned], axis=0)


    # --- 2. RTS Smoother (Backward Pass) ---
    def rts_smoother_step(carry, inputs):
        mu_smooth_next, P_smooth_next = carry
        mu_filtered_t, P_filtered_t, mu_pred_next, P_pred_next = inputs

        #L_t = P_filtered_t @ jnp.linalg.inv(P_pred_next)
        L_t = jnp.linalg.solve(P_pred_next.T, P_filtered_t.T).T  # numerically stable version
        
        # Update step
        mu_smooth_t = mu_filtered_t + L_t @ (mu_smooth_next - mu_pred_next)
        P_smooth_t = P_filtered_t + L_t @ (P_smooth_next - P_pred_next) @ L_t.T

        new_carry = (mu_smooth_t, P_smooth_t)
        outputs_to_stack = (mu_smooth_t, P_smooth_t)
        return new_carry, outputs_to_stack

    init_carry_smooth = (mu_filtered[-1], P_filtered[-1])
    inputs_smooth = (mu_filtered[:-1], P_filtered[:-1], mu_pred[1:], P_pred[1:])
    
    _, (mu_smooth_scanned, P_smooth_scanned) = jax.lax.scan(
        rts_smoother_step, init_carry_smooth, inputs_smooth, reverse=True, unroll=True
    )

    mu_smooth = jnp.concatenate([mu_smooth_scanned, mu_filtered[-1][None, :]], axis=0)
    P_smooth = jnp.concatenate([P_smooth_scanned, P_filtered[-1][None, :, :]], axis=0)

    return mu_smooth, P_smooth



In [None]:
def test_kalman_smoother_equivalence(seed=0):
    '''
    Tests that the for-loop and jax.lax.scan implementations of the 
    Kalman filter and RTS smoother produce the same results.
    '''
    # 1. Set up test parameters
    key = jax.random.PRNGKey(seed)
    T = 15  # Timesteps
    K = 5   # Roles
    N = 100  # Nodes

    # 2. Generate random valid inputs
    key, subkey = jax.random.split(key)
    mu_init = jax.random.normal(subkey, (T, K))
    
    key, subkey = jax.random.split(key)
    # Ensure P is positive semi-definite
    P_init_rand = jax.random.normal(subkey, (T, K, K))
    P_init = P_init_rand @ jnp.transpose(P_init_rand, (0, 2, 1)) + jnp.eye(K) * 1e-3

    key, subkey = jax.random.split(key)
    Y = jax.random.normal(subkey, (T, K))

    key, subkey = jax.random.split(key)
    # Ensure Sigma is positive semi-definite
    Sigma_rand = jax.random.normal(subkey, (T, K, K))
    Sigma = Sigma_rand @ jnp.transpose(Sigma_rand, (0, 2, 1)) + jnp.eye(K) * 1e-3

    key, subkey = jax.random.split(key)
    # Ensure Phi is positive semi-definite
    Phi_rand = jax.random.normal(subkey, (K, K))
    Phi = Phi_rand @ Phi_rand.T + jnp.eye(K) * 1e-3

    # 3. Run both implementations
    # NOTE: The original loop-based function modifies mu and P in-place,
    # so we must pass copies to ensure a fair comparison.
    mu_loop, P_loop = update_mu(mu_init.copy(), P_init.copy(), Y, Sigma, Phi, N)
    mu_scan, P_scan = update_mu_scan(mu_init.copy(), P_init.copy(), Y, Sigma, Phi, N)

    # 4. Compare results
    mu_are_close = jnp.allclose(mu_loop, mu_scan, atol=1e-5)
    P_are_close = jnp.allclose(P_loop, P_scan, atol=1e-5)

    print(f"Test for Kalman/RTS smoother equivalence:")
    print(f"Smoothed mu results are close: {mu_are_close}")
    print(f"Smoothed P results are close: {P_are_close}")

    if not mu_are_close:
        print("Difference in mu:\n", jnp.abs(mu_loop - mu_scan).max())
    if not P_are_close:
        print("Difference in P:\n", jnp.abs(P_loop - P_scan).max())
    

# Run the test
for i in range(100):

    test_kalman_smoother_equivalence(100+i)

Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: False
Smoothed P results are close: True
Difference in mu:
 1.7166138e-05
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: True
Smoothed P results are close: True
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: False
Smoothed P results are close: True
Difference in mu:
 3.671646e-05
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: False
Smoothed P results are close: True
Difference in mu:
 2.2120774e-05
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: True
Smoothed P results are close: True
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: True
Smoothed P results are close: True
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: True
Smoothed P results are close: True
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: True
Smoothed P results are close: 

KeyboardInterrupt: 

In [55]:
key = jax.random.PRNGKey(0)
N = 10
T = 5
K = 3
EPS = 1e-6
mu = jnp.ones((T,K))
Sigma = jnp.tile(jnp.eye(K)[None, :, :], (T, 1, 1)) * 10 #shape (T,K,K)
key, subkey = jax.random.split(key)



def _compute_g_H(gamma_hat, K):
    '''
    Compute g and H at gamma_hat
    gamma_hat: (N, K)
    K: scalar, number of roles
    Returns: g: (N, K), H: (N, K, K)
    '''
    #g = jnp.exp(gamma_hat) / jnp.sum(jnp.exp(gamma_hat), axis=-1, keepdims=True) # shape (N,K)
    max_gamma = jnp.max(gamma_hat, axis=-1, keepdims=True)
    g = jnp.exp(gamma_hat - (max_gamma + logsumexp(gamma_hat - max_gamma, axis=-1, keepdims=True))) # shape (N,K) logsumexp trick for numerical stability
    H = jnp.einsum('ni,ij->nij', g, jnp.eye(K)) - jnp.einsum('ni,nj->nij', g, g) # shape (N,K,K)
    # print("g is finite:", jnp.isfinite(g).all())
    # print("H is finite:", jnp.isfinite(H).all())
    return g, H

def _update_sigma_tilde(Sigma_inv, H, N):
    '''
    Compute Sigma_tilde = (Sigma^{-1} + (2N-2) H)^{-1}
    Sigma_inv: (K,K)
    H: (N,K,K) Hessian at gamma_hat
    N: scalar, number of nodes
    Returns: Sigma_tilde: (N,K,K)

    '''
    factor = 2.0 * N - 2.0
    A = Sigma_inv[None, :, :] + factor * H # shape (N,K,K)
    #jitter = 1e-6 * jnp.eye(self.K)
    #A = A + jitter[None, :, :]
    Sigma_tilde = jnp.linalg.inv(A)
    return Sigma_tilde # shape (N,K,K)

def init_q_gamma(key, mu_t, Sigma_t, N):
                '''
                Initialize q(gamma) and Sigma^-1 for a single time step.

                '''
                gamma_tilde = jax.random.multivariate_normal(key, mu_t, Sigma_t, shape=(N,)) # shape (N,K)

                g, H = _compute_g_H(gamma_tilde, K) # g: (N,K), H: (N,K,K)

                jitter = EPS * jnp.eye(K) # for numerical stability
                Sigma_inv = jnp.linalg.inv(Sigma_t + jitter) # shape (K,K
                Sigma_tilde = _update_sigma_tilde(Sigma_inv, H, N) # shape (N,K,K)

                return gamma_tilde, Sigma_tilde, Sigma_inv

gamma_tilde, Sigma_tilde, Sigma_inv = vmap(init_q_gamma, in_axes=(None, 0, 0, None))(subkey, mu, Sigma, N) # shape (T,N,K), (T,N,K,K), (T,K,K)

print("gamma_tilde shape:", gamma_tilde.shape)
print("Sigma_tilde shape:", Sigma_tilde.shape)
print("Sigma_inv shape:", Sigma_inv.shape)

gamma_tilde shape: (5, 10, 3)
Sigma_tilde shape: (5, 10, 3, 3)
Sigma_inv shape: (5, 3, 3)


In [None]:
def sample_gamma(key, mu_t, Sigma_t, N):
    '''
    Helper function to sample gamma_tilde for a single time step
    '''
    return jax.random.multivariate_normal(key, mu_t, Sigma_t, shape=(N,)) # shape (N,K)

gamma_tilde = vmap(sample_gamma, in_axes=(None, 0, 0, None))(key, mu, Sigma, N) # shape (T,N,K)
print("gamma_tilde shape after vmap:", gamma_tilde.shape)

gamma_tilde shape after vmap: (5, 10, 3)
