# Training the VAE

## Setup

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import pathlib
# import pandas as pd
import keras
from keras import layers, ops
from datetime import date

## Get the data

In [None]:
data_root = keras.utils.get_file(
    origin="https://www.kaggle.com/api/v1/datasets/download/jacksoncrow/stock-market-dataset",
    cache_dir='.', cache_subdir='data',
    extract=True)

In [None]:
tickers = np.load("tickers.npy")

In [None]:
data_root_path = pathlib.Path(data_root)
stock_dir = data_root_path.joinpath("stocks")
data_strs = [str(x) for x in stock_dir.iterdir() if x.stem in tickers]
data_columns = ["Open", "High", "Low", "Close", "Adj Close", "Volume"]

In [None]:
dataloader = tf.data.experimental.make_csv_dataset(
    file_pattern=data_strs,
    batch_size=16,
    column_defaults=["float32" for _ in range(6)],
    num_epochs=1,
    select_columns=data_columns,
).ignore_errors()
dataloader

## Create a sampling layer

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def __init__(self, seed=None, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(seed=seed)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon

## Build the encoder

In [None]:
# https://stackoverflow.com/a/65246730
input_list = []
for column in data_columns:
    _input = keras.layers.Input(shape=(1,), name=column, dtype="float32")
    input_list.append(_input)

In [None]:
latent_dim = 2

encoder_inputs = layers.Concatenate(name="concat", trainable=False)(input_list)
x = layers.Dense(5, activation="relu")(encoder_inputs)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling(seed=1337)([z_mean, z_log_var])
encoder = keras.Model(input_list, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

## Build the decoder

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(5, activation="relu")(latent_inputs)
decoder_outputs = layers.Dense(6, activation="sigmoid")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

## Define the VAE as a `Model` with a custom `train_step`

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        _goodshape = tf.stack(list(data.values()), axis=1)
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = ops.mean(
                ops.sum(
                    keras.losses.binary_crossentropy(_goodshape, reconstruction),
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
            kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        print(total_loss)
        print(reconstruction_loss)
        print(kl_loss)
        print()
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

## Train the VAE

In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())

In [None]:
history = vae.fit(dataloader, epochs=30, batch_size=128)