In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import prefer_static
tfb = tfp.bijectors
tfd = tfp.distributions

In [2]:
event_size = 4
num_components = 3

# Learnable Multivariate Normal with Scaled Identity for chol(Cov)

In [3]:
learnable_mvn_scaled_identity = tfd.Independent(
    tfd.Normal(loc=tf.Variable(tf.zeros(event_size), name='loc'),
               scale=tfp.util.TransformedVariable(
                   tf.ones([event_size, 1]), 
                   bijector=tfb.Exp()), 
               name='scale'), 
    reinterpreted_batch_ndims=1, 
    name='learnable_mvn_scaled_identity')

In [4]:
learnable_mvn_scaled_identity

<tfp.distributions.Independent 'learnable_mvn_scaled_identity' batch_shape=[4] event_shape=[4] dtype=float32>

In [5]:
learnable_mvn_scaled_identity.trainable_variables

(<tf.Variable 'loc:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=(4, 1) dtype=float32, numpy=
 array([[0.],
        [0.],
        [0.],
        [0.]], dtype=float32)>)

In [8]:
trainable_normal = tfd.Normal(loc=tf.Variable(0.),
                              scale=tfp.util.TransformedVariable(1., bijector=tfb.Exp()))

In [9]:
trainable_normal.loc

<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>

In [10]:
trainable_normal.scale

<TransformedVariable: dtype=float32, shape=[], fn="exp", numpy=1.0>

In [11]:
trainable_normal.parameters

{'loc': <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>,
 'scale': <TransformedVariable: dtype=float32, shape=[], fn="exp", numpy=1.0>,
 'validate_args': False,
 'allow_nan_stats': True,
 'name': 'Normal'}

In [12]:
trainable_normal.stddev()

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

In [14]:
learnable_mvndiag = tfd.Independent(
    tfd.Normal(
        loc=tf.Variable(tf.zeros(event_size), name='loc'),
        scale=tfp.util.TransformedVariable(
            tf.ones(event_size),
            bijector=tfb.Softplus()),  # Use Softplus...cuz why not?
            name='scale'),
    reinterpreted_batch_ndims=1,
    name='learnable_mvn_diag')

print(learnable_mvndiag)
print(learnable_mvndiag.trainable_variables)

tfp.distributions.Independent("learnable_mvn_diag", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'loc:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>, <tf.Variable 'Variable:0' shape=(4,) dtype=float32, numpy=array([0.54132485, 0.54132485, 0.54132485, 0.54132485], dtype=float32)>)


In [15]:
t = tf.ones(4)
t

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([1., 1., 1., 1.], dtype=float32)>

In [16]:
soft = tfb.Softplus()

In [17]:
soft.forward(t)

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([1.3132616, 1.3132616, 1.3132616, 1.3132616], dtype=float32)>

In [27]:
soft.inverse(t)

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([0.54132485, 0.54132485, 0.54132485, 0.54132485], dtype=float32)>

In [18]:
np.log(1+np.exp(1))

1.3132616875182228

In [20]:
s = tfp.util.TransformedVariable(
            tf.ones(event_size),
            bijector=tfb.Softplus())

In [21]:
s

<TransformedVariable: dtype=float32, shape=[4], fn="softplus", numpy=array([1., 1., 1., 1.], dtype=float32)>

In [22]:
s.pretransformed_input

<tf.Variable 'Variable:0' shape=(4,) dtype=float32, numpy=array([0.54132485, 0.54132485, 0.54132485, 0.54132485], dtype=float32)>

In [23]:
soft.forward(s.pretransformed_input)

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([1., 1., 1., 1.], dtype=float32)>

In [24]:
n = tfd.Normal(
        loc=tf.Variable(tf.zeros(event_size), name='loc'),
        scale=tfp.util.TransformedVariable(
            tf.ones(event_size),
            bijector=tfb.Softplus()),  # Use Softplus...cuz why not?
            name='scale')

In [25]:
n.scale

<TransformedVariable: dtype=float32, shape=[4], fn="softplus", numpy=array([1., 1., 1., 1.], dtype=float32)>

In [26]:
n.parameters

{'loc': <tf.Variable 'loc:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>,
 'scale': <TransformedVariable: dtype=float32, shape=[4], fn="softplus", numpy=array([1., 1., 1., 1.], dtype=float32)>,
 'validate_args': False,
 'allow_nan_stats': True,
 'name': 'scale'}