In [13]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

In [14]:
tf.random.set_seed(1)

In [15]:
(train_img,_),(test_img,_)=tf.keras.datasets.mnist.load_data()

In [16]:
height,width=28,28

train_img=train_img.reshape(train_img.shape[0],height,width,1).astype('float32')
train_img/=255
train_img[train_img>=0.5]=1
train_img[train_img<=0.5]=0

test_img=test_img.reshape(test_img.shape[0],height,width,1).astype('float32')
test_img/=255
test_img[test_img>=0.5]=1
test_img[test_img<=0.5]=0

In [17]:
train_buffer=60000
test_buffer=10000

batch_size=100
latent_dimensions=2

In [18]:
train_set=tf.data.Dataset.from_tensor_slices(train_img).shuffle(train_buffer).batch(batch_size)
test_set=tf.data.Dataset.from_tensor_slices(test_img).shuffle(test_buffer).batch(batch_size)

In [19]:
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D

def encoder(x,filters):
    x=Conv2D(filters,3,activation='relu',padding='same',kernel_initializer='he_normal')(x)
    x=Conv2D(filters,3,activation='relu',padding='same',kernel_initializer='he_normal')(x)
    x=MaxPooling2D()(x)
    return x

def decoder(x,filters):
    x=UpSampling2D()(x)
    x=Conv2D(filters,3,activation='relu',padding='same',kernel_initializer='he_normal')(x)
    x=Conv2D(filters,3,activation='relu',padding='same',kernel_initializer='he_normal')(x)
    return x

In [20]:
from keras.layers import Input, Dense, Flatten, Lambda, Reshape
from keras.models import Model
from keras.optimizers import Adam

def vae(h,w,batch,latent,ini_filt,capacity,optimizer=Adam(lr=0.001)):
    
    inputs=Input((h,w,1))
    
    encoded=encoder(inputs,ini_filt)
    for i in range(1,capacity+1):
        encoded=encoder(encoded,ini_filt*(2**i))
    
    _,*space=encoded.get_shape().as_list()
    encoded_flat=Flatten()(encoded)
    
    def sampling(args):
        z_mean,z_log_sigma=args
        epsilon=tf.keras.backend.random_normal(shape=(batch,latent),mean=0,stddev=1)
        return z_mean+tf.keras.backend.exp(z_log_sigma)*epsilon
    
    z_mean=Dense(latent)(encoded_flat)
    z_log_sigma=Dense(latent)(encoded_flat)
    z=Lambda(sampling,output_shape=(latent,))([z_mean,z_log_sigma])
    
    
    input_embed=Input([latent])
    embed=Dense(np.prod(space),activation='relu')(input_embed)
    embed=Reshape(encoded.shape.as_list()[1:])(embed)
    
    decoded=decoder(embed,ini_filt*(2**capacity))
    for i in range(capacity-1,-1,-1):
        decoded=decoder(decoded,ini_filt*(2**i))
        
    output=Conv2D(3,1,activation='tanh')(decoded)
    
    decode=Model(input_embed,output)
    encode_sample=Model(inputs,z)
    vae_out=decode(encode_sample(inputs))
    vae=Model(inputs,vae_out)
    
    def vae_loss(x,x_decoded_mean):
        mse_loss=tf.keras.backend.mean(mse(x,x_decoded_mean),axis(1,2))*h*w
        kl_loss=-0.5*tf.keras.backend.mean(1+z_log_sigma-tf.keras.backend.square(z_mean)-tf.keras.backend.exp(z_log_sigma),axis=-1)
        return mse_loss+kl_loss
    
    vae.compile(loss=vae_loss,optimizer=optimizer)
    vae.summary()
    return vae,encode_sample,decode

In [21]:
vae,encode,decode=vae(height,width,batch_size,latent_dimensions,8,3)

In [22]:
vae.fit(train_set,verbose=1,epochs=5,validation_data=test_set)

Epoch 1/5


ValueError: in user code:

    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\keras\engine\training.py:806 train_function  *
        return step_function(self, iterator)
    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\keras\engine\training.py:796 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1211 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2585 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2945 _call_for_each_replica
        return fn(*args, **kwargs)
    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\keras\engine\training.py:789 run_step  **
        outputs = model.train_step(data)
    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\keras\engine\training.py:756 train_step
        _minimize(self.distribute_strategy, tape, self.optimizer, loss,
    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\keras\engine\training.py:2736 _minimize
        gradients = optimizer._aggregate_gradients(zip(gradients,  # pylint: disable=protected-access
    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py:562 _aggregate_gradients
        filtered_grads_and_vars = _filter_grads(grads_and_vars)
    C:\Users\Tim\anaconda3\envs\Moonlander2.0\lib\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py:1270 _filter_grads
        raise ValueError("No gradients provided for any variable: %s." %

    ValueError: No gradients provided for any variable: ['conv2d_17/kernel:0', 'conv2d_17/bias:0', 'conv2d_18/kernel:0', 'conv2d_18/bias:0', 'conv2d_19/kernel:0', 'conv2d_19/bias:0', 'conv2d_20/kernel:0', 'conv2d_20/bias:0', 'conv2d_21/kernel:0', 'conv2d_21/bias:0', 'conv2d_22/kernel:0', 'conv2d_22/bias:0', 'conv2d_23/kernel:0', 'conv2d_23/bias:0', 'conv2d_24/kernel:0', 'conv2d_24/bias:0', 'dense_3/kernel:0', 'dense_3/bias:0', 'dense_4/kernel:0', 'dense_4/bias:0', 'dense_5/kernel:0', 'dense_5/bias:0', 'conv2d_25/kernel:0', 'conv2d_25/bias:0', 'conv2d_26/kernel:0', 'conv2d_26/bias:0', 'conv2d_27/kernel:0', 'conv2d_27/bias:0', 'conv2d_28/kernel:0', 'conv2d_28/bias:0', 'conv2d_29/kernel:0', 'conv2d_29/bias:0', 'conv2d_30/kernel:0', 'conv2d_30/bias:0', 'conv2d_31/kernel:0', 'conv2d_31/bias:0', 'conv2d_32/kernel:0', 'conv2d_32/bias:0', 'conv2d_33/kernel:0', 'conv2d_33/bias:0'].
