# Implementation of the dMMSB model (dynamic)

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax.scipy.special import logsumexp
from jax.nn import softmax

from jax import vmap, jit
from jax.tree_util import register_pytree_node_class
from functools import partial

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt


In [None]:
class dMMSB():
    def __init__(self, nodes, roles, timesteps, **kwargs):
        self.N = nodes
        self.K = roles
        self.T = timesteps

        key = kwargs.get('key', jax.random.PRNGKey(0))
        self.B = kwargs.get('B', None)
        self.mu = kwargs.get('mu', None)
        self.Sigma = kwargs.get('Sigma', None)
        self.gamma_tilde = kwargs.get('gamma_tilde', None)
        self.Sigma_tilde = kwargs.get('Sigma_tilde', None)
        self.nu = kwargs.get('nu', None)
        self.Phi = kwargs.get('Phi', None)

        # Initialize model parameters 
        if self.B is None:
            self.B = jax.random.uniform(key, (self.K, self.K)) #shape (K,K)
        if self.nu is None:
            self.nu = jax.random.normal(key) # scalar
        if self.mu is None:
            self.mu = jnp.ones(self.T) * self.nu #shape (T,)
        if self.Phi is None:
            self.Phi = jnp.eye(self.K) *  10 #shape (K,K)
        if self.Sigma is None:
            self.Sigma = jnp.tile(jnp.eye(self.K)[None, :, :], (self.T, 1, 1)) * 10 #shape (T,K,K)

        self.EPS = 1e-6
        
    def log_likelihood(self, delta, B, E):
        '''
        Compute the log likelihood of the data given the current parameters.
        delta: shape (T,N,N,K,K)
        B: shape (K,K)
        E: shape (T,N,N)
        '''
        E_reshaped = E[:, :, :, None, None] # shape (T,N,N,1,1)
        B_reshaped = B[None, None, None, :, :] # shape (1,1,1,K,K)
        logB = jnp.log(B_reshaped + self.EPS) # shape (1,1,1,K,K)
        log1mB = jnp.log(1.0 - B_reshaped + self.EPS) # shape (1,1,1,K,K)

        ll_matrix = delta * (E_reshaped * logB + (1.0 - E_reshaped) * log1mB) # shape (T,N,N,K,K)
        ll = jnp.sum(ll_matrix)
        return ll
    

    def update_gamma_tilde_temporal(self, mu, Sigma_tilde, gamma_hat, g, H):
        '''
        Update gamma_tilde using the current parameters.
        mu: shape (T,)
        Sigma_tilde: shape (T,K,K)
        gamma_hat: shape (N,K)
        g: shape (N,K)
        H: shape (N,K,K)
        '''
        #vmap over time
        pass

    def 
   
    

    



SyntaxError: invalid syntax (2068489685.py, line 21)