## Sample shapes 
A library contaning numpyro models for future reference.  The model naming convention corresponds to models replicated in R using `brms` and `rethinking` and also in Python using `PyMC`.  

### Imports

In [16]:
from jax import random
import jax.numpy as jnp
import numpyro.distributions as dist

### Random notes looking at sample shape

The examples below demonstrate the difference between *sample shape* (that specified with the `sample_shape` parameter of the `sample` method) and *batch shape* (that specified with the `batch_shape` parameter of the `expand` method).

Note that batch shape signifies the dimension of XXXXX.

[Reference](https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/)

In [17]:
d = dist.Exponential(1) \
    .sample(random.PRNGKey(42), sample_shape=(3,))

print(d)
print("\n")
print(d.shape)
print("\n")
print(dist.Exponential(1).event_shape)

[0.85365486 0.10553633 0.06130229]


(3,)


()


In [18]:
d = dist.Exponential(1) \
    .expand(batch_shape=(1,)) \
    .sample(random.PRNGKey(42), sample_shape=(3,))

print('Sample: \n', d)
print("\n")
print('Sample shape: \n', d.shape)
print("\n")
print('Event shape: \n', dist.Exponential(1).event_shape)

Sample: 
 [[0.85365486]
 [0.10553633]
 [0.06130229]]


Sample shape: 
 (3, 1)


Event shape: 
 ()


In [19]:
d = dist.Exponential(1)

s = d.expand(batch_shape=(2,)) \
    .sample(random.PRNGKey(42), sample_shape=(3,))


print('Sample: \n', s)
print("\n")
print('Sample shape: \n', s.shape)
print("\n")
print('Event shape: \n', d.event_shape)

Sample: 
 [[1.3086625  2.0340383 ]
 [2.0580726  0.23342834]
 [0.206562   0.7990092 ]]


Sample shape: 
 (3, 2)


Event shape: 
 ()


In [20]:
d = dist.Exponential(1)

s = d.expand(batch_shape=[2]) \
    .sample(random.PRNGKey(42), sample_shape=(3,))


print('Sample: \n', s)
print("\n")
print('Sample shape: \n', s.shape)
print("\n")
print('Event shape: \n', d.event_shape)

Sample: 
 [[1.3086625  2.0340383 ]
 [2.0580726  0.23342834]
 [0.206562   0.7990092 ]]


Sample shape: 
 (3, 2)


Event shape: 
 ()


In [21]:
d = dist.LKJ(2, 2)
s = d.sample(random.PRNGKey(42), sample_shape=(1,))

print('Sample: \n', s)
print("\n")
print('Sample shape: \n', s.shape)
print("\n")
print('Event shape: \n', d.event_shape)

Sample: 
 [[[1.         0.44555074]
  [0.44555074 1.        ]]]


Sample shape: 
 (1, 2, 2)


Event shape: 
 (2, 2)


In [22]:
d = dist.LKJ(2, 2)
s = d.expand(batch_shape=[2]) \
    .sample(random.PRNGKey(42), sample_shape=(1,))

print('Sample: \n', s)
print("\n")
print('Sample shape: \n', s.shape)
print("\n")
print('Event shape: \n', d.event_shape)

Sample: 
 [[[[ 1.          0.44442013]
   [ 0.44442013  1.        ]]

  [[ 1.         -0.11944741]
   [-0.11944741  1.        ]]]]


Sample shape: 
 (1, 2, 2, 2)


Event shape: 
 (2, 2)


In [23]:
lkj = dist.LKJ(2, 2)
lkj_sample = lkj.expand(batch_shape=[2]) \
    .sample(random.PRNGKey(42), sample_shape=(2,))

d = dist.MultivariateNormal(
        loc=jnp.stack([1, 2]),
        covariance_matrix=jnp.matmul(jnp.matmul(jnp.diag(jnp.array([3, 4])), lkj_sample), jnp.diag(jnp.array([3, 4])))
        )
s = d.sample(random.PRNGKey(42), sample_shape=(1,))

print('Sample: \n', s)
print("\n")
print('Sample shape: \n', s.shape)
print("\n")
print('Event shape: \n', d.event_shape)

Sample: 
 [[[[  2.468875    2.7756164]
   [ -0.8821709  -1.6952798]]

  [[  3.147201   -1.0912826]
   [ -4.238549  -10.6817   ]]]]


Sample shape: 
 (1, 2, 2, 2)


Event shape: 
 (2,)


In [24]:
lkj_sample = lkj.sample(random.PRNGKey(42), sample_shape=(1,))

d = dist.MultivariateNormal(
        loc=jnp.stack([1, 2]),
        covariance_matrix=jnp.matmul(jnp.matmul(jnp.diag(jnp.array([3, 4])), lkj_sample), jnp.diag(jnp.array([3, 4])))
        ) \
        .expand(batch_shape=(6,)) \
        .sample(random.PRNGKey(42), sample_shape=(1,))

print(d)
print("\n")
print(d.shape)

[[[ 2.1126237  -1.5476623 ]
  [-0.5430193   8.057651  ]
  [-2.8950949  -2.4994526 ]
  [ 2.6727245   3.3619635 ]
  [-1.4942696  -0.81697655]
  [-0.64660287 -1.1234295 ]]]


(1, 6, 2)


In [25]:
lkj_sample = lkj.sample(random.PRNGKey(42), sample_shape=(1,))

d = dist.MultivariateNormal(
        loc=jnp.stack([1, 2]),
        covariance_matrix=jnp.matmul(jnp.matmul(jnp.diag(jnp.array([3, 4])), lkj_sample), jnp.diag(jnp.array([3, 4])))
        ) \
        .expand(batch_shape=(1,)) \
        .sample(random.PRNGKey(42), sample_shape=(6,))

print(d)
print("\n")
print(d.shape)

[[[ 2.1126237  -1.5476623 ]]

 [[-0.5430193   8.057651  ]]

 [[-2.8950949  -2.4994526 ]]

 [[ 2.6727245   3.3619635 ]]

 [[-1.4942696  -0.81697655]]

 [[-0.64660287 -1.1234295 ]]]


(6, 1, 2)


In [26]:
cov_mat = jnp.matmul(jnp.matmul(jnp.diag(jnp.array([3, 4])), lkj_sample), jnp.diag(jnp.array([3, 4])))

print(cov_mat)
print("\n")
print(cov_mat.shape)

[[[ 9.        5.346609]
  [ 5.346609 16.      ]]]


(1, 2, 2)
