In [2]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers,Model,utils
import imageio
import PIL
import matplotlib.pyplot as plt

In [3]:
import tensorboard
import datetime
import time

In [4]:
(train_images,labels),(_,_)=tf.keras.datasets.mnist.load_data()

In [5]:
def preprocess_data(data):
    data=data.astype('float32')
    data/=255
    data[data>0.5]=1
    data[data<=0.5]=0
    return tf.expand_dims(data,-1)

In [6]:
input_images = preprocess_data(train_images)

In [7]:
one_hot_labels = np.eye(10)[labels]
one_hot_labels=tf.cast(one_hot_labels,'float32')

In [8]:
dataset = tf.data.Dataset.from_tensor_slices((input_images,one_hot_labels)).batch(64).shuffle(10000)

In [9]:
class Encoder(Model):
    def __init__(self,n_filters):
        super(Encoder,self).__init__()
        
        self.conv1 = layers.Conv2D(n_filters,kernel_size=3,strides=(2,2),padding = 'same',activation = 'elu')
        self.conv2 = layers.Conv2D(n_filters*2,kernel_size=3,strides=(2,2),padding = 'same',activation = 'elu')
        self.flat = layers.Flatten()
        self.bottle_neck = layers.Dense(50+50)
        
    def call(self,inputs):
        #dim = (64,28,28,1)
        x = self.conv1(inputs)
        #dim = (64,14,14,16)
        x = self.conv2(x)
        #dim = (64,7,7,32)
        x = self.flat(x)
        #dim = (64,7*7*32)
        x = self.bottle_neck(x)
        #dim = (64,50+50)
        return x

In [10]:
class Classifier(Model):
    def __init__(self,n_filters):
        super(Classifier,self).__init__()
        
        self.conv1 = layers.Conv2D(n_filters,kernel_size=3,strides=(2,2),padding = 'same',activation = 'elu')
        self.conv2 = layers.Conv2D(n_filters*2,kernel_size=3,strides=(2,2),padding = 'same',activation = 'elu')
        self.flat = layers.Flatten()
        self.bottle_neck = layers.Dense(10)
        
    def call(self,inputs):
        #dim = (64,28,28,1)
        x = self.conv1(inputs)
        #dim = (64,14,14,16)
        x = self.conv2(x)
        #dim = (64,7,7,32)
        x = self.flat(x)
        #dim = (64,7*7*32)
        x = self.bottle_neck(x)
        #dim = (64,10)
        return x

In [11]:
class Decoder(Model):
    def __init__(self,n_filters):
        super(Decoder,self).__init__()
        
        self.dense = layers.Dense(7*7*32,activation='elu')
        self.reshape = layers.Reshape(target_shape=(7,7,32))
        self.deconv1 = layers.Conv2DTranspose(n_filters*2,kernel_size=3,strides=(2,2),padding = 'same',activation = 'elu')
        self.deconv2 = layers.Conv2DTranspose(n_filters,kernel_size=3,strides=(2,2),padding = 'same',activation = 'elu')
        self.decoder_output = layers.Conv2D(1,kernel_size=3,strides=(1,1),padding = 'same')
        
    def call(self,inputs):
        #z_dim = (64,60)
        x = self.dense(inputs)
        #dim = (64,7*7*32)
        x = self.reshape(x)
        #dim = (64,7,7,32)
        x = self.deconv1(x)
        #dim = (64,7,7,32)
        x = self.deconv2(x)
        #dim = (64,14,14,16)
        x = self.decoder_output(x)
        #dim = (64,28*28*1)

        return x

In [12]:
@tf.function
def reparameterize(input_tensor):
    # inputs have dim (64,50+50) where the first 50 params are mean and the next 50 is logvar
    mean = input_tensor[:,:50]
    logvar = input_tensor[:,50::]
    
    sample = tf.random.normal((input_tensor.shape[0],50))
    z = mean+tf.math.sqrt(tf.exp(logvar))*sample 
    return z,mean,logvar,sample

In [13]:
def log_normal(mean,logvar,sample):
    return tf.reduce_sum(-0.5*((sample-mean)**2*tf.exp(-logvar)+tf.math.log(2*np.pi)+logvar),axis=1)

@tf.function
def compute_loss(input_image,generated_image,predicted_labels,true_labels,mean,logvar,sample,alpha=1.0):
    logpz=log_normal(0.0,1.0,sample)
    logqz = log_normal(mean,logvar,sample)
    logpx_z = -tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=generated_image,labels=input_image),axis = [1,2,3])
    inference_loss = -tf.reduce_mean(logpx_z+logpz-logqz)
    class_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = predicted_labels,labels = true_labels))
    return alpha*class_loss+inference_loss

In [14]:
classifier = Classifier(16)
encoder = Encoder(16)
decoder =Decoder(16)

In [15]:
%reload_ext tensorboard
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = r'C:/Users/Boyang/Machine Learning/VAE/logs/gradient_tape/' + current_time + '/train'
test_log_dir = r'C:/Users/Boyang/Machine Learning/VAE/logs/gradient_tape/' + current_time + '/test'
#train_summary_writer = tf.summary.create_file_writer(train_log_dir)
#test_summary_writer = tf.summary.create_file_writer(test_log_dir)

In [16]:
adam = tf.keras.optimizers.Adam(1e-4)
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)

@tf.function()
def train_step(image,labels):
    with tf.GradientTape() as tape:
        predicted_labels = classifier(image)
        encoder_output = encoder(image)
        z,mean,logvar,sample = reparameterize(encoder_output)
        generated_image = decoder(tf.concat([z,predicted_labels],axis = 1))
        loss = compute_loss(image,generated_image,predicted_labels,labels,mean,logvar,sample,alpha=10.0)
        
    train_loss(loss)
    
    gradients = tape.gradient(loss,(classifier.trainable_variables+encoder.trainable_variables+decoder.trainable_variables))
    adam.apply_gradients(zip(gradients,(classifier.trainable_variables+encoder.trainable_variables+decoder.trainable_variables)))

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [18]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

In [20]:
checkpoint_path = r"C:/Users/Boyang/Machine Learning/VAE/VAE_saved_models"
generated_samples = []
for epoch in range(1,101):
    start_time = time.time()
    for step,image,label in enumerate(dataset):
        train_step(image,label)
    end_time = time.time()
    if epoch % 10 == 0:
        test_inputs = tf.concat([tf.random.normal((10,50)),tf.cast(np.eye(10)[0:10],'float32')],axis = 1)

        print('Epoch: {},  time elapse for current epoch {}'.format(epoch,end_time - start_time))
        encoder.save_weights(checkpoint_path+'encoder_{}.h5'.format(epoch))
        decoder.save_weights(checkpoint_path+'decoder_{}.h5'.format(epoch))
        classifier.save_weights(checkpoint_path+'classifer_{}.h5'.format(epoch))
        generated_samples.append(decoder(test_inputs))
        
encoder.save_weights(checkpoint_path+'encoder_{}.h5'.format(epoch))
decoder.save_weights(checkpoint_path+'decoder_{}.h5'.format(epoch))
classifier.save_weights(checkpoint_path+'classifer_{}.h5'.format(epoch))

RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.