In [None]:
import numpy as ny
import tensorflow as tw

In [None]:
class sp(tw.keras.layers.Layer):
    def call(self, ip):
        mz, vlz = ip
        b = tw.shape(mz)[0]
        d = tw.shape(mz)[1]
        ep = tw.keras.backend.random_normal(shape=(b, d))
        return mz + tw.exp(0.5 * vlz) * ep

In [None]:
ldim = 2

ipenc = tw.keras.Input(shape=(28, 28, 1))
md = tw.keras.layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(ipenc)
md = tw.keras.layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(md)
md = tw.keras.layers.Flatten()(md)
md = tw.keras.layers.Dense(16, activation="relu")(md)
mz = tw.keras.layers.Dense(ldim, name="mz")(md)
vlz = tw.keras.layers.Dense(ldim, name="vlz")(md)
zi = sp()([mz, vlz])
enc = tw.keras.Model(ipenc, [mz, vlz, zi], name="enc")
enc.summary()

In [None]:
ipl = tw.keras.Input(shape=(ldim,))
md = tw.keras.layers.Dense(7 * 7 * 64, activation="relu")(ipl)
md = tw.keras.layers.Reshape((7, 7, 64))(md)
md = tw.keras.layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(md)
md = tw.keras.layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(md)
odec = tw.keras.layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(md)
dec = tw.keras.Model(ipl, odec, name="dec")
dec.summary()

In [None]:
class VAE(tw.keras.Model):
    def __init__(sl, e, d, **p):
        super(VAE, sl).__init__(**p)
        sl.en = e
        sl.de = d

    def train_step(s, dt):
        if isinstance(dt, tuple):
            dt = dt[0]
        with tw.GradientTape() as tg:
            mz, vlz, zi = enc(dt)
            rec = dec(zi)
            recl = tw.reduce_mean(
                tw.keras.losses.binary_crossentropy(dt, rec)
            )
            recl *= 28 * 28
            kll = 1 + vlz - tw.square(mz) - tw.exp(vlz)
            kll = tw.reduce_mean(kll)
            kll *= -0.5
            ttl = recl + kll
        gs = tg.gradient(ttl, s.trainable_weights)
        s.optimizer.apply_gradients(zip(gs, s.trainable_weights))
        return {
            "total loss": ttl,
            "reconstruction loss": recl,
            "kl loss": kll,
        }

In [None]:
(tx, _), (ttx, _) = tw.keras.datasets.mnist.load_data()
mgs = ny.concatenate([tx, ttx], axis=0)
mgs = ny.expand_dims(mgs, -1).astype("float32") / 255

v = VAE(enc, dec)
v.compile(optimizer=tw.keras.optimizers.Adam())
v.fit(mgs, epochs=1, batch_size=8192)

In [None]:
from matplotlib import pyplot


def pl(enc, dec):
    num = 30
    ds = 28
    sle = 2.0
    figs = 15
    fig = ny.zeros((ds * num, ds * num))
    gx = ny.linspace(-sle, sle, num)
    gy = ny.linspace(-sle, sle, num)[::-1]

    for k, ky in enumerate(gy):
        for l, kx in enumerate(gx):
            sz = ny.array([[kx, ky]])
            decx = dec.predict(sz)
            dg = decx[0].reshape(ds, ds)
            fig[
                k * ds : (1 + k) * ds,
                l * ds : (1 + l) * ds,
            ] = dg

    pyplot.figure(figsize=(figs, figs))
    sr = ds // 2
    er = num * ds + sr
    pr = ny.arange(sr, er, ds)
    srx = ny.round(gx, 1)
    sry = ny.round(gy, 1)
    pyplot.xticks(pr, srx)
    pyplot.yticks(pr, sry)
    pyplot.xlabel("z[0]")
    pyplot.ylabel("z[1]")
    pyplot.imshow(fig, cmap="Greys_r")
    pyplot.show()


pl(enc, dec)