In [114]:
import jax
import sys  
import jax.random as random
import jax.numpy as jnp
from jax.numpy import index_exp as index

sys.path.insert(0, '../')
from dibs.target import make_linear_gaussian_model, make_nonlinear_gaussian_model
from dibs.graph_utils import graph_to_mat #for graph sampling
from dibs.models import LinearGaussian #for parameter sampling

key = random.PRNGKey(123)
print(f"JAX backend: {jax.default_backend()}")

JAX backend: cpu


In [139]:
subkey

DeviceArray([ 961309823, 1704866707], dtype=uint32)

In [141]:
#try to see how to do the for loop and in which type to store the subkeys for sampling multiple parameters
""" seed = 1701
num_steps = 100
key = jax.random.PRNGKey(seed)
print(key)
l = jnp.arange(num_steps)
print(l)
print(type(l))
for i in range(num_steps):
    key, subkey = jax.random.split(key)
    print(i)
    print(key)
    l = l.at[i].set(subkey)
    print(subkey)
l """

' seed = 1701\nnum_steps = 100\nkey = jax.random.PRNGKey(seed)\nprint(key)\nl = jnp.arange(num_steps)\nprint(l)\nprint(type(l))\nfor i in range(num_steps):\n    key, subkey = jax.random.split(key)\n    print(i)\n    print(key)\n    l = l.at[i].set(subkey)\n    print(subkey)\nl '

In [116]:
key, subk = random.split(key)

In [3]:
key, subk = random.split(key)
data, model = make_nonlinear_gaussian_model(key=subk, n_vars=20, graph_prior_str="sf")


### Step one is to sample a random Graph:

In [16]:
data.x #should have shape of 100*20 

100

In [18]:
from dibs.models.graph import ErdosReniDAGDistribution


In [30]:
g_gt = ErdosReniDAGDistribution(n_vars = 20).sample_G(key = subk, return_mat=False)

In [31]:
g_gt

<igraph.Graph at 0x7f75cc792e50>

In [65]:
g_gt.topological_sorting()

[6, 11, 12, 15, 17, 4, 10, 19, 8, 0, 5, 13, 7, 14, 9, 2, 1, 16, 3, 18]

In [80]:
len(g_gt.vs)

20

In [101]:
parent_edges = g_gt.incident(19, mode='in')
print(parent_edges)
parents = list(g_gt.es[e].source for e in parent_edges)

[9, 27]


In [102]:
parents

[6, 15]

In [36]:
g_gt_mat = jnp.array(graph_to_mat(g_gt)) #why do we need this while we have a adj matrix already from the sample_G function?

In [39]:
g_gt_mat # 20*20 shape since we have 20 nodes on the graph

DeviceArray([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
             [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0],
             [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
             [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
             [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0],
             [0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0,

### Step two is to sample the parameters: 

In [108]:
theta = LinearGaussian(graph_dist = ErdosReniDAGDistribution(
            n_vars=20, 
            n_edges_per_node=2)).sample_parameters(key=subk, n_vars = 20)

In [110]:
theta.shape

(20, 20)

In [56]:
theta[0].shape

#I understand why we have 20 different thetas (corresponding to each Xi i = 1,...,20) but why each is of dimension 20? Just to make the matrix mult possible with the 
#adj matrix? In the end we end up unsing only the ones that have connections to the parents. 

(20,)

### Step three: Sample observations

In [58]:
observations = LinearGaussian(graph_dist = ErdosReniDAGDistribution(
            n_vars=20, 
            n_edges_per_node=2)).sample_obs(key=subk, n_samples = 100, g = g_gt, theta = theta )

In [61]:
# we have 100 observations that each has dimension 20 
observations[9].shape

(20,)

In [94]:
 x = jnp.zeros((100, len(g_gt.vs)))

In [96]:
x.shape

(100, 20)

In [104]:
mean = x[:, jnp.array(parents)] @ theta[jnp.array(parents), 19] 

In [107]:
theta[jnp.array(parents), 19]

DeviceArray([ 3.2422197, -0.8611006], dtype=float32)

In [106]:
len(mean)

100

In [99]:
x.at[index[:, 0]].set(5)

DeviceArray([[5., 0., 0., ..., 0., 0., 0.],
             [5., 0., 0., ..., 0., 0., 0.],
             [5., 0., 0., ..., 0., 0., 0.],
             ...,
             [5., 0., 0., ..., 0., 0., 0.],
             [5., 0., 0., ..., 0., 0., 0.],
             [5., 0., 0., ..., 0., 0., 0.]], dtype=float32)