In [1]:
import tensorflow as tf
import tensorflow_probability as tfp

In [2]:
def happiness_model(weather_prob, weather_to_happiness_probs) :
    weather = yield tfp.distributions.JointDistributionCoroutine.Root(
        tfp.distributions.Bernoulli(
            probs=weather_prob,
            name='weather',
        )
    )
    
    happiness = yield tfp.distributions.Bernoulli(
        probs = weather_to_happiness_probs[weather],
        name = 'happiness' 
    )

In [3]:
theta_weather = tf.constant(0.8)
theta_happiness = tf.constant([0.7, 0.9])

model_joint_orig = tfp.distributions.JointDistributionCoroutineAutoBatched(
    lambda : happiness_model(theta_weather, theta_happiness),
)
model_joint_orig

<tfp.distributions.JointDistributionCoroutineAutoBatched 'JointDistributionCoroutineAutoBatched' batch_shape=[] event_shape=StructTuple(
  weather=[],
  happiness=[]
) dtype=StructTuple(
  weather=int32,
  happiness=int32
)>

In [4]:
model_joint_orig.sample()

StructTuple(
  weather=<tf.Tensor: shape=(), dtype=int32, numpy=1>,
  happiness=<tf.Tensor: shape=(), dtype=int32, numpy=1>
)

In [5]:
dataset = model_joint_orig.sample(100)

In [6]:
dataset

StructTuple(
  weather=<tf.Tensor: shape=(100,), dtype=int32, numpy=
    array([1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1,
           0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0,
           1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1,
           1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1])>,
  happiness=<tf.Tensor: shape=(100,), dtype=int32, numpy=
    array([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0,
           1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0,
           0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
           1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1])>
)

In [9]:
theta_weather_fit = tfp.util.TransformedVariable(
    initial_value = 0.5, #initial value
    bijector = tfp.bijectors.SoftClip(low = 0.0, high = 1.0), #is a fn that transforms our variable
    name = 'theta_weather_fit'
)

thetas_happiness_fit = tfp.util.TransformedVariable(
    initial_value = [0.5, 0.5], #initial value
    bijector = tfp.bijectors.SoftClip(low = 0.0, high = 1.0),
    name = 'thetas_happiness_fit'
)


In [12]:
model_joint_fit = tfp.distributions.JointDistributionCoroutineAutoBatched(
    lambda : happiness_model(theta_weather_fit, thetas_happiness_fit)
)
model_joint_fit

<tfp.distributions.JointDistributionCoroutineAutoBatched 'JointDistributionCoroutineAutoBatched' batch_shape=[] event_shape=StructTuple(
  weather=[],
  happiness=[]
) dtype=StructTuple(
  weather=int32,
  happiness=int32
)>

In [14]:
model_joint_fit.log_prob(dataset)

<tf.Tensor: shape=(100,), dtype=float32, numpy=
array([-1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862944, -1.3862944, -1.3862944, -1.3862944, -1.3862944,
       -1.3862

In [15]:
neg_log_likelihood = lambda : -tf.reduce_sum(model_joint_fit.log_prob(dataset))

In [16]:
tfp.math.minimize(
loss_fn=neg_log_likelihood, 
    optimizer=tf.optimizers.Adam(0.01), 
    num_steps = 1000
)

<tf.Tensor: shape=(1000,), dtype=float32, numpy=
array([138.62944 , 138.00415 , 137.38367 , 136.76804 , 136.15736 ,
       135.55167 , 134.95111 , 134.35567 , 133.76546 , 133.18053 ,
       132.60092 , 132.02673 , 131.45798 , 130.89473 , 130.33702 ,
       129.78484 , 129.23831 , 128.6974  , 128.16216 , 127.63262 ,
       127.10879 , 126.59068 , 126.07833 , 125.57172 , 125.070854,
       124.575714, 124.086365, 123.602745, 123.12486 , 122.65268 ,
       122.18622 , 121.72544 , 121.27034 , 120.82087 , 120.377014,
       119.93876 , 119.50606 , 119.07887 , 118.65718 , 118.24094 ,
       117.83012 , 117.42468 , 117.02455 , 116.62973 , 116.24015 ,
       115.85579 , 115.476555, 115.10245 , 114.73339 , 114.369354,
       114.0103  , 113.65614 , 113.306854, 112.96237 , 112.62267 ,
       112.287674, 111.95734 , 111.63161 , 111.31045 , 110.9938  ,
       110.681595, 110.37381 , 110.07037 , 109.77122 , 109.47633 ,
       109.18565 , 108.89912 , 108.616684, 108.33829 , 108.06392 ,
       107.79

In [17]:
theta_weather_fit

<TransformedVariable: name=theta_weather_fit, dtype=float32, shape=[], fn="soft_clip", numpy=0.8299995>

In [19]:
theta_weather

<tf.Tensor: shape=(), dtype=float32, numpy=0.8>

In [18]:
thetas_happiness_fit

<TransformedVariable: name=thetas_happiness_fit, dtype=float32, shape=[2], fn="soft_clip", numpy=array([0.64705884, 0.8554211 ], dtype=float32)>

In [20]:
theta_happiness

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.7, 0.9], dtype=float32)>

In [21]:
tf.reduce_sum(model_joint_orig.log_prob(dataset))

<tf.Tensor: shape=(), dtype=float32, numpy=-92.14024>

In [22]:
tf.reduce_sum(model_joint_fit.log_prob(dataset))

<tf.Tensor: shape=(), dtype=float32, numpy=-90.92046>