In [1]:
#%env CUDA_DEVICE_ORDER=PCI_BUS_ID
#%env CUDA_VISIBLE_DEVICES=0

In [2]:
%load_ext autoreload
%autoreload 2

# Imports

In [3]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

from matplotlib import animation
from IPython.display import HTML

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

# Data Loading

In [5]:
data = np.load("../../data/moving_mnist/mnist_test_seq.npy")
data.shape

(20, 10000, 64, 64)

# Data reshaping

In [6]:
# We can see that data is of shape (window, n_samples, width, height)
# But we want for keras something of shape (n_samples, window, width, height)
data = np.moveaxis(data, 0, 1)
# Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
data = np.expand_dims(data, axis=-1)
data.shape

(10000, 20, 64, 64, 1)

# See the frame in action:

In [10]:
def display_videos(data, n_rows=3, n_cols=3):
    fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, squeeze=False)
    ims = []

        
    for i in range(n_rows):
        for j in range(n_cols):
            idx = i*n_rows + j
            video = data[idx]
            im = axs[i][j].imshow(video[0,:,:,:], animated=True)
            ims.append(im)

            plt.close() # this is required to not display the generated image

    def init():
        for i in range(n_rows):
            for j in range(n_cols):
                idx = i*n_rows + j
                video = data[idx]
                im = ims[idx]
                im.set_data(video[0,:,:,:])

    def animate(frame_id):
        for i in range(n_rows):
            for j in range(n_cols):
                idx = i*n_rows + j
                video = data[idx]
                ims[idx].set_data(video[frame_id,:,:,:])
        return ims

    anim = animation.FuncAnimation(fig, animate, 
                                   init_func=init, 
                                   frames=data.shape[1],
                                   blit=True,
                                   interval=100)
    FFwriter = animation.FFMpegWriter(fps=10, codec="libx264")     
    anim.save('basic_animation1.mp4', writer = FFwriter )
    
    return HTML(anim.to_html5_video())
    

In [11]:
display_videos(data[:10], n_rows=1, n_cols=1)

<IPython.core.display.Javascript object>

In [32]:
import cv2

out = cv2.VideoWriter('basic_video.avi', cv2.VideoWriter_fourcc(*'XVID'), 10.0, (64,64))


for frame in data[0]:
    #for i in range(100):
    out.write(cv2.merge((frame,frame,frame)))
out.release()


# Create dataset object

In [20]:
def _preprocess(sample):
    image = tf.cast(sample, tf.float32) / 255.  # Scale to unit interval.
    image = image < tf.random.uniform(tf.shape(image))   # Randomly binarize.
    return image, image

train_dataset = (tf.data.Dataset.from_tensor_slices(data[:9000])
                 .map(_preprocess)
                 .batch(256)
                 .prefetch(tf.data.AUTOTUNE)
                 .shuffle(int(10e3)))
test_dataset = (tf.data.Dataset.from_tensor_slices(data[9000:])
                 .map(_preprocess)
                 .batch(256)
                 .prefetch(tf.data.AUTOTUNE)
                 .shuffle(int(10e3)))

# Specify model

In [9]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3', '/job:localhost/replica:0/task:0/device:GPU:4', '/job:localhost/replica:0/task:0/device:GPU:5', '/job:localhost/replica:0/task:0/device:GPU:6', '/job:localhost/replica:0/task:0/device:GPU:7')


In [10]:
input_shape = data.shape[1:]
encoded_size = 32
base_depth = 32

In [11]:
prior = tfp.distributions.Independent(tfp.distributions.Normal(loc=tf.zeros(encoded_size), scale=1),
                        reinterpreted_batch_ndims=1)

encoder = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=input_shape),
    tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
    tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu)),
    tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu)),
    tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(2 * base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu)),
    tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(2 * base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu)),
    tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(4 * encoded_size, 7, strides=1,
                padding='valid', activation=tf.nn.leaky_relu)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(tfp.layers.MultivariateNormalTriL.params_size(encoded_size),
               activation=None),
    tfp.layers.MultivariateNormalTriL(
        encoded_size,
        activity_regularizer=tfp.layers.KLDivergenceRegularizer(prior)),
])

In [12]:
with strategy.scope():
    encoder = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=input_shape),
        tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
        tf.keras.layers.Conv3D(base_depth, 5, strides=1,
                    padding='same', activation=tf.nn.leaky_relu),
        tf.keras.layers.Conv3D(base_depth, 5, strides=2,
                    padding='same', activation=tf.nn.leaky_relu),
        tf.keras.layers.Conv3D(2 * base_depth, 5, strides=1,
                    padding='same', activation=tf.nn.leaky_relu),
        tf.keras.layers.Conv3D(2 * base_depth, 5, strides=2,
                    padding='same', activation=tf.nn.leaky_relu),
        #tf.keras.layers.Conv3D(4 * encoded_size, 7, strides=1,
        #            padding='valid', activation=tf.nn.leaky_relu),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(tfp.layers.MultivariateNormalTriL.params_size(encoded_size),
                   activation=None),
        tfp.layers.MultivariateNormalTriL(
            encoded_size,
            activity_regularizer=tfp.layers.KLDivergenceRegularizer(prior)),
    ])

2022-04-11 04:21:18.780269: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


In [13]:
encoder.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lambda (Lambda)             (None, 20, 64, 64, 1)     0         
                                                                 
 conv3d (Conv3D)             (None, 20, 64, 64, 32)    4032      
                                                                 
 conv3d_1 (Conv3D)           (None, 10, 32, 32, 32)    128032    
                                                                 
 conv3d_2 (Conv3D)           (None, 10, 32, 32, 64)    256064    
                                                                 
 conv3d_3 (Conv3D)           (None, 5, 16, 16, 64)     512064    
                                                                 
 flatten (Flatten)           (None, 81920)             0         
                                                                 
 dense (Dense)               (None, 560)               4

In [14]:
with strategy.scope():
    decoder = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=[encoded_size]),
        tf.keras.layers.Reshape([1, 1, 1, encoded_size]),
        tf.keras.layers.Conv3DTranspose(2 * base_depth, (5, 4, 4), strides=1,
                             padding='valid', activation=tf.nn.leaky_relu),
        tf.keras.layers.Conv3DTranspose(2 * base_depth, (5, 4, 4), strides=(1, 2, 2),
                             padding='same', activation=tf.nn.leaky_relu),
        tf.keras.layers.Conv3DTranspose(2 * base_depth, (5, 4, 4), strides=2,
                             padding='same', activation=tf.nn.leaky_relu),
        tf.keras.layers.Conv3DTranspose(base_depth, (5, 4, 4), strides=(1, 2, 2),
                             padding='same', activation=tf.nn.leaky_relu),
        tf.keras.layers.Conv3DTranspose(base_depth, (5, 4, 4), strides=2,
                             padding='same', activation=tf.nn.leaky_relu),
        tf.keras.layers.Conv3DTranspose(base_depth, (5, 4, 4), strides=1,
                             padding='same', activation=tf.nn.leaky_relu),
        tf.keras.layers.Conv2D(filters=1, kernel_size=5, strides=1,
                    padding='same', activation=None),
        tf.keras.layers.Flatten(),
        tfp.layers.IndependentBernoulli(input_shape, tfp.distributions.Bernoulli.logits),
    ])

In [15]:
decoder.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 1, 1, 1, 32)       0         
                                                                 
 conv3d_transpose (Conv3DTra  (None, 5, 4, 4, 64)      163904    
 nspose)                                                         
                                                                 
 conv3d_transpose_1 (Conv3DT  (None, 5, 8, 8, 64)      327744    
 ranspose)                                                       
                                                                 
 conv3d_transpose_2 (Conv3DT  (None, 10, 16, 16, 64)   327744    
 ranspose)                                                       
                                                                 
 conv3d_transpose_3 (Conv3DT  (None, 10, 32, 32, 32)   163872    
 ranspose)                                            

In [16]:

with strategy.scope():
    vae = tf.keras.Model(inputs=encoder.inputs,
                    outputs=decoder(encoder.outputs[0]))

In [19]:
from ganime.data.base import load_dataset

my_train_dataset, ds_info = load_dataset("moving_mnist")

In [33]:
for video in train_dataset.take(1):
    print(video[0].shape)

(40, 20, 64, 64, 1)


In [34]:
for video in my_train_dataset.take(1):
    print(video[0].shape)

(256, 20, 64, 64, 1)


In [18]:
negloglik = lambda x, rv_x: -rv_x.log_prob(x)

vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
            loss=negloglik)

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=5, restore_best_weights=True)

_ = vae.fit(train_dataset,
            epochs=2,
            validation_data=test_dataset,
            callbacks=[callback]
           )

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2022-04-11 04:22:10.608736: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


Epoch 1/2
INFO:tensorflow:batch_all_reduce: 24 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 24 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:batch_all_reduce: 24 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 24 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2022-04-11 04:22:37.002860: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 33 of 10000
2022-04-11 04:22:38.825176: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:415] Shuffle buffer filled.
2022-04-11 04:22:39.280330: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8101
2022-04-11 04:22:40.073186: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8101
2022-04-11 04:22:41.063056: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8101
2022-04-11 04:22:42.302120: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8101
2022-04-11 04:22:42.932503: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2022-04-11 



InvalidArgumentError: Graph execution error:

Detected at node 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv' defined at (most recent call last):
    File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
      self.run()
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 220, in __call__
      distribution, _ = super(DistributionLambda, self).__call__(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 226, in call
      distribution, value = super(DistributionLambda, self).call(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/layers/core/lambda_layer.py", line 196, in call
      result = self.function(inputs, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 180, in _fn
      value = distribution._value()  # pylint: disable=protected-access
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 215, in _value
      self._convert_to_tensor_fn(self.tensor_distribution)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 327, in _call_sample_n
      x = self._maybe_broadcast_distribution_batch_shape().sample(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1211, in _call_sample_n
      samples = self._sample_n(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 298, in _sample_n
      self._augment_sample_shape(sample_shape), seed=seed)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 236, in _augment_sample_shape
      [ps.maximum(0, n_batch // underlying_n_batch)]],
Node: 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv'
Detected at node 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv' defined at (most recent call last):
    File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
      self.run()
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 220, in __call__
      distribution, _ = super(DistributionLambda, self).__call__(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 226, in call
      distribution, value = super(DistributionLambda, self).call(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/layers/core/lambda_layer.py", line 196, in call
      result = self.function(inputs, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 180, in _fn
      value = distribution._value()  # pylint: disable=protected-access
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 215, in _value
      self._convert_to_tensor_fn(self.tensor_distribution)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 327, in _call_sample_n
      x = self._maybe_broadcast_distribution_batch_shape().sample(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1211, in _call_sample_n
      samples = self._sample_n(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 298, in _sample_n
      self._augment_sample_shape(sample_shape), seed=seed)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 236, in _augment_sample_shape
      [ps.maximum(0, n_batch // underlying_n_batch)]],
Node: 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv'
Detected at node 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv' defined at (most recent call last):
    File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
      self.run()
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 220, in __call__
      distribution, _ = super(DistributionLambda, self).__call__(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 226, in call
      distribution, value = super(DistributionLambda, self).call(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/layers/core/lambda_layer.py", line 196, in call
      result = self.function(inputs, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 180, in _fn
      value = distribution._value()  # pylint: disable=protected-access
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 215, in _value
      self._convert_to_tensor_fn(self.tensor_distribution)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 327, in _call_sample_n
      x = self._maybe_broadcast_distribution_batch_shape().sample(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1211, in _call_sample_n
      samples = self._sample_n(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 298, in _sample_n
      self._augment_sample_shape(sample_shape), seed=seed)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 236, in _augment_sample_shape
      [ps.maximum(0, n_batch // underlying_n_batch)]],
Node: 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv'
Detected at node 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv' defined at (most recent call last):
    File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
      self.run()
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 220, in __call__
      distribution, _ = super(DistributionLambda, self).__call__(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 226, in call
      distribution, value = super(DistributionLambda, self).call(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/layers/core/lambda_layer.py", line 196, in call
      result = self.function(inputs, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 180, in _fn
      value = distribution._value()  # pylint: disable=protected-access
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 215, in _value
      self._convert_to_tensor_fn(self.tensor_distribution)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 327, in _call_sample_n
      x = self._maybe_broadcast_distribution_batch_shape().sample(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1211, in _call_sample_n
      samples = self._sample_n(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 298, in _sample_n
      self._augment_sample_shape(sample_shape), seed=seed)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 236, in _augment_sample_shape
      [ps.maximum(0, n_batch // underlying_n_batch)]],
Node: 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv'
Detected at node 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv' defined at (most recent call last):
    File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
      self.run()
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 220, in __call__
      distribution, _ = super(DistributionLambda, self).__call__(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 226, in call
      distribution, value = super(DistributionLambda, self).call(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/layers/core/lambda_layer.py", line 196, in call
      result = self.function(inputs, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 180, in _fn
      value = distribution._value()  # pylint: disable=protected-access
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 215, in _value
      self._convert_to_tensor_fn(self.tensor_distribution)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 327, in _call_sample_n
      x = self._maybe_broadcast_distribution_batch_shape().sample(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1211, in _call_sample_n
      samples = self._sample_n(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 298, in _sample_n
      self._augment_sample_shape(sample_shape), seed=seed)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 236, in _augment_sample_shape
      [ps.maximum(0, n_batch // underlying_n_batch)]],
Node: 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv'
Detected at node 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv' defined at (most recent call last):
    File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
      self.run()
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 220, in __call__
      distribution, _ = super(DistributionLambda, self).__call__(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 226, in call
      distribution, value = super(DistributionLambda, self).call(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/layers/core/lambda_layer.py", line 196, in call
      result = self.function(inputs, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 180, in _fn
      value = distribution._value()  # pylint: disable=protected-access
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 215, in _value
      self._convert_to_tensor_fn(self.tensor_distribution)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 327, in _call_sample_n
      x = self._maybe_broadcast_distribution_batch_shape().sample(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1211, in _call_sample_n
      samples = self._sample_n(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 298, in _sample_n
      self._augment_sample_shape(sample_shape), seed=seed)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 236, in _augment_sample_shape
      [ps.maximum(0, n_batch // underlying_n_batch)]],
Node: 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv'
Detected at node 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv' defined at (most recent call last):
    File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
      self.run()
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 220, in __call__
      distribution, _ = super(DistributionLambda, self).__call__(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 226, in call
      distribution, value = super(DistributionLambda, self).call(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/layers/core/lambda_layer.py", line 196, in call
      result = self.function(inputs, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 180, in _fn
      value = distribution._value()  # pylint: disable=protected-access
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 215, in _value
      self._convert_to_tensor_fn(self.tensor_distribution)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 327, in _call_sample_n
      x = self._maybe_broadcast_distribution_batch_shape().sample(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1211, in _call_sample_n
      samples = self._sample_n(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 298, in _sample_n
      self._augment_sample_shape(sample_shape), seed=seed)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 236, in _augment_sample_shape
      [ps.maximum(0, n_batch // underlying_n_batch)]],
Node: 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv'
Detected at node 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv' defined at (most recent call last):
    File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
      self.run()
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 220, in __call__
      distribution, _ = super(DistributionLambda, self).__call__(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 226, in call
      distribution, value = super(DistributionLambda, self).call(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/layers/core/lambda_layer.py", line 196, in call
      result = self.function(inputs, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 180, in _fn
      value = distribution._value()  # pylint: disable=protected-access
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 215, in _value
      self._convert_to_tensor_fn(self.tensor_distribution)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 327, in _call_sample_n
      x = self._maybe_broadcast_distribution_batch_shape().sample(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1211, in _call_sample_n
      samples = self._sample_n(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 298, in _sample_n
      self._augment_sample_shape(sample_shape), seed=seed)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 236, in _augment_sample_shape
      [ps.maximum(0, n_batch // underlying_n_batch)]],
Node: 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv'
Detected at node 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv' defined at (most recent call last):
    File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
      self.run()
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 220, in __call__
      distribution, _ = super(DistributionLambda, self).__call__(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 226, in call
      distribution, value = super(DistributionLambda, self).call(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/keras/layers/core/lambda_layer.py", line 196, in call
      result = self.function(inputs, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 180, in _fn
      value = distribution._value()  # pylint: disable=protected-access
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 215, in _value
      self._convert_to_tensor_fn(self.tensor_distribution)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 327, in _call_sample_n
      x = self._maybe_broadcast_distribution_batch_shape().sample(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1234, in sample
      return self._call_sample_n(sample_shape, seed, **kwargs)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1211, in _call_sample_n
      samples = self._sample_n(
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 298, in _sample_n
      self._augment_sample_shape(sample_shape), seed=seed)
    File "/home/abdalla/GANime/venv/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_broadcast.py", line 236, in _augment_sample_shape
      [ps.maximum(0, n_batch // underlying_n_batch)]],
Node: 'replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv'
9 root error(s) found.
  (0) INVALID_ARGUMENT:  Integer division by zero
	 [[{{node replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv}}]]
	 [[replica_7/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/Prod_1/_430]]
  (1) INVALID_ARGUMENT:  Integer division by zero
	 [[{{node replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv}}]]
	 [[replica_6/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/BatchBroadcastSampleNormal/batch_shape_tensor/BroadcastArgs/_397]]
  (2) INVALID_ARGUMENT:  Integer division by zero
	 [[{{node replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv}}]]
	 [[replica_4/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/Prod_1/_418]]
  (3) INVALID_ARGUMENT:  Integer division by zero
	 [[{{node replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv}}]]
	 [[replica_3/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/Prod/_411]]
  (4) INVALID_ARGUMENT:  Integer division by zero
	 [[{{node replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv}}]]
	 [[replica_2/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/BatchBroadcastSampleNormal/batch_shape_tensor/BroadcastArgs/_389]]
  (5) INVALID_ARGUMENT:  Integer division by zero
	 [[{{node replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv}}]]
	 [[gradient_tape/replica_1/model/multivariate_normal_tri_l/ActivityRegularizer/kldivergence_loss/tensor_coercible_CONSTRUCTED_AT_replica_1_model_multivariate_normal_tri_l/log_prob/chain_of_shift_of_scale_matvec_linear_operator/inverse_log_det_jacobian/scale_matvec_linear_operator/inverse_log_det_jacobian/LinearOperatorLowerTriangular/log_abs_det/DynamicStitch/_626]]
  (6) INVALID_ARGUMENT:  Integer division by zero
	 [[{{node replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv}}]]
	 [[div_no_nan/ReadVariableOp_3/_698]]
  (7) INVALID_ARGUMENT:  Integer division by zero
	 [[{{node replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv}}]]
	 [[GroupCrossDeviceControlEdges_2/Identity_9/_885]]
  (8) INVALID_ARGUMENT:  Integer division by zero
	 [[{{node replica_5/model/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/BatchBroadcastSampleNormal/sample/floordiv}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_434745]

# Results

In [None]:
# We'll just examine ten random digits.
x = next(iter(test_dataset))[0][:5]
xhat = vae(x)
assert isinstance(xhat, tfp.distributions.Distribution)

In [None]:
print('Originals:')
display_videos(x, n_rows=1, n_cols=5)

In [None]:
print('Decoded Random Samples:')
display_videos(xhat.sample(), n_rows=1, n_cols=5)

In [None]:
print('Decoded Modes:')
display_videos(xhat.mode(), n_rows=1, n_cols=5)

In [None]:
print('Decoded Means:')
display_videos(xhat.mean(), n_rows=1, n_cols=5)

In [None]:
print('Decoded variance:')
display_videos(xhat.variance(), n_rows=1, n_cols=5)

In [None]:
# Now, let's generate ten never-before-seen digits.
z = prior.sample(10)
xtilde = decoder(z)
assert isinstance(xtilde, tfp.distributions.Distribution)

In [None]:
print('Randomly Generated Samples:')
display_videos(xhat.sample(), n_rows=1, n_cols=5)

In [None]:
print('Randomly Generated Modes:')
display_videos(xhat.mode(), n_rows=1, n_cols=5)

In [None]:
print('Randomly Generated Means:')
display_videos(xhat.mean(), n_rows=1, n_cols=5)

In [None]:
print('Randomly Generated variance:')
display_videos(xhat.variance(), n_rows=1, n_cols=5)

In [None]:
encoder.save("mnist_encoder")