Imports for necessary libraries:

In [1]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
import astropy.io.fits as fits
import os
import sys

import pywst as pw

Below is the code for importing WST coefficients and RWST coefficients calculated using Thorne et al's methodology. These serve as data inputs for our VAE. The RWST's are downloaded solely for reference and are not used in comparison or loss calculation since VAE's are a typically unsupervised learning process.

In [2]:
import pickle
wst_p = pickle.load(open('./wst_rwst_coeffs/WST_polarized_dust.p', 'rb'))
rwst_p = pickle.load(open('./wst_rwst_coeffs/RWST_polarized_dust.p', 'rb'))

# VAE with TensorFlow
Imports for Tensorflow and test-training data splitting

In [3]:
from IPython import display

import glob
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import time
from sklearn.model_selection import train_test_split

2022-12-05 15:29:24.049296: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Below is the preprocessing in which we take out the one S0 coefficient so the scaling between layers goes 1360->680->340->170 for VAE-1 and 1360->680->340->170->85 for VAE-2. Then the S0 coefficient is added back to have 171 (or 86) for OUR RWST coefficients. This preprocessing of preserving the S0 before and after the embedding was done similar to the past research paper Thorne et al.

In [4]:
train_wst, test_wst, train_rwst, test_rwst = train_test_split(wst_p, rwst_p, test_size=0.2, random_state=1)

train_wst_s0 = train_wst[:,0]
train_wst = train_wst[:,1:]

test_wst_s0 = test_wst[:,0]
test_wst = test_wst[:,1:]

train_rwst_s0 = train_rwst[:,0]
train_rwst = train_rwst[:,1:]

test_rwst_s0 = test_rwst[:,0]
test_rwst = test_rwst[:,1:]

Below is the first Variational Autoencoder designed for this study. Its embedding in the latent space is in the 170th dimension. Its architecture is outlined in the written report. 

Underneath the VAE's initializer are four functions for operating the neural net. First is the "sample" function that takes a vector from a normal distribution in the input space. This vector can then be fed into the neural net to transform it into the latent space using the next function "encode". The "encode" function takes in one vector that should be of shape (1,1360) and feeds this vector into the neural net's encoder attribute that returns the vector's transformation into the latent space distribution with a mean and standard deviation. The function "reparameterize" calculates a specific vector z in the latent space from the calculated distribution returned by the "encode" function. The purpose of "reparameterize" is to take a sample vector from the latent space distribution to feed into the "decode" function and decode attribute to later backpropagate into the encoding and decoding functions. The last function is "decode" which is meant to take a sample vector from distributions in the latent space and decode them into the original input space. The purpose of this function is to transform vectors from latent to input space and calculate the loss between these decoded vectors and where they fall in the input data's distribution.

In [5]:
class CVAE_1(tf.keras.Model):
    def __init__(self, latent_dim):
        super(CVAE_1, self).__init__()
        self.latent_dim = latent_dim
        input_shape=(1,1360,1)
        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=input_shape),
                tf.keras.layers.Conv2D(
                    filters=64,
                    kernel_size=28, #decided by doing input_dim / 9; similar to keras documentation
                    strides=2, 
#                     dilation_rate=2,
                    padding='same',
                    activation='relu', 
                    input_shape=(1,1360,1)), 
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 1 finish
                tf.keras.layers.Conv2D(
                    filters=32,
                    kernel_size=14, 
                    strides=2, 
#                     dilation_rate=2,
                    padding='same',
                    activation='relu', 
                    input_shape=(1,680,64)),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 2 finish
                tf.keras.layers.Conv2D(
                    filters=1,
                    kernel_size=7, 
                    strides=2,
#                     dilation_rate=2,
                    padding='same',
                    activation='relu', 
                    input_shape=(1,340,32)),
                tf.keras.layers.Reshape((1,1,170)),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 3 finish
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(680),
                tf.keras.layers.Dense(2*latent_dim)
            ]
        )
        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(170)),
                tf.keras.layers.Dense(5440),
                tf.keras.layers.Reshape((1,170,32)),
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=128, 
                    strides=(1,2),
                    kernel_size=14, 
                    padding='same',
                    activation='relu',
                    input_shape=(1,170,32)), #layer 1 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=64, 
                    strides=1,
                    kernel_size=14, 
                    padding='same',
                    activation='relu',
                    input_shape=(1,340,128)), #layer 2 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=32, 
                    strides=(1,2),
                    kernel_size=28, 
                    padding='same',
                    activation='relu',
                    input_shape=(1,340,64)), #layer 3 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=16, 
                    strides=(1,2),
                    kernel_size=56, 
                    padding='same',
                    activation='relu',
                    input_shape=(1,680,32)), #layer 4 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=1, 
                    strides=1,
                    kernel_size=56, 
                    padding='same',
                    input_shape=(1,1360,1)) #layer 5 finish
            ]
        )
        
        
    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100,self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)
    
    def encode(self, x):
#         print(self.encoder(x))
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
#         print(mean)
#         print(logvar)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean
    
    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

In [6]:
class CVAE_2(tf.keras.Model):
    def __init__(self, latent_dim):
        super(CVAE_2, self).__init__()
        self.latent_dim = latent_dim
        input_shape = (1,1360,1)
        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=input_shape),
                tf.keras.layers.Conv2D(
                    filters=64,
                    kernel_size=28,
                    strides=2, 
                    padding='same',
                    activation='relu', 
                    input_shape=(1,1360,1)), 
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 1 finish
                tf.keras.layers.Conv2D(
                    filters=32,
                    kernel_size=14, 
                    strides=2, 
                    padding='same',
                    activation='relu', 
                    input_shape=(1,680,64)),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 2 finish
                tf.keras.layers.Conv2D(
                    filters=16,
                    kernel_size=7, 
                    strides=2,
                    padding='same',
                    activation='relu', 
                    input_shape=(1,340,32)),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 3 finish
                tf.keras.layers.Conv2D(
                    filters=1,
                    kernel_size=7, 
                    strides=2,
                    padding='same',
                    activation='relu', 
                    input_shape=(1,170,16)), #layer 3 finish
                tf.keras.layers.Reshape((1,1,85)),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 3 finish
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(340),
                tf.keras.layers.Dense(2*latent_dim)
            ]
        )
        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(85)),
                tf.keras.layers.Dense(2720),
                tf.keras.layers.Reshape((1,85,32)),
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=128, 
                    strides=(1,2),
                    kernel_size=14, 
                    padding='same',
                    activation='relu',
                    input_shape=(1,85,32)), #layer 1 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=64, 
                    strides=(1,2), #only thing different than the CVAE_Large!
                    kernel_size=14, 
                    padding='same',
                    activation='relu',
                    input_shape=(1,340,128)), #layer 2 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=32, 
                    strides=(1,2),
                    kernel_size=28, 
                    padding='same',
                    activation='relu',
                    input_shape=(1,340,64)), #layer 3 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=16, 
                    strides=(1,2),
                    kernel_size=56, 
                    padding='same',
                    activation='relu',
                    input_shape=(1,680,32)), #layer 4 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=1, 
                    strides=1,
                    kernel_size=56, 
                    padding='same',
                    input_shape=(1,1360,1)) #layer 5 finish
            ]
        )
        
        
        
    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100,self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)
    
    def encode(self, x):
#         print(self.encoder(x))
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean
    
    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

In [7]:
# x = tf.random.normal(shape=(1,85))
# test = tf.keras.Sequential(
#             [
#                 tf.keras.layers.InputLayer(input_shape=(85)),
#                 tf.keras.layers.Dense(2720),
#                 tf.keras.layers.Reshape((1,85,32)),
#                 tf.keras.layers.BatchNormalization(momentum=0.9),
#                 tf.keras.layers.Conv2DTranspose(
#                     filters=128, 
#                     strides=(1,2),
#                     kernel_size=14, 
#                     padding='same',
#                     activation='relu',
#                     input_shape=(1,85,32)), #layer 1 finish
#                 tf.keras.layers.BatchNormalization(momentum=0.9),
#                 tf.keras.layers.Conv2DTranspose(
#                     filters=64, 
#                     strides=(1,2), #only thing different than the CVAE_Large!
#                     kernel_size=14, 
#                     padding='same',
#                     activation='relu',
#                     input_shape=(1,340,128)), #layer 2 finish
#                 tf.keras.layers.BatchNormalization(momentum=0.9),
#                 tf.keras.layers.Conv2DTranspose(
#                     filters=32, 
#                     strides=(1,2),
#                     kernel_size=28, 
#                     padding='same',
#                     activation='relu',
#                     input_shape=(1,340,64)), #layer 3 finish
#                 tf.keras.layers.BatchNormalization(momentum=0.9),
#                 tf.keras.layers.Conv2DTranspose(
#                     filters=16, 
#                     strides=(1,2),
#                     kernel_size=56, 
#                     padding='same',
#                     activation='relu',
#                     input_shape=(1,680,32)), #layer 4 finish
#                 tf.keras.layers.BatchNormalization(momentum=0.9),
#                 tf.keras.layers.Conv2DTranspose(
#                     filters=1, 
#                     strides=1,
#                     kernel_size=56, 
#                     padding='same',
#                     input_shape=(1,1360,1)) #layer 5 finish
#             ]
#         )
# print(test(x).shape)

Below are functions for calculating the loss and updating the neural net parameters with a given optimizer. The first function log_normal_pdf takes three parameters: a sample input vector, a mean, and log of the distribution's variance. Its purpose is to calculate the logarithmic probability of a vector under a certain distribution. The function "compute_loss" uses "log_normal_pdf" for calculating the total loss of the encoding and decoding process to update the neural net's parameters. For calculating the loss, this function calculates the ELBO loss: a sum of the the log probability of input vector x given latent vector z, the log probability of latent vector z in the latent space normal distribution, and the log probability of latent vector z given input vector x under the input data's distribution. Lastly, the "train_step" function takes one input vector and inputs it into the VAE to get one returned encoded-decoded vector. The function then computes the ELBO loss using "compute_loss" and uses the gradient of each of the VAE's parameters with respect to this loss to update the parameters themselves.

In [8]:
optimizer = tf.keras.optimizers.Adam(1e-4)

Output_var_VAE_L = []
Latent_var_VAE_L = []
Output_var_VAE_L = []
Latent_var_VAE_S = []

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(-0.5 * ((sample - mean) ** 2. *tf.exp(-logvar) + logvar + log2pi), axis=raxis)

def compute_loss(model, x, testing=False):
    global Output_var_VAE_L, Output_var_VAE_L, Output_var_VAE_L, Latent_var_VAE_S
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logits = model.decode(z)
    if testing:
        if z.shape == (1, 170):
            Latent_var_VAE_L.append(z)
            Output_var_VAE_L.append(x_logits)
        elif z.shape == (1, 85):
            Latent_var_VAE_S.append(z)
            Output_var_VAE_S.append(x_logits)
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logits, labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    logpz = log_normal_pdf(z, 0., 0.)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)

@tf.function
def train_step(model, x, optimizer):
#     executes one training step and returns loss
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

In [9]:
VAE_L = CVAE_1(170)
VAE_S = CVAE_2(85)

2022-12-05 15:29:51.209414: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [10]:
tf.config.run_functions_eagerly(True)
# x = train_wst[0,:].astype(np.float32)
# x = x.reshape(1, 1, 1360, 1)
# train_step(VAE_S, x, optimizer)


In [None]:
epochs = 5
trainset_size = train_wst.shape[0]
testset_size = test_wst.shape[0]
for epoch in range(1, epochs + 1):
    start_time = time.time()
    for i in range(0,trainset_size):
        x = train_wst[i,:].astype(np.float32)
        x = x.reshape(1, 1, 1360, 1)
        train_step(VAE_L, x, optimizer)
    end_time = time.time()
    print("finished training")
    loss = tf.keras.metrics.Mean()
    for j in range(0,testset_size):
        test_x = test_wst[j,:].astype(np.float32)
        test_x = test_x.reshape(1, 1, 1360, 1)
        loss(compute_loss(VAE_L, test_x, testing=True))
    elbo = -loss.result()
    display.clear_output(wait=False)
    print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'
        .format(epoch, elbo, end_time - start_time))

In [None]:
epochs = 5
trainset_size = train_wst.shape[0]
testset_size = test_wst.shape[0]
for epoch in range(1, epochs + 1):
    start_time = time.time()
    for i in range(0,trainset_size):
        x = train_wst[i,:].astype(np.float32)
        x = x.reshape(1, 1, 1360, 1)
        train_step(VAE_S, x, optimizer)
    end_time = time.time()
    print("finished training")
    loss = tf.keras.metrics.Mean()
    for j in range(0,testset_size):
        test_x = test_wst[j,:].astype(np.float32)
        test_x = test_x.reshape(1, 1, 1360, 1)
        loss(compute_loss(VAE_S, test_x, testing=True))
    elbo = -loss.result()
    display.clear_output(wait=False)
    print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'
        .format(epoch, elbo, end_time - start_time))

WSTs outputted by the VAE's during testing:

In [None]:
print("Stored_output_WSTs_from_testing:\n VAE_L: ")
print(Output_var_VAE_L)
print("\n VAE_S: ")
print(Output_var_VAE_S)

Latent space variables (z) found during testing:

In [None]:
print("Stored_latent_space_variables_from_testing:\n VAE_L: ")
print(Latent_var_VAE_L)
print("\n VAE_S: ")
print(Latent_var_VAE_S)

                            ================Scratchwork below=====================

In [None]:
class CVAE_Large(tf.keras.Model):
    def __init__(self, latent_dim):
        super(CVAE_Large, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(1,1360,1,1)),
                tf.keras.layers.Conv2D(
                    filters=128,
                    kernel_size=28, #decided by doing input_dim / 9; similar to keras documentation
                    strides=1, 
                    dilation_rate=2,
                    padding='same',
                    activation='relu'), 
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 1 finish
                tf.keras.layers.Conv2D(
                    filters=64,
                    kernel_size=14, 
                    strides=1, 
                    dilation_rate=2,
                    padding='same',
                    activation='relu'),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 2 finish
                tf.keras.layers.Conv2D(
                    filters=32,
                    kernel_size=7, 
                    strides=1,
                    dilation_rate=2,
                    padding='same',
                    activation='relu'),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 3 finish
                tf.keras.layers.Dense(680),
                tf.keras.layers.Dense(340)
            ]
        )
        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(1,170,1,1)),
                tf.keras.layers.Dense(5440),
                tf.keras.layers.Reshape((1,170,32,1)),
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=128, 
                    strides=2,
                    kernel_size=14, 
                    padding='same',
                    activation='relu'), #layer 1 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=64, 
                    strides=1,
                    kernel_size=14, 
                    padding='same',
                    activation='relu'), #layer 2 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=32, 
                    strides=2,
                    kernel_size=28, 
                    padding='same',
                    activation='relu'), #layer 3 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=16, 
                    strides=2,
                    kernel_size=56, 
                    padding='same',
                    activation='relu'), #layer 4 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=1, 
                    strides=1,
                    kernel_size=56, 
                    padding='same') #layer 5 finish
            ]
        )
        
        
    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100,self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean
    
    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

In [None]:
class CVAE_Small(tf.keras.Model):
    def __init__(self, latent_dim):
        super(CVAE_Small, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(1,1360,1)),
                tf.keras.layers.Conv2D(
                    filters=128,
                    kernel_size=28, #decided by doing input_dim / 9; similar to keras documentation
#                     strides=2, 
                    dilation_rate=2,
                    activation='relu'), 
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 1 finish
                tf.keras.layers.Conv2D(
                    filters=64,
                    kernel_size=14, 
#                     strides=2, 
                    dilation_rate=2,
                    activation='relu'),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 2 finish
                tf.keras.layers.Conv2D(
                    filters=32,
                    kernel_size=7, 
#                     strides=2,
                    dilation_rate=2,
                    activation='relu'),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 3 finish
                tf.keras.layers.Conv2D(
                    filters=16,
                    kernel_size=14, 
#                     strides=2,
                    dilation_rate=2,
                    activation='relu'),
                tf.keras.layers.BatchNormalization(momentum=0.9), #layer 4 finish
                tf.keras.layers.Dense(340),
                tf.keras.layers.Dense(170)
            ]
        )
        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(85,1)),
                tf.keras.layers.Dense(2720),
                tf.keras.layers.Reshape((1,85,32)),
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=128, 
                    strides=2,
                    kernel_size=14, 
                    padding='same',
                    activation='relu'), #layer 1 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=64, 
                    strides=2,
                    kernel_size=14, 
                    padding='same',
                    activation='relu'), #layer 2 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=32, 
                    strides=2,
                    kernel_size=28, 
                    padding='same',
                    activation='relu'), #layer 3 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=16, 
                    strides=2,
                    kernel_size=56, 
                    padding='same',
                    activation='relu'), #layer 4 finish
                tf.keras.layers.BatchNormalization(momentum=0.9),
                tf.keras.layers.Conv2DTranspose(
                    filters=1, 
                    strides=1,
                    kernel_size=56, 
                    padding='same') #layer 5 finish
            ]
        )
        
        
    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100,self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean
    
    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

# VAE w Pytorch
https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
def relu(_x_):
    return np.maximum(0, _x_)

In [None]:
class VariationalEncoder(nn.Module):
    def __init__(self, latent_dims):
#         find INPUT_SIZE
        super(VariationalEncoder, self).__init__()
        self.layer1 = nn.Conv2D(input_size, 256, stride=2)
        self.act1 = nn.ReLU()
        self.norm1 = nn.BatchNorm2d(256, momentum=0.9)
        self.layer2 = nn.Conv2D(256, 128, stride=2)
        self.act2 = nn.ReLU()
        self.norm2 = nn.BatchNorm2d(128, momentum=0.9)
        self.layer3 = nn.Conv2D(128, 64, stride=2)
        self.act3 = nn.ReLU()
        self.norm3 = nn.BatchNorm2d(64, momentum=0.9)


        
    
    
    
    
        self.linear1 = nn.Linear(784, 512)
        self.linear2 = nn.Linear(512, latent_dims)
        self.linear3 = nn.Linear(512, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_dims, 512)
        self.linear2 = nn.Linear(512, 784)

    def forward(self, z):
        z = F.relu(self.linear1(z))
        z = torch.sigmoid(self.linear2(z))
        return z.reshape((-1, 1, 28, 28))

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = VariationalEncoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

In [None]:
def train(autoencoder, data, epochs=20):
    opt = torch.optim.Adam(autoencoder.parameters())
    for epoch in range(epochs):
        for x, y in data:
            x = x.to(device) # GPU
            opt.zero_grad()
            x_hat = autoencoder(x)
            loss = ((x - x_hat)**2).sum() + autoencoder.encoder.kl
            loss.backward()
            opt.step()
    return autoencoder

In [None]:
vae = VariationalAutoencoder(latent_dims).to(device) # GPU
vae = train(vae, data)

In [None]:
plot_latent(vae, data)