In [1]:
!pip install wheel==0.37.1 setuptools==59.6.0
!pip install gym[atari,accept-rom-license]==0.21.0 tensorflow-probability==0.19.0

You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [16]:
from math import ceil
from dataclasses import dataclass
from typing import List, Tuple

import gym
import numpy as np

import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow import keras
from keras import Model
from keras.layers import \
    Input, Dense, Reshape, Lambda, Flatten, \
    Conv2D, Conv2DTranspose, Dropout, BatchNormalization

tfpl = tfp.layers
tfd = tfp.distributions
KLDivergenceRegularizer = tfp.layers.KLDivergenceRegularizer
MultivariateNormalTriL = tfp.layers.MultivariateNormalTriL
IndependentBernoulli = tfp.layers.IndependentBernoulli
IndependentNormal = tfp.layers.IndependentNormal
Independent = tfp.distributions.Independent
Bernoulli = tfp.distributions.Bernoulli
Normal = tfp.distributions.Normal
MultivariateNormalDiag = tfp.distributions.MultivariateNormalDiag

In [10]:
print(tf.__version__)
print(tf.test.gpu_device_name())

2.11.0
/device:GPU:0


In [11]:
@dataclass
class DreamerSettings:
    action_dims: List[int]
    obs_dims: List[int]
    repr_dims: List[int]
    hidden_dims: List[int]
    enc_dims: List[int]
    dropout_rate: float = 0.2

    @property
    def repr_dims_flat(self) -> int:
        return self.repr_dims[0] * self.repr_dims[1]

    @property
    def repr_out_dims_flat(self) -> int:
        return self.repr_dims[0] * self.repr_dims[1] + self.hidden_dims[0]

    @property
    def obs_dims_flat(self) -> int:
        return self.obs_dims[0] * self.obs_dims[1] * self.obs_dims[2]

In [12]:
def sample_obs(env: gym.Env, num_obs: int) -> tf.data.Dataset:
    all_obs = []
    done = False
    obs = env.reset()
    all_obs.append(obs)
    while len(all_obs) < num_obs:
        action = env.action_space.sample()
        obs, reward, done, _ = env.step(action)
        all_obs.append(obs)
        if done and len(all_obs) < num_obs:
            obs = env.reset()
            all_obs.append(obs)
    return tf.data.Dataset.from_tensor_slices(tensors=(np.array(all_obs)))

In [6]:
env = gym.make("ALE/Pacman-v5")
settings = DreamerSettings([1], [64, 64, 3], [32, 32], [512], [128])

timesteps = 10_000
dataset = sample_obs(env, timesteps)
dataset = dataset.map(lambda obs: tf.image.resize(obs, settings.obs_dims[:2]))
dataset = dataset.map(lambda obs: (obs / 255.0, obs / 255.0))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [13]:
def create_repr_encoder(settings: DreamerSettings) -> Model:
    enc_units = MultivariateNormalTriL.params_size(settings.enc_dims[0])
    model_in = Input(settings.obs_dims, name="enc_out")
    norm_img = Lambda(lambda x: x * 2.0 - 1.0)
    batch_norm = BatchNormalization()
    cnn_1 = Conv2D(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_2 = Conv2D(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_3 = Conv2D(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_4 = Conv2D(8, (3, 3), padding="same", activation="elu")
    drop_1 = Dropout(rate=settings.dropout_rate)
    drop_2 = Dropout(rate=settings.dropout_rate)
    drop_3 = Dropout(rate=settings.dropout_rate)
    drop_4 = Dropout(rate=settings.dropout_rate)
    flatten = Flatten()
    dense_out = Dense(enc_units, activation="linear", name="enc_dense")
    posterior_out = MultivariateNormalTriL(settings.enc_dims[0], name="posterior_enc")

    img_in = batch_norm(norm_img(model_in))
    prep_model_convs = drop_4(cnn_4(drop_3(cnn_3(drop_2(cnn_2(drop_1(cnn_1(img_in))))))))
    model_out = posterior_out(dense_out(flatten(prep_model_convs)))
    return Model(inputs=model_in, outputs=model_out, name="encoder_model")


def create_repr_decoder(settings: DreamerSettings) -> Model:
    image_channels = settings.obs_dims[-1]
    upscale_source_dims = (settings.obs_dims[0] // 8 * settings.obs_dims[1] // 8) * 8

    post_units = IndependentNormal.params_size(settings.obs_dims_flat)
    cov_flat = post_units // (settings.obs_dims_flat // image_channels)

    model_in = Input(settings.enc_dims[0], name="repr_out")
    dense_in = Dense(upscale_source_dims, activation="linear", name="dec_in")
    reshape_in = Reshape((settings.obs_dims[0] // 8, settings.obs_dims[1] // 8, -1))
    cnn_1 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_2 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_3 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same", activation="elu")
    cnn_4 = Conv2D(64, (3, 3), padding="same", activation="elu")
    cnn_5 = Conv2D(cov_flat, (1, 1), padding="same", activation="linear")
    drop_1 = Dropout(rate=settings.dropout_rate)
    drop_2 = Dropout(rate=settings.dropout_rate)
    drop_3 = Dropout(rate=settings.dropout_rate)
    flatten = Flatten()
    likelihood_out = IndependentNormal(settings.obs_dims)

    prep_in = reshape_in(dense_in(model_in))
    conv_out = cnn_5(cnn_4(drop_3(cnn_3(drop_2(cnn_2(drop_1(cnn_1(prep_in))))))))
    model_out = likelihood_out(flatten(conv_out))
    return Model(inputs=model_in, outputs=model_out, name="decoder_model")


def compose_models(settings: DreamerSettings) -> Tuple[Model, Model]:
    model_in = Input(settings.obs_dims)
    encoder = create_repr_encoder(settings)
    decoder = create_repr_decoder(settings)
    prior = Independent(
        Normal(loc=tf.zeros(settings.enc_dims[0]), scale=1),
        reinterpreted_batch_ndims=1)

    posterior = encoder(model_in)
    reconst_dist = decoder(posterior)
    model = Model(inputs=[model_in], outputs=[reconst_dist])

    likelihood = reconst_dist.log_prob()
    divergence = tfd.kl_divergence(posterior, prior)
    elbo_loss = tf.reduce_mean(likelihood - divergence)
    loss_model = Model(inputs=[model_in], outputs=[elbo_loss])

    return model, loss_model

In [14]:
model, loss = compose_models(settings)
model.build([None] + settings.obs_dims)

In [15]:
batch_size = 64
num_eval_batches = ceil((timesteps // batch_size) * 0.1)
dataset = dataset.batch(64)
train_dataset = dataset.skip(num_eval_batches)
eval_dataset = dataset.take(num_eval_batches)
train_dataset = train_dataset.shuffle(100)

loss = lambda y_true, y_pred_dist: -y_pred_dist.log_prob(y_true)
model.compile(optimizer="adam", loss=loss, metrics=["mse"])
model.fit(x=train_dataset, epochs=500, validation_data=eval_dataset)
model.save_weights("vae.h5")

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500


Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78/500
Epoch 79/500
Epoch 80/500
Epoch 81/500
Epoch 82/500
Epoch 83/500
Epoch 84/500
Epoch 85/500
Epoch 86/500
Epoch 87/500
Epoch 88/500
Epoch 89/500
Epoch 90/500
Epoch 91/500
Epoch 92/500
Epoch 93/500
Epoch 94/500
  3/141 [..............................] - ETA: 5s - loss: nan - mse: nan 

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

def display_imgs(x, y=None):
    if not isinstance(x, (np.ndarray, np.generic)):
        x = np.array(x)
    plt.ioff()
    n = x.shape[0]
    fig, axs = plt.subplots(1, n, figsize=(n, 1))
    if y is not None:
        fig.suptitle(np.argmax(y, axis=1))
    for i in range(n):
        axs.flat[i].imshow(x[i].squeeze(), interpolation='none', cmap='gray')
        axs.flat[i].axis('off')
    plt.show()
    plt.close()
    plt.ion()

In [None]:
num_rows = 5
img_dataset = train_dataset.unbatch().shuffle(100).batch(num_rows)
img_in, _ = iter(img_dataset).next()
img_out = model(img_in).sample().numpy()
img_in = img_in.numpy()
img_out = np.clip(img_out, 0.0, 1.0)

print(np.max(img_out))
print(np.min(img_out))

display_imgs(img_in)
display_imgs(img_out)