In [None]:
import numpy as np
import pandas as pd
from sklearn import preprocessing
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from PIL import Image
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics

In [None]:
def save_img(imgs,names):
  img_new = Image.new('L',(280,280))
  index = 0
  for i in range(0,280,80):
    for j in range(0,280,80):
      img = imgs[index]
      img = Image.fromarray(img,mode='L')
      img_new.paste(img,(i,j))
      index+=1
  img_new.save(names)

In [None]:
def feature_scale(x):
  x = tf.cast(x,dtype=tf.float32)/255.
  return x

In [None]:
#Dim reduct nums
dim_reduce = 10
batch_num = 128
lr = 1e-3

In [None]:
(x,y),(x_test,y_test) = datasets.fashion_mnist.load_data()
data = tf.data.Dataset.from_tensor_slices(x)
data = data.map(feature_scale).shuffle(10000).batch(batch_num)

data_test = tf.data.Dataset.from_tensor_slices(x_test)
data_test = data_test.map(feature_scale).batch(batch_num)

data_iter = iter(data)
samples = next(data_iter)
print(samples[0].shape,samples[1].shape)

(28, 28) (28, 28)


In [None]:
class VAE(keras.Model):
  def __init__(self):
    super(VAE,self).__init__()
    #encoder
    self.fc_layer_1 = layers.Dense(128)
    self.fc_layer_2 = layers.Dense(dim_reduce)
    self.fc_layer_3 = layers.Dense(dim_reduce)
    
    
    self.fc_layer_4 = layers.Dense(128)
    self.fc_layer_5 = layers.Dense(784)
    

  def model_encoder(self, x):
    h = tf.nn.relu(self.fc_layer_1(x))
    mean_fc = self.fc_layer_2(h)
    var_fc = self.fc_layer_3(h)
    return mean_fc,var_fc

  def model_decoder(self, z):
    out = tf.nn.relu(self.fc_layer_4(z))
    out = self.fc_layer_5(out)
    return out

  def reparameter(self,mean_x,var_x):
    eps = tf.random.normal(var_x.shape)
    std = tf.exp(var_x)**0.5
    z = mean_x + std*eps
    return z
  
  def call(self, inputs, training=None):
    mean_x,var_x = self.model_encoder(inputs)
    z = self.reparameter(mean_x,var_x)
    x = self.model_decoder(z)
    return x,mean_x,var_x

In [None]:
model = VAE()
model.build(input_shape=(4,784))
optimizer = optimizers.Adam(lr=lr)
model.summary()

Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                multiple                  100480    
_________________________________________________________________
dense_1 (Dense)              multiple                  1290      
_________________________________________________________________
dense_2 (Dense)              multiple                  1290      
_________________________________________________________________
dense_3 (Dense)              multiple                  1408      
_________________________________________________________________
dense_4 (Dense)              multiple                  101136    
Total params: 205,604
Trainable params: 205,604
Non-trainable params: 0
_________________________________________________________________


In [None]:
!rm -rf img_result
!mkdir img_result

In [None]:
optimizer = optimizers.Adam(lr=lr)
for i in range(10):
  for step,x in enumerate(data):
    x = tf.reshape(x,[-1,784])
    with tf.GradientTape() as tape:
      logits,mean_x,var_x = model(x)
      loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x,logits=logits)
      loss = tf.reduce_sum(loss)/x.shape[0]
      kl_div = -0.5*(var_x+1-mean_x**2-tf.exp(var_x))
      kl_div = tf.reduce_sum(kl_div)/x.shape[0]
      
      loss = loss + 1.*kl_div
    grads = tape.gradient(loss,model.trainable_variables)
    optimizer.apply_gradients(zip(grads,model.trainable_variables))
    
    if step %100==0:
      print(i,step,'loss:',float(loss),'kl_div:',float(kl_div))
      
  x = next(iter(data_test))
  val_x = tf.reshape(x,[-1,784])
  logits,_,_ = model(val_x)
  x_hat = tf.sigmoid(logits)
  x_hat = tf.reshape(x_hat,[-1,28,28])
  x_hat = x_hat.numpy()*255
  x_hat = x_hat.astype(np.uint8)
  save_img(x_hat,'img_result/VAE_img_%d.png'%i)

0 0 loss: 260.6536865234375 kl_div: 15.01053524017334
0 100 loss: 255.05715942382812 kl_div: 14.16098403930664
0 200 loss: 253.0172119140625 kl_div: 14.4535493850708
0 300 loss: 254.93603515625 kl_div: 14.592442512512207
0 400 loss: 273.73236083984375 kl_div: 14.08193588256836
1 0 loss: 261.02496337890625 kl_div: 14.518166542053223
1 100 loss: 259.20953369140625 kl_div: 14.700233459472656
1 200 loss: 247.3898162841797 kl_div: 15.081777572631836
1 300 loss: 250.42955017089844 kl_div: 14.643710136413574
1 400 loss: 236.47540283203125 kl_div: 14.7083740234375
2 0 loss: 246.284912109375 kl_div: 14.419654846191406
2 100 loss: 251.54071044921875 kl_div: 14.560482025146484
2 200 loss: 246.84365844726562 kl_div: 14.768489837646484
2 300 loss: 259.6373596191406 kl_div: 14.294961929321289
2 400 loss: 247.30999755859375 kl_div: 15.784496307373047
3 0 loss: 252.60140991210938 kl_div: 14.31601333618164
3 100 loss: 252.27960205078125 kl_div: 14.752006530761719
3 200 loss: 255.7644805908203 kl_div: 1

In [None]:
from google.colab import files
files.download('img_result/VAE_img_9.png')