# Autoregressives Modell auf MNIST mit Autoencoder

In [None]:
import numpy as np
import time
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
import os
import random

#Kleine Module von Lukas Rinder https://github.com/LukasRinder/normalizing-flows:
from LukasRinder.LukasRinder import load_and_preprocess_mnist
from LukasRinder.LukasRinder import Made
from LukasRinder.LukasRinder import train_density_estimation, nll

from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, losses
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model

tfd = tfp.distributions
tfb = tfp.bijectors

tf.random.set_seed(1234)

## Dieser Abschnitt stammt von https://www.tensorflow.org/tutorials/generative/autoencoder Zugriff: 28.01.2022

### Daten Laden

In [None]:
(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

### Autoencoder implementieren.

In [None]:
latent_dim = 64 

class Autoencoder(Model):
  def __init__(self, latent_dim):
    super(Autoencoder, self).__init__()
    self.latent_dim = latent_dim   
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(latent_dim, activation='relu'),
    ])
    self.decoder = tf.keras.Sequential([
      layers.Dense(784, activation='sigmoid'),
      layers.Reshape((28, 28))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

autoencoder = Autoencoder(latent_dim)

In [None]:
autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())

### Autoencoder trainieren.

In [None]:
autoencoder.fit(x_train, x_train,
                epochs=10,
                shuffle=True,
                validation_data=(x_test, x_test))

### Autoencoder testen.

In [None]:
encoded_imgs = autoencoder.encoder(x_test).numpy()
decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()

In [None]:
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  # display original
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # display reconstruction
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

### Daten encoden und vorbereiten.

In [None]:
low_dim_train = autoencoder.encoder(x_train[:50000]).numpy()
low_dim_val = autoencoder.encoder(x_train[50000:]).numpy()

In [None]:
maximum = -1000000
for i in low_dim_train:
    if max(i) > maximum:
        maximum = max(i)
for j in low_dim_val:
    if max(j) > maximum:
        maximum = max(j)
print(maximum)

In [None]:
minimum = 1000000
for i in low_dim_train:
    if min(i) < minimum:
        minimum = min(i)
for j in low_dim_val:
    if min(j) < minimum:
        minimum = min(j)
print(minimum)

In [None]:
low_dim_train_scale = tf.cast(low_dim_train / 29.337349, tf.float32)
low_dim_val_scale = tf.cast(low_dim_val / 29.337349, tf.float32)
low_dim_train_scale = tf.reshape(low_dim_train_scale, (low_dim_train_scale.shape[0], 8, 8))
low_dim_val_scale = tf.reshape(low_dim_val_scale, (low_dim_val_scale.shape[0], 8, 8))

In [None]:
batch_s = 128

shuffled_train = tf.data.Dataset.from_tensor_slices(low_dim_train_scale).shuffle(1000)
batched_train = shuffled_train.batch(batch_s)
batched_val = tf.data.Dataset.from_tensor_slices(low_dim_val_scale).batch(batch_s)

### Visualisierung der 8x8 Bilder aus Spaß.

In [None]:
plt.imshow(next(iter(batched_train))[0], cmap='gray')

### Funktion wie in anderen Beispielen.

In [None]:
def AutoregressiveFlow(dimension, layers, hidden_shape=[512, 512], activation="relu", inverse=False):
    base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros(shape=dimension, dtype=tf.float32))
    bijectors = []
    permutation = tf.cast(np.concatenate((np.arange(dimension/2,dimension),np.arange(0,dimension/2))), tf.int32)
    params=0
    if inverse:
        for i in range(layers):
            bijectors.append(tfb.Invert(tfb.MaskedAutoregressiveFlow(
                shift_and_log_scale_fn = Made(params=2, hidden_units=hidden_shape, activation=activation))))
            bijectors.append(tfb.Permute(permutation=permutation))
    else:
        for i in range(layers):
            bijectors.append(tfb.MaskedAutoregressiveFlow(
                shift_and_log_scale_fn = Made(params=2, hidden_units=hidden_shape, activation=activation)))
            bijectors.append(tfb.Permute(permutation=permutation))
        
    
    bijectors.append(tfb.Reshape(event_shape_out=(int(np.sqrt(dimension)),int(np.sqrt(dimension))),
                                 event_shape_in=(dimension,)))
    bijector = tfb.Chain(bijectors=list(reversed(bijectors)))
    
    masked_auto_flow = tfd.TransformedDistribution(distribution=base_dist, bijector=bijector)
    masked_auto_flow.log_prob(tf.reshape(base_dist.sample(), (8, 8)))
    for theta in masked_auto_flow.trainable_variables:
        params += np.prod(theta.shape)
    print("trainable parameters:", params)
    return masked_auto_flow, base_dist, bijectors, bijector

### Parameter festlegen und einen Namen für die Checkpoints festlegen.

In [None]:
dataset = "mnist_auto"
layers = 20
base_lr = 1e-3
end_lr = 1e-4
epochs = 200
shape = [128, 128]
mnist_trainsize = 50000
dimension = 64

### Modell initialisieren. In diesem Stadium entspricht MAF der Startverteilung bzw. full_bijector der Identitätsabbildung.

In [None]:
MAF, base_dist, list_of_bijectors, full_bijector = AutoregressiveFlow(dimension, layers, shape, inverse=False)

In [None]:
learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(base_lr, epochs, end_lr, power=0.5)
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [None]:
ckpt_dir = f"{dataset}/tmp_{layers}"
ckpt_prefix = os.path.join(ckpt_dir, "ckpt")

ckpt = tf.train.Checkpoint(optimizer=opt, model=MAF)

In [None]:
def TrainFlow(flow, batched_train, batched_val, epochs, train_size, optimizer, checkpoint, checkpoint_pref):

    t_losses, v_losses = [], []
    t_start = time.time()
    
    for i in range(epochs):
        batched_train.shuffle(buffer_size=train_size, reshuffle_each_iteration=True)
        batch_t_losses = []
        for batch in batched_train:
            batch_loss = train_density_estimation(flow, optimizer, batch)
            batch_t_losses.append(batch_loss)
        t_loss = tf.reduce_mean(batch_t_losses)

        batch_v_losses = []
        for batch in batched_val:
            batch_loss = nll(flow, batch)
            batch_v_losses.append(batch_loss)
        v_loss = tf.reduce_mean(batch_v_losses)

        t_losses.append(t_loss)
        v_losses.append(v_loss)
        print(f"Epoch {i+1}: train loss: {t_loss}, val loss: {v_loss}")
        
        if i == 0:
            min_v_loss = v_loss
            best_epoch = 0
        if v_loss < min_v_loss:
            min_v_loss = v_loss
            best_epoch = i
            checkpoint.write(file_prefix=checkpoint_pref)
                
    print("train time:", time.time() - t_start)
    
    return t_losses, v_losses

In [None]:
train_losses, val_losses = TrainFlow(MAF, batched_train, batched_val, 
                                     epochs, mnist_trainsize, opt, ckpt, ckpt_prefix)

In [None]:
plt.plot(range(len(train_losses)), train_losses, label="train loss")
plt.plot(range(len(val_losses)), val_losses, label="val loss")
plt.legend()

In [None]:
ckpt.restore(ckpt_prefix)

### Stichproben generieren und durchschnittliche Zeit ausgeben.
#### Dazu müssen 8x8 Stichproben der zug. Normalverteilung generiert, dann transformiert und anschließend decoded werden.

In [None]:
s_time = time.time()
examples = tf.reshape(MAF.sample(20)*maximum, (20, 64)).numpy()
examples = autoencoder.decoder(examples).numpy()
sample_time = time.time() -s_time
sample_time = sample_time/20
sample_time

### Stichporben visualisieren.

In [None]:
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(examples[i])
  plt.title("generated")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()