# Implementation of the dMMSB model (dynamic)

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax.scipy.special import logsumexp
from jax.nn import softmax

from jax import vmap, jit
from jax.tree_util import register_pytree_node_class
from functools import partial

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt


In [None]:
class dMMSB():
    def __init__(self, nodes, roles, timesteps, **kwargs):
        self.N = nodes
        self.K = roles
        self.T = timesteps

        key = kwargs.get('key', jax.random.PRNGKey(0))
        self.B = kwargs.get('B', None)
        self.mu = kwargs.get('mu', None)
        self.Sigma = kwargs.get('Sigma', None)
        self.gamma_tilde = kwargs.get('gamma_tilde', None)
        self.Sigma_tilde = kwargs.get('Sigma_tilde', None)
        self.nu = kwargs.get('nu', None)
        self.Phi = kwargs.get('Phi', None)

        # Initialize model parameters 
        if self.B is None:
            self.B = jax.random.uniform(key, (self.K, self.K)) #shape (K,K)
        if self.nu is None:
            self.nu = jax.random.normal(key) # scalar
        if self.mu is None:
            self.mu = jnp.ones(self.T) * self.nu #shape (T,)
        if self.Phi is None:
            self.Phi = jnp.eye(self.K) *  10 #shape (K,K)
        if self.Sigma is None:
            self.Sigma = jnp.tile(jnp.eye(self.K)[None, :, :], (self.T, 1, 1)) * 10 #shape (T,K,K)

        self.EPS = 1e-6
        
    def log_likelihood(self, delta, B, E):
        '''
        Compute the log likelihood of the data given the current parameters.
        delta: shape (T,N,N,K,K)
        B: shape (K,K)
        E: shape (T,N,N)
        '''
        E_reshaped = E[:, :, :, None, None] # shape (T,N,N,1,1)
        B_reshaped = B[None, None, None, :, :] # shape (1,1,1,K,K)
        logB = jnp.log(B_reshaped + self.EPS) # shape (1,1,1,K,K)
        log1mB = jnp.log(1.0 - B_reshaped + self.EPS) # shape (1,1,1,K,K)

        ll_matrix = delta * (E_reshaped * logB + (1.0 - E_reshaped) * log1mB) # shape (T,N,N,K,K)
        ll = jnp.sum(ll_matrix)
        return ll
    

    def update_gamma_tilde_temporal(self, mu, Sigma_tilde, gamma_hat, g, H):
        '''
        Update gamma_tilde using the current parameters.
        mu: shape (T,)
        Sigma_tilde: shape (T,K,K)
        gamma_hat: shape (N,K)
        g: shape (N,K)
        H: shape (N,K,K)
        '''
        #vmap over time
        pass


    def update_mu_scan(mu, P, Y, Sigma, Phi, N):
        '''
        Update mu using Kalman filter and RTS smoother with jax.lax.scan.
        mu: shape (T,K) 
        P: shape (T,K,K)
        Y: shape (T,K)
        Sigma: shape (T,K,K) 
        Phi: shape (K,K) 
        N: scalar, number of nodes
        '''
        T = mu.shape[0]

        # --- 1. Kalman Filter (Forward Pass) ---
        def kalman_step(carry, inputs):
            mu_prev, P_prev = carry
            Y_t, Sigma_t = inputs

            # Prediction step
            mu_pred_t = mu_prev
            P_pred_t = P_prev + Phi

            # Update step
            #K_t = P_pred_t @ jnp.linalg.inv(P_pred_t + Sigma_t/N)
            tmp = P_pred_t + Sigma_t/N
            K_t = jnp.linalg.solve(tmp.T, P_pred_t.T).T  #numerically stable version
            mu_t = mu_pred_t + K_t @ (Y_t - mu_pred_t)
            P_t = P_pred_t - K_t @ P_pred_t

            new_carry = (mu_t, P_t)
            # Stack filtered and predicted states for the backward pass
            outputs_to_stack = (mu_t, P_t, mu_pred_t, P_pred_t)
            return new_carry, outputs_to_stack

        init_carry = (mu[0], P[0])
        inputs = (Y[1:], Sigma[1:])
        _, (mu_filtered_scanned, P_filtered_scanned, mu_pred_scanned, P_pred_scanned) = jax.lax.scan(
            kalman_step, init_carry, inputs, unroll=True
        )

        # Combine initial state with scanned results
        mu_filtered = jnp.concatenate([mu[0][None, :], mu_filtered_scanned], axis=0)
        P_filtered = jnp.concatenate([P[0][None, :, :], P_filtered_scanned], axis=0)
        # The prediction for time t is mu_{t-1}, so mu_pred starts from mu_0
        mu_pred = jnp.concatenate([mu[0][None, :], mu_pred_scanned], axis=0)
        P_pred = jnp.concatenate([P[0][None, :, :], P_pred_scanned], axis=0)


        # --- 2. RTS Smoother (Backward Pass) ---
        def rts_smoother_step(carry, inputs):
            mu_smooth_next, P_smooth_next = carry
            mu_filtered_t, P_filtered_t, mu_pred_next, P_pred_next = inputs

            #L_t = P_filtered_t @ jnp.linalg.inv(P_pred_next)
            L_t = jnp.linalg.solve(P_pred_next.T, P_filtered_t.T).T  # numerically stable version
            
            # Update step
            mu_smooth_t = mu_filtered_t + L_t @ (mu_smooth_next - mu_pred_next)
            P_smooth_t = P_filtered_t + L_t @ (P_smooth_next - P_pred_next) @ L_t.T

            new_carry = (mu_smooth_t, P_smooth_t)
            outputs_to_stack = (mu_smooth_t, P_smooth_t)
            return new_carry, outputs_to_stack

        init_carry_smooth = (mu_filtered[-1], P_filtered[-1])
        inputs_smooth = (mu_filtered[:-1], P_filtered[:-1], mu_pred[1:], P_pred[1:])
        
        _, (mu_smooth_scanned, P_smooth_scanned) = jax.lax.scan(
            rts_smoother_step, init_carry_smooth, inputs_smooth, reverse=True, unroll=True
        )

        mu_smooth = jnp.concatenate([mu_smooth_scanned, mu_filtered[-1][None, :]], axis=0)
        P_smooth = jnp.concatenate([P_smooth_scanned, P_filtered[-1][None, :, :]], axis=0)

        return mu_smooth, P_smooth



   
    

    



In [None]:
# def update_mu(mu, P,  Y, Sigma, Phi):
#         '''
#         Update mu using Kalman filter and RTS smoother. eq (14 .. 17)
#         mu: shape (T,)
#         P: shape (T,K,K)
#         Y: shape (T,K)
#         Sigma: shape (T,K,K)
#         Phi: shape (K,K)
#         '''
#         #Initialize arrays
#         T = mu.shape[0]
#         roles = Phi.shape[0]
#         mu_pred = jnp.zeros((T, roles))
#         P_pred = jnp.zeros((T, roles, roles))
#         K = jnp.zeros((T, roles, roles))
#         L = jnp.zeros((T, roles, roles))
#         mu_smooth = jnp.zeros((T, roles))
#         P_smooth = jnp.zeros((T, roles, roles))

#         mu_pred[0] = mu[0]
#         P_pred[0] = P[0]

#         #kalman filter (forward pass)
#         for t in range(1, T):
#             mu_pred[t] = mu[t-1]
#             P_pred[t] = P[t-1] + Phi
#             K[t] = P_pred[t] @ jnp.linalg.inv(P_pred[t] + Sigma[t])
#             mu[t] = mu_pred[t] + K[t] @ (Y[t]- mu_pred[t])
#             P[t] = P_pred[t] - K[t] @ P_pred[t]
        
#         #RTS smoother (backward pass)
#         mu_smooth[-1] = mu[-1]
#         P_smooth[-1] = P[-1]
#         for t in range(T-2, -1, -1):
#             L[t] = P[t] @ jnp.linalg.inv(P_pred[t+1])
#             mu_smooth[t] = mu[t] + L[t] @ (mu_smooth[t+1] - mu_pred[t+1])
#             P_smooth[t] = P[t] + L[t] @ (P_smooth[t+1] - P_pred[t+1]) @ L[t].T

#         return mu_smooth, P_smooth

def update_mu(mu, P, Y, Sigma, Phi, N):
    '''
    Update mu using Kalman filter and RTS smoother. eq (14 .. 17)
    This version uses the JAX functional update syntax: array.at[...].set(...)
    mu: shape (T,K)
    P: shape (T,K,K)
    Y: shape (T,K)
    Sigma: shape (T,K,K)
    Phi: shape (K,K)
    N: scalar, number of nodes
    '''
    T = mu.shape[0]
    K = Phi.shape[0]

    # --- 1. Kalman Filter (Forward Pass) ---
    # Initialize arrays to store results
    mu_filtered = jnp.zeros_like(mu)
    P_filtered = jnp.zeros_like(P)
    mu_pred = jnp.zeros_like(mu)
    P_pred = jnp.zeros_like(P)

    # Set initial conditions for t=0
    mu_filtered = mu_filtered.at[0].set(mu[0])
    P_filtered = P_filtered.at[0].set(P[0])

    # Loop from t=1 to T-1
    for t in range(1, T):
        # Prediction step
        mu_pred_t = mu_filtered[t-1]
        P_pred_t = P_filtered[t-1] + Phi

        # Store predictions
        mu_pred = mu_pred.at[t].set(mu_pred_t)
        P_pred = P_pred.at[t].set(P_pred_t)

        # Update step
        #K_t = P_pred_t @ jnp.linalg.inv(P_pred_t + Sigma[t]/N)
        tmp = P_pred_t + Sigma[t]/N
        K_t = jnp.linalg.solve(tmp.T, P_pred_t.T).T  # More stable
        mu_t = mu_pred_t + K_t @ (Y[t] - mu_pred_t)
        P_t = P_pred_t - K_t @ P_pred_t

        # Store filtered results
        mu_filtered = mu_filtered.at[t].set(mu_t)
        P_filtered = P_filtered.at[t].set(P_t)

    # --- 2. RTS Smoother (Backward Pass) ---
    # Initialize arrays for smoothed results
    mu_smooth = jnp.zeros_like(mu)
    P_smooth = jnp.zeros_like(P)

    # Set initial conditions for t=T-1
    mu_smooth = mu_smooth.at[-1].set(mu_filtered[-1])
    P_smooth = P_smooth.at[-1].set(P_filtered[-1])

    # Loop from t=T-2 down to 0
    for t in range(T - 2, -1, -1):
        # Smoother gain
        #L_t = P_filtered[t] @ jnp.linalg.inv(P_pred[t+1])
        L_t = jnp.linalg.solve(P_pred[t+1].T, P_filtered[t].T).T  # More stable
        
        # Update step
        mu_smooth_t = mu_filtered[t] + L_t @ (mu_smooth[t+1] - mu_pred[t+1])
        P_smooth_t = P_filtered[t] + L_t @ (P_smooth[t+1] - P_pred[t+1]) @ L_t.T

        # Store smoothed results
        mu_smooth = mu_smooth.at[t].set(mu_smooth_t)
        P_smooth = P_smooth.at[t].set(P_smooth_t)

    return mu_smooth, P_smooth

def update_mu_scan(mu, P, Y, Sigma, Phi, N):
    '''
    Update mu using Kalman filter and RTS smoother with jax.lax.scan.
    mu: shape (T,K) 
    P: shape (T,K,K)
    Y: shape (T,K)
    Sigma: shape (T,K,K) 
    Phi: shape (K,K) 
    N: scalar, number of nodes
    '''
    T = mu.shape[0]

    # --- 1. Kalman Filter (Forward Pass) ---
    def kalman_step(carry, inputs):
        mu_prev, P_prev = carry
        Y_t, Sigma_t = inputs

        # Prediction step
        mu_pred_t = mu_prev
        P_pred_t = P_prev + Phi

        # Update step
        #K_t = P_pred_t @ jnp.linalg.inv(P_pred_t + Sigma_t/N)
        tmp = P_pred_t + Sigma_t/N
        K_t = jnp.linalg.solve(tmp.T, P_pred_t.T).T  #numerically stable version
        mu_t = mu_pred_t + K_t @ (Y_t - mu_pred_t)
        P_t = P_pred_t - K_t @ P_pred_t

        new_carry = (mu_t, P_t)
        # Stack filtered and predicted states for the backward pass
        outputs_to_stack = (mu_t, P_t, mu_pred_t, P_pred_t)
        return new_carry, outputs_to_stack

    init_carry = (mu[0], P[0])
    inputs = (Y[1:], Sigma[1:])
    _, (mu_filtered_scanned, P_filtered_scanned, mu_pred_scanned, P_pred_scanned) = jax.lax.scan(
        kalman_step, init_carry, inputs, unroll=True
    )

    # Combine initial state with scanned results
    mu_filtered = jnp.concatenate([mu[0][None, :], mu_filtered_scanned], axis=0)
    P_filtered = jnp.concatenate([P[0][None, :, :], P_filtered_scanned], axis=0)
    # The prediction for time t is mu_{t-1}, so mu_pred starts from mu_0
    mu_pred = jnp.concatenate([mu[0][None, :], mu_pred_scanned], axis=0)
    P_pred = jnp.concatenate([P[0][None, :, :], P_pred_scanned], axis=0)


    # --- 2. RTS Smoother (Backward Pass) ---
    def rts_smoother_step(carry, inputs):
        mu_smooth_next, P_smooth_next = carry
        mu_filtered_t, P_filtered_t, mu_pred_next, P_pred_next = inputs

        #L_t = P_filtered_t @ jnp.linalg.inv(P_pred_next)
        L_t = jnp.linalg.solve(P_pred_next.T, P_filtered_t.T).T  # numerically stable version
        
        # Update step
        mu_smooth_t = mu_filtered_t + L_t @ (mu_smooth_next - mu_pred_next)
        P_smooth_t = P_filtered_t + L_t @ (P_smooth_next - P_pred_next) @ L_t.T

        new_carry = (mu_smooth_t, P_smooth_t)
        outputs_to_stack = (mu_smooth_t, P_smooth_t)
        return new_carry, outputs_to_stack

    init_carry_smooth = (mu_filtered[-1], P_filtered[-1])
    inputs_smooth = (mu_filtered[:-1], P_filtered[:-1], mu_pred[1:], P_pred[1:])
    
    _, (mu_smooth_scanned, P_smooth_scanned) = jax.lax.scan(
        rts_smoother_step, init_carry_smooth, inputs_smooth, reverse=True, unroll=True
    )

    mu_smooth = jnp.concatenate([mu_smooth_scanned, mu_filtered[-1][None, :]], axis=0)
    P_smooth = jnp.concatenate([P_smooth_scanned, P_filtered[-1][None, :, :]], axis=0)

    return mu_smooth, P_smooth



In [41]:
def test_kalman_smoother_equivalence(seed=0):
    '''
    Tests that the for-loop and jax.lax.scan implementations of the 
    Kalman filter and RTS smoother produce the same results.
    '''
    # 1. Set up test parameters
    key = jax.random.PRNGKey(seed)
    T = 15  # Timesteps
    K = 5   # Roles
    N = 100  # Nodes

    # 2. Generate random valid inputs
    key, subkey = jax.random.split(key)
    mu_init = jax.random.normal(subkey, (T, K))
    
    key, subkey = jax.random.split(key)
    # Ensure P is positive semi-definite
    P_init_rand = jax.random.normal(subkey, (T, K, K))
    P_init = P_init_rand @ jnp.transpose(P_init_rand, (0, 2, 1)) + jnp.eye(K) * 1e-3

    key, subkey = jax.random.split(key)
    Y = jax.random.normal(subkey, (T, K))

    key, subkey = jax.random.split(key)
    # Ensure Sigma is positive semi-definite
    Sigma_rand = jax.random.normal(subkey, (T, K, K))
    Sigma = Sigma_rand @ jnp.transpose(Sigma_rand, (0, 2, 1)) + jnp.eye(K) * 1e-3

    key, subkey = jax.random.split(key)
    # Ensure Phi is positive semi-definite
    Phi_rand = jax.random.normal(subkey, (K, K))
    Phi = Phi_rand @ Phi_rand.T + jnp.eye(K) * 1e-3

    # 3. Run both implementations
    # NOTE: The original loop-based function modifies mu and P in-place,
    # so we must pass copies to ensure a fair comparison.
    mu_loop, P_loop = update_mu(mu_init.copy(), P_init.copy(), Y, Sigma, Phi, N)
    mu_scan, P_scan = update_mu_scan(mu_init.copy(), P_init.copy(), Y, Sigma, Phi, N)

    # 4. Compare results
    mu_are_close = jnp.allclose(mu_loop, mu_scan, atol=1e-5)
    P_are_close = jnp.allclose(P_loop, P_scan, atol=1e-5)

    print(f"Test for Kalman/RTS smoother equivalence:")
    print(f"Smoothed mu results are close: {mu_are_close}")
    print(f"Smoothed P results are close: {P_are_close}")

    if not mu_are_close:
        print("Difference in mu:\n", jnp.abs(mu_loop - mu_scan).max())
    if not P_are_close:
        print("Difference in P:\n", jnp.abs(P_loop - P_scan).max())

# Run the test
for i in range(5):
    test_kalman_smoother_equivalence(100+i)

Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: False
Smoothed P results are close: True
Difference in mu:
 1.7166138e-05
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: True
Smoothed P results are close: True
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: False
Smoothed P results are close: True
Difference in mu:
 3.671646e-05
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: False
Smoothed P results are close: True
Difference in mu:
 2.2120774e-05
Test for Kalman/RTS smoother equivalence:
Smoothed mu results are close: True
Smoothed P results are close: True
