Skip to content

Commit

Permalink
Change scale transformation of IAF and add debugging patches
Browse files Browse the repository at this point in the history
  • Loading branch information
jdehning committed Mar 1, 2021
1 parent 35b62f8 commit bf42873
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 10 deletions.
19 changes: 12 additions & 7 deletions covid19_npis/model/approximate_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tensorflow as tf
from tensorflow_probability import distributions as tfd
from tensorflow_probability import bijectors as tfb
from .. import transformations


# The code in this module raises an error during cleanup. This hook catches it such that
Expand Down Expand Up @@ -34,11 +35,13 @@ def create_bijector_fn(shift_and_log_scale_fn):

def bijector_fn(x, **condition_kwargs):
params = shift_and_log_scale_fn(x, **condition_kwargs)
shift, log_scale = tf.unstack(params, num=2, axis=-1)
shift, scale = tf.unstack(params, num=2, axis=-1)

bijectors = []
bijectors.append(tfb.Shift(shift))
bijectors.append(tfb.Scale(log_scale=log_scale))
bijectors.append(
tfb.Scale(scale=transformations.Exp_SinhArcsinh().inverse(scale))
)
return tfb.Chain(bijectors, validate_event_size=False)

return bijector_fn
Expand Down Expand Up @@ -73,11 +76,13 @@ def build_iaf(values_iaf_dict, order_list, values_exclude_dict=None):
bijectors_iaf_list.append(
tfb.Invert(
tfb.MaskedAutoregressiveFlow(
shift_and_log_scale_fn=tfp.bijectors.AutoregressiveNetwork(
params=2,
hidden_units=[size_iaf, size_iaf],
input_order=order,
activation="elu",
bijector_fn=create_bijector_fn(
tfp.bijectors.AutoregressiveNetwork(
params=2,
hidden_units=[size_iaf, size_iaf],
input_order=order,
activation="elu",
)
)
)
)
Expand Down
28 changes: 28 additions & 0 deletions scripts/debugging_patches/filter_nan_errors1.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Index: lib/python3.8/site-packages/tensorflow_probability/python/math/generic.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- lib/python3.8/site-packages/tensorflow_probability/python/math/generic.py (date 1614119717126)
+++ lib/python3.8/site-packages/tensorflow_probability/python/math/generic.py (date 1614119717126)
@@ -656,10 +656,19 @@
logcosh = abs_x + tf.math.softplus(-2 * abs_x) - np.log(2).astype(
numpy_dtype)
bound = 45. * np.power(np.finfo(numpy_dtype).tiny, 1 / 6.)
- return tf.where(
+ try:
+ res = tf.where(
abs_x <= bound,
tf.math.exp(tf.math.log(abs_x) + tf.math.log1p(-tf.square(abs_x) / 6.)),
logcosh)
+ except tf.errors.InvalidArgumentError:
+ tf.debugging.disable_check_numerics()
+ res = tf.where(
+ abs_x <= bound,
+ tf.math.exp(tf.math.log(abs_x) + tf.math.log1p(-tf.square(abs_x) / 6.)),
+ logcosh)
+ tf.debugging.enable_check_numerics(stack_height_limit=50, path_length_limit=50)
+ return res


def _log_cosh_jvp(primals, tangents):
22 changes: 22 additions & 0 deletions scripts/debugging_patches/filter_nan_errors2.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Index: lib/python3.8/site-packages/tensorflow_probability/python/bijectors/sigmoid.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- lib/python3.8/site-packages/tensorflow_probability/python/bijectors/sigmoid.py (date 1613231204680)
+++ lib/python3.8/site-packages/tensorflow_probability/python/bijectors/sigmoid.py (date 1613231204680)
@@ -44,7 +44,13 @@
cutoff = -20
else:
cutoff = -9
- return tf.where(x < cutoff, tf.exp(x), tf.math.sigmoid(x))
+ try:
+ res = tf.where(x < cutoff, tf.exp(x), tf.math.sigmoid(x))
+ except tf.errors.InvalidArgumentError:
+ tf.debugging.disable_check_numerics()
+ res = tf.where(x < cutoff, tf.exp(x), tf.math.sigmoid(x))
+ tf.debugging.enable_check_numerics(stack_height_limit=50, path_length_limit=50)
+ return res

@tf.custom_gradient
def _stable_grad_softplus(x):
33 changes: 33 additions & 0 deletions scripts/debugging_patches/filter_nan_errors3.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
Index: lib/python3.8/site-packages/tensorflow_probability/python/bijectors/softplus.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- lib/python3.8/site-packages/tensorflow_probability/python/bijectors/softplus.py (date 1614198880654)
+++ lib/python3.8/site-packages/tensorflow_probability/python/bijectors/softplus.py (date 1614198880654)
@@ -50,11 +50,22 @@
else:
cutoff = -9

- y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))
+ try:
+ y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))
+ except tf.errors.InvalidArgumentError:
+ tf.debugging.disable_check_numerics()
+ y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))
+ tf.debugging.enable_check_numerics(stack_height_limit=50, path_length_limit=50)

def grad_fn(dy):
- return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))
-
+ try:
+ res = tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))
+ except tf.errors.InvalidArgumentError:
+ tf.debugging.disable_check_numerics()
+ res = tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))
+ tf.debugging.enable_check_numerics(stack_height_limit=50,
+ path_length_limit=50)
+ return dy * res
return y, grad_fn


7 changes: 4 additions & 3 deletions scripts/dummy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@
# Mute Tensorflow warnings ...
# logging.getLogger("tensorflow").setLevel(logging.ERROR)

# For eventual debugging:
# tf.config.run_functions_eagerly(True)
# tf.debugging.enable_check_numerics(stack_height_limit=50, path_length_limit=50)

if tf.executing_eagerly():
log.warning("Running in eager mode!")
Expand Down Expand Up @@ -172,6 +169,10 @@ def print_dist_shapes(st):
traceable_quantities.loss, f"loss not finite: {traceable_quantities.loss}"
)

# For eventual debugging:
# tf.config.run_functions_eagerly(True)
# tf.debugging.enable_check_numerics(stack_height_limit=50, path_length_limit=50)

begin = time.time()
posterior = tfp.vi.fit_surrogate_posterior(
logpfn,
Expand Down

0 comments on commit bf42873

Please sign in to comment.