# Training the VAE

## Setup

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import pathlib
import keras
from keras import layers, ops

## Get the data

In [None]:
data_root = keras.utils.get_file(
    origin="https://www.kaggle.com/api/v1/datasets/download/jacksoncrow/stock-market-dataset",
    cache_dir='.', cache_subdir='data',
    extract=True)

In [None]:
tickers = np.load("tickers.npy")

In [None]:
data_root_path = pathlib.Path(data_root)
stock_dir = data_root_path.joinpath("stocks")
data_strs = [str(x) for x in stock_dir.iterdir() if x.stem in tickers]
data_columns = ["Open", "High", "Low", "Close", "Adj Close", "Volume"]
input_stacker = lambda x: tf.stack(list(x.values()), axis=1)  # TODO inefficient (but easy)
target_adder = lambda x: (x,x)

In [None]:
dataloader = tf.data.experimental.make_csv_dataset(
    file_pattern=data_strs,
    batch_size=128,
    column_defaults=["float32" for _ in range(6)],
    num_epochs=1,
    select_columns=data_columns,
).ignore_errors().map(input_stacker).map(target_adder)
dataloader

## Create a custom sampling layer and VAE model

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def __init__(self, kl_loss_factor=1, seed=None, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(seed=seed)
        self.kl_loss_factor = kl_loss_factor

    def call(self, inputs):
        z_mean, z_log_var = inputs
        # add loss
        kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
        self.add_loss(kl_loss * self.kl_loss_factor)
        # sample
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon

class VAE(keras.Model):
    """Wraps an encoder and decoder into a single variational autoencoder model."""
    def __init__(self, encoder, decoder, **kwargs):
        # TODO hacky?
        vae_inputs = keras.Input(shape=encoder.input.shape[1:], name="vae_inputs")
        vae_outputs = decoder(encoder(vae_inputs))
        super().__init__(inputs=vae_inputs, outputs=vae_outputs, **kwargs)
        self.encoder = encoder
        self.decoder = decoder

## Build the models

In [None]:
input_shape = (6,)
latent_dim = 2

In [None]:
encoder_inputs = keras.Input(shape=input_shape, name="encoder_inputs")

x = layers.Dense(5, activation="relu")(encoder_inputs)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)

encoder_outputs = Sampling(kl_loss_factor=1, seed=489, name="encoder_outputs")([z_mean, z_log_var])

encoder = keras.Model(inputs=encoder_inputs, outputs=encoder_outputs, name="encoder")
encoder.summary()

In [None]:
decoder_inputs = keras.Input(shape=(latent_dim,), name="decoder_inputs")

x = layers.Dense(5, activation="relu")(decoder_inputs)
decoder_outputs = layers.Dense(6, activation="sigmoid", name="decoder_outputs")(x)

decoder = keras.Model(inputs=decoder_inputs, outputs=decoder_outputs, name="decoder")
decoder.summary()

In [None]:
vae = VAE(encoder=encoder, decoder=decoder, name="vae")
vae.summary()

## Train the VAE

In [None]:
vae.compile(optimizer="adam", loss="mean_squared_error", metrics=["accuracy", "mean_squared_error"])
# note: still shows nan which is concerning, but otherwise we're doing well I think
history = vae.fit(dataloader, epochs=30, batch_size=128)