In [1]:
!pip install wheel==0.37.1 setuptools==59.6.0
!pip install gym[atari,accept-rom-license]==0.21.0 tensorflow-probability==0.19.0 tqdm==4.64.1

You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [2]:
from math import ceil
from typing import Iterable, Tuple, List, Callable
from dataclasses import dataclass

import gym
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from keras import Model
from keras.optimizers import Optimizer, Adam
from keras.metrics import Mean
from keras.layers import \
    Layer, Input, Conv2D, Conv2DTranspose, Lambda, \
    Dropout, Reshape, Flatten, Dense, GaussianNoise, Concatenate

In [3]:
print(tf.__version__)
print(tf.test.gpu_device_name())

2.11.0
/device:GPU:0


In [4]:
@dataclass
class DreamerSettings:
    action_dims: List[int]
    obs_dims: List[int]
    repr_dims: List[int]
    hidden_dims: List[int]
    enc_dims: List[int]
    dropout_rate: float = 0.2
    codebook_size: int = 32

    @property
    def repr_dims_flat(self) -> int:
        return self.repr_dims[0] * self.repr_dims[1]

    @property
    def repr_out_dims_flat(self) -> int:
        return self.repr_dims[0] * self.repr_dims[1] + self.hidden_dims[0]

In [5]:
def sample_obs(env: gym.Env, num_obs: int, proprocess_image) -> tf.data.Dataset:
    all_obs = []
    done = False
    obs = env.reset()
    all_obs.append(proprocess_image(obs))
    while len(all_obs) < num_obs:
        action = env.action_space.sample()
        obs, reward, done, _ = env.step(action)
        all_obs.append(proprocess_image(obs))
        if done and len(all_obs) < num_obs:
            obs = env.reset()
            all_obs.append(proprocess_image(obs))
    return tf.data.Dataset.from_tensor_slices(tensors=(np.array(all_obs)))


def generate_datasets(
        env: gym.Env, settings: DreamerSettings,
        batch_size: int, timesteps: int):

    def preprocess(obs):
        obs = tf.image.resize(obs, settings.obs_dims[:2])
        return tf.cast(obs, dtype=tf.float32) / 255.0

    dataset = sample_obs(env, timesteps, preprocess)
    dataset = dataset.map(lambda obs: (obs, obs))
    dataset = dataset.batch(batch_size)
    num_eval_batches = ceil((timesteps // batch_size) * 0.125)
    train_dataset = dataset.skip(num_eval_batches)
    eval_dataset = dataset.take(num_eval_batches)
    train_dataset = train_dataset.shuffle(100)
    return train_dataset, eval_dataset

In [6]:
class VQCodebook(Layer):
    """Representing a codebook of a vector quantization for a given amount
    of classifications with a given amount of classes each. The embedding
    vectors are initialized to match the inputs to be quantized. When calling
    this layer, it expects to receive one-hot encoded categoricals of shape
    (batch_size, num_classifications, num_classes)."""

    def __init__(
            self, num_classifications: int, num_classes: int,
            name: str="vq_codebook"):
        super(VQCodebook, self).__init__(name=name)
        self.num_classifications = num_classifications
        self.num_classes = num_classes
        self.num_embeddings = num_classifications * num_classes
        self.reshape_out = None

    def init_codebook(self, input_shape: Iterable[int]):
        orig_input_shape = input_shape[1:]
        self.reshape_out = Reshape(orig_input_shape)
        input_dims_flat = tf.reduce_prod(orig_input_shape)
        self.embedding_dims = input_dims_flat // self.num_classifications

        if input_dims_flat % self.num_classifications != 0:
            raise ValueError((
                f"The input dimensions {input_dims_flat} must be divisible "
                f"by the number of classifications {self.num_classifications} "
                f"to support swapping each of the {self.num_classifications} slices "
                "from the input vector with a quantized vector from the codebook."))

        embed_shape = (self.embedding_dims, self.num_embeddings)
        self.embeddings = self.add_weight(
            "embeddings", shape=embed_shape, trainable=True, initializer="random_normal")

    def call(self, categoricals_onehot: tf.Tensor):
        categoricals_sparse = tf.argmax(categoricals_onehot, axis=2)
        id_offsets = tf.range(0, self.num_classifications, dtype=tf.int64) * self.num_classes
        categoricals_embed_sparse = categoricals_sparse + id_offsets
        categoricals_embed = tf.one_hot(categoricals_embed_sparse, depth=self.num_embeddings)
        quantized = tf.matmul(categoricals_embed, self.embeddings, transpose_b=True)
        return self.reshape_out(quantized)

    def most_similar_embeddings(self, inputs: tf.Tensor):
        input_shape = (-1, self.num_classifications, self.embedding_dims)
        embed_shape = (-1, self.num_classifications, self.num_classes)
        inputs_per_classification = tf.reshape(inputs, input_shape)
        embeddings_per_classification = tf.reshape(self.embeddings, embed_shape)
        codebook_ids = []

        for i in range(self.num_classifications):
            embeddings = embeddings_per_classification[:, i, :]
            inputs_classif = inputs_per_classification[:, i, :]

            inputs_sqsum = tf.reduce_sum(inputs_classif ** 2, axis=1, keepdims=True)
            embed_sqsum = tf.reduce_sum(embeddings ** 2, axis=0)
            similarity = tf.matmul(inputs_classif, embeddings)
            distances = inputs_sqsum + embed_sqsum - 2 * similarity

            class_ids = tf.argmin(distances, axis=1, output_type=tf.int64)
            codebook_ids.append(tf.expand_dims(class_ids, axis=0))

        codebook_ids = tf.concat(codebook_ids, axis=0)
        codebook_ids = tf.transpose(codebook_ids, perm=[1, 0])
        return codebook_ids


class VQCategorical(Layer):
    """Representing a transformation of an input vector to be quantized into
    a one-hot encoded categorical matching the quantized vectors of the codebook.
    This layer can be used to receive a high-level latent state from arbitrary input.
    It expects to be used in combination with a codebook instance that is managing
    the embeddings used for quantization."""

    def __init__(self, codebook: VQCodebook, name: str="vq_categorical"):
        super(VQCategorical, self).__init__(name=name)
        self.codebook = codebook
        #self.trainable = False

    def build(self, input_shape: Iterable[int]):
        self.codebook.init_codebook(input_shape)

    def call(self, inputs: tf.Tensor):
        categoricals_sparse = self.codebook.most_similar_embeddings(inputs)
        return tf.one_hot(categoricals_sparse, self.codebook.num_classes)

In [7]:
def create_encoder(settings: DreamerSettings) -> Model:
    model_in = Input(settings.obs_dims, name="obs")
    norm_img = Lambda(lambda x: x * 2.0 - 1.0)
    cnn_1 = Conv2D(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_2 = Conv2D(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_3 = Conv2D(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_4 = Conv2D(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_5 = Conv2D(64, (3, 3), padding="same", activation="elu")
    drop_1 = Dropout(rate=settings.dropout_rate)
    drop_2 = Dropout(rate=settings.dropout_rate)
    drop_3 = Dropout(rate=settings.dropout_rate)
    drop_4 = Dropout(rate=settings.dropout_rate)

    img_in = norm_img(model_in)
    model_out = cnn_5(drop_4(cnn_4(drop_3(cnn_3(drop_2(cnn_2(drop_1(cnn_1(img_in)))))))))
    return Model(inputs=model_in, outputs=model_out, name="encoder_model")


def create_decoder(settings: DreamerSettings) -> Model:
    image_channels = settings.obs_dims[-1]
    in_height, in_width = settings.obs_dims[0] // 16, settings.obs_dims[1] // 16
    in_channels = np.prod(settings.obs_dims) / in_height / in_width
    # upscale_source_dims = in_height * in_width * in_channels

    model_in = Input((in_height, in_width, in_channels), name="repr_out")
    #dense_in = Dense(upscale_source_dims, activation="linear", name="dec_in")
    # reshape_in = Reshape((in_height, in_width, -1))
    cnn_1 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_2 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_3 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_4 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_5 = Conv2D(image_channels, (1, 1), padding="same", activation="linear")
    drop_1 = Dropout(rate=settings.dropout_rate)
    drop_2 = Dropout(rate=settings.dropout_rate)
    drop_3 = Dropout(rate=settings.dropout_rate)

    # prep_in = reshape_in(model_in)
    model_out = cnn_5(cnn_4(drop_3(cnn_3(drop_2(cnn_2(drop_1(cnn_1(model_in))))))))
    return Model(inputs=model_in, outputs=model_out, name="decoder_model")


def compose_vqvae(settings: DreamerSettings) -> Tuple[Model, Model]:
    codebook = VQCodebook(settings.repr_dims[0], settings.repr_dims[1])
    cat_quant = VQCategorical(codebook)
    encoder = create_encoder(settings)
    decoder = create_decoder(settings)

    h_fake_noise = GaussianNoise(stddev=1.0)
    z_flatten = Flatten()
    z_h_concat = Concatenate()

    model_in = Input(shape=settings.obs_dims, name="img_orig")
    z_enc = encoder(model_in)
    z_fused_dense = Dense(units=tf.reduce_prod(z_enc.shape[1:]))
    z_enc_reshape = Reshape(z_enc.shape[1:])
    h_fake = h_fake_noise(tf.zeros((128, 512)))
    z_enc = z_enc_reshape(z_fused_dense(z_h_concat((z_flatten(z_enc), h_fake))))

    z_cat = cat_quant(z_enc)
    z_quantized = codebook(z_cat)
    z_st_quantized = z_enc + tf.stop_gradient(z_quantized - z_enc)
    x_reconst = decoder(z_st_quantized)

    vqvae_train = Model(inputs=[model_in], outputs=[x_reconst, z_enc, z_quantized], name="vqvae")
    vqvae_infer = Model(inputs=[model_in], outputs=[x_reconst], name="vqvae")
    return vqvae_train, vqvae_infer


def create_model(settings: DreamerSettings) -> Tuple[Model, Model]:
    model_train, model_infer = compose_vqvae(settings)
    model_train.build([None] + settings.obs_dims)
    model_train.summary()
    return model_train, model_infer

In [8]:
TrainBatch = Tuple[tf.Tensor, tf.Tensor]
Datasets = Tuple[tf.data.Dataset, tf.data.Dataset]


@tf.function
def train_step(
        model: Model, optimizer: Optimizer, batch: TrainBatch,
        committment_cost: float=0.25, data_variance: float=1.0):

    x, y_true = batch
    with tf.GradientTape() as tape:
        reconst, z_enc, z_quantized = model(x)

        committment_loss = tf.reduce_mean((tf.stop_gradient(z_quantized) - z_enc) ** 2)
        codebook_loss = tf.reduce_mean((z_quantized - tf.stop_gradient(z_enc)) ** 2)
        vqvae_loss = committment_cost * committment_loss + codebook_loss

        reconst_loss = tf.reduce_mean((y_true - reconst) ** 2) / data_variance
        total_loss = reconst_loss + vqvae_loss

    grads = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return vqvae_loss, reconst_loss


@tf.function
def eval_step(
        model: Model, batch: TrainBatch,
        committment_cost: float=0.25, data_variance: float=1.0):

    x, y_true = batch
    reconst, z_enc, z_quantized = model(x)

    committment_loss = tf.reduce_mean((tf.stop_gradient(z_quantized) - z_enc) ** 2)
    codebook_loss = tf.reduce_mean((z_quantized - tf.stop_gradient(z_enc)) ** 2)
    vqvae_loss = committment_cost * committment_loss + codebook_loss

    reconst_loss = tf.reduce_mean((y_true - reconst) ** 2) / data_variance
    total_loss = reconst_loss + vqvae_loss
    return vqvae_loss, reconst_loss


class LossLogger:
    def __init__(self):
        self.writer = tf.summary.create_file_writer("logs/vae")
        self.vqvae_loss_mean = Mean()
        self.reconst_loss_mean = Mean()
        self.total_loss_mean = Mean()

    def log_losses(self, vqvae_loss: float, reconst_loss: float):
        total_loss = reconst_loss + vqvae_loss
        self.vqvae_loss_mean(vqvae_loss)
        self.reconst_loss_mean(reconst_loss)
        self.total_loss_mean(total_loss)

    def flush_losses(self, step: int, mode: str):
        with self.writer.as_default():
            tf.summary.scalar(f"{mode}/vqvae_loss", self.vqvae_loss_mean.result(), step=step)
            tf.summary.scalar(f"{mode}/reconst_loss", self.reconst_loss_mean.result(), step=step)
            tf.summary.scalar(f"{mode}/total_loss", self.total_loss_mean.result(), step=step)
        self.vqvae_loss_mean.reset_state()
        self.reconst_loss_mean.reset_state()
        self.total_loss_mean.reset_state()


def train(settings: DreamerSettings, epochs: int, datasets: Datasets):
    train_dataset, test_dataset = datasets
    model_train, model_infer = create_model(settings)
    optimizer = Adam()
    loss_logger = LossLogger()
    cost = 0.25

    for ep in tqdm(range(epochs)):
        for batch in train_dataset:
            vqvae_loss, reconst_loss = train_step(model_train, optimizer, batch, cost)
            loss_logger.log_losses(vqvae_loss, reconst_loss)
        loss_logger.flush_losses(ep + 1, "train")

        for batch in test_dataset:
            vqvae_loss, reconst_loss = eval_step(model_train, batch, cost)
            loss_logger.log_losses(vqvae_loss, reconst_loss)
        loss_logger.flush_losses(ep + 1, "eval")

    return model_infer


settings = DreamerSettings([1], [64, 64, 3], [32, 32], [512], [1024])
env = gym.make("ALE/Pacman-v5")
train_dataset, test_dataset = generate_datasets(env, settings, 128, 16_384)
model = train(settings, 500, (train_dataset, test_dataset))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Model: "vqvae"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 img_orig (InputLayer)          [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 encoder_model (Functional)     (None, 4, 4, 64)     149504      ['img_orig[0][0]']               
                                                                                                  
 flatten (Flatten)              (None, 1024)         0           ['encoder_model[0][0]']          
                                                                                           

100%|██████████| 500/500 [41:55<00:00,  5.03s/it]


In [10]:
def show_subplot(original, reconstructed):
    plt.subplot(1, 2, 1)
    plt.imshow(original.squeeze())
    plt.title("Original")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(reconstructed.squeeze())
    plt.title("Reconstructed")
    plt.axis("off")

    plt.show()


def eval_on_test_data(model: Model, num_examples: int, test_dataset):
    test_batches = test_dataset.unbatch().shuffle(100).batch(num_examples)
    test_images = next(iter(test_batches.take(1)))[0].numpy()
    reconst_images = np.clip(model.predict(test_images), 0.0, 1.0)
    for test_image, reconst_image in zip(test_images, reconst_images):
        show_subplot(test_image, reconst_image)


eval_on_test_data(model, 128, test_dataset)

InvalidArgumentError: Graph execution error:

Detected at node 'vqvae/concatenate/concat' defined at (most recent call last):
    File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/usr/local/lib/python3.8/dist-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/usr/local/lib/python3.8/dist-packages/traitlets/config/application.py", line 982, in launch_instance
      app.start()
    File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelapp.py", line 505, in start
      self.io_loop.start()
    File "/usr/local/lib/python3.8/dist-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/usr/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/usr/local/lib/python3.8/dist-packages/tornado/ioloop.py", line 687, in <lambda>
      lambda f: self._run_callback(functools.partial(callback, future))
    File "/usr/local/lib/python3.8/dist-packages/tornado/ioloop.py", line 740, in _run_callback
      ret = callback()
    File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 821, in inner
      self.ctx_run(self.run)
    File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 782, in run
      yielded = self.gen.send(value)
    File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelbase.py", line 378, in dispatch_queue
      yield self.process_one()
    File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 250, in wrapper
      runner = Runner(ctx_run, result, future, yielded)
    File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 748, in __init__
      self.ctx_run(self.run)
    File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 782, in run
      yielded = self.gen.send(value)
    File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelbase.py", line 365, in process_one
      yield gen.maybe_future(dispatch(*args))
    File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 234, in wrapper
      yielded = ctx_run(next, result)
    File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelbase.py", line 272, in dispatch_shell
      yield gen.maybe_future(handler(stream, idents, msg))
    File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 234, in wrapper
      yielded = ctx_run(next, result)
    File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelbase.py", line 540, in execute_request
      self.do_execute(
    File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 234, in wrapper
      yielded = ctx_run(next, result)
    File "/usr/local/lib/python3.8/dist-packages/ipykernel/ipkernel.py", line 294, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/usr/local/lib/python3.8/dist-packages/ipykernel/zmqshell.py", line 536, in run_cell
      return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "/usr/local/lib/python3.8/dist-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "<ipython-input-9-7adf037ddd7d>", line 23, in <module>
      eval_on_test_data(model, 50, test_dataset)
    File "<ipython-input-9-7adf037ddd7d>", line 18, in eval_on_test_data
      reconst_images = np.clip(model.predict(test_images), 0.0, 1.0)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2350, in predict
      tmp_batch_outputs = self.predict_function(iterator)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2137, in predict_function
      return step_function(self, iterator)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2123, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2111, in run_step
      outputs = model.predict_step(data)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2079, in predict_step
      return self(x, training=False)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 561, in __call__
      return super().__call__(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 511, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 668, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.8/dist-packages/keras/layers/merging/base_merge.py", line 196, in call
      return self._merge_function(inputs)
    File "/usr/local/lib/python3.8/dist-packages/keras/layers/merging/concatenate.py", line 134, in _merge_function
      return backend.concatenate(inputs, axis=self.axis)
    File "/usr/local/lib/python3.8/dist-packages/keras/backend.py", line 3572, in concatenate
      return tf.concat([to_dense(x) for x in tensors], axis)
Node: 'vqvae/concatenate/concat'
ConcatOp : Dimension 0 in both shapes must be equal: shape[0] = [32,1024] vs. shape[1] = [128,512]
	 [[{{node vqvae/concatenate/concat}}]] [Op:__inference_predict_function_2612706]