In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow import keras
from keras import layers,optimizers,losses

physical_devices = tf.config.list_physical_devices("gPU")

  from .autonotebook import tqdm as notebook_tqdm


# Data preparation

In [2]:

(ds_train, _),infos=tfds.load(
    name="mnist",
    split=["train","test"],
    shuffle_files=True,
    with_info=True ,  
    as_supervised=False
)

def normalize_to_normal(x):
    return 255*(x-1)/2

def normalize_data(data):
    return tf.cast(2*(data["image"]/255)-1,tf.float32)

ds_train=ds_train.map(normalize_data,num_parallel_calls=tf.data.AUTOTUNE)
ds_train=ds_train.cache()
ds_train=ds_train.shuffle(infos.splits["train"].num_examples)#60,000
ds_train=ds_train.batch(batch_size=64)
ds_train=ds_train.prefetch(tf.data.AUTOTUNE)

print(ds_train.take(0))

<TakeDataset element_spec=TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None)>


## Build Model

### Generator Network

In [3]:
class CNNTranspose(layers.Layer):
    def __init__(self,channels,filter_size,stride=2,padding="same",name=None):
        super(CNNTranspose,self).__init__(name=name)
        self.cnt=layers.Conv2DTranspose(channels,filter_size,stride,padding,use_bias=False)
        self.bn=layers.BatchNormalization()
    def call(self,inputs,training=False):
        x=self.cnt(inputs)
        x=self.bn(x,training=training)
        return tf.nn.relu(x)



class Generator(keras.Model):
    seed_dim=128
    def __init__(self):
        super(Generator,self).__init__()
        self.linear=layers.Dense(7*7*256,name="linear")
        self.tcn1=CNNTranspose(128,5,stride=1,name='TPCNN_1')
        self.tcn2=CNNTranspose(64,5,stride=2,name="TPCNN_2")
        self.output_layer=layers.Conv2DTranspose(1,3,2,"same",name='output')
    def call(self,inputs,training=False):
        x=self.linear(inputs)
        x=layers.Reshape((7,7,256))(x)
        x=self.tcn1(x,training=training)
        x=self.tcn2(x,training=training)
        x=self.output_layer(x)
        return tf.keras.activations.tanh(x)
    def architecture(self):
        x=keras.Input((128))
        model=tf.keras.Model(inputs=[x],outputs=self.call(x))
        return model.summary()
    
generator=Generator()
generator.architecture()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 128)]             0         
                                                                 
 linear (Dense)              (None, 12544)             1618176   
                                                                 
 reshape (Reshape)           (None, 7, 7, 256)         0         
                                                                 
 TPCNN_1 (CNNTranspose)      (None, 7, 7, 128)         819712    
                                                                 
 TPCNN_2 (CNNTranspose)      (None, 14, 14, 64)        205056    
                                                                 
 output (Conv2DTranspose)    (None, 28, 28, 1)         577       
                                                                 
 tf.math.tanh (TFOpLambda)   (None, 28, 28, 1)         0     

### Discriminator Network

In [4]:
class CNN(layers.Layer):
    def __init__(self,channels,filter_size,name=None):
        super(CNN,self).__init__(name=name)
        self.cnn=layers.Conv2D(channels,filter_size,strides=2,padding="same",use_bias=False)
        self.bn=layers.BatchNormalization()
    def call(self,inputs,training=False):
        x=self.cnn(inputs)
        x=self.bn(x,training=training)
        return tf.nn.leaky_relu(x)


class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.cnn1=layers.Conv2D(64,3,2,padding="same",name="Conv_1")
        self.cnn2=CNN(128,5,name="Conv_2")
        self.cnn3=CNN(256,5,name="Conv_1")
        self.output_layer=layers.Dense(1,'sigmoid',name="output")
    def call(self,inputs,training=False):
        x=self.cnn1(inputs)
        x=tf.nn.leaky_relu(x)
        x=self.cnn2(x,training=training)
        #x=self.cnn3(x,training=training)
        x=layers.Flatten()(x)
        x=self.output_layer(x)
        return x
    def architecture(self):
        x=keras.Input((28,28,1))
        model=tf.keras.Model(inputs=[x],outputs=self.call(x))
        return model.summary()


discriminator=Discriminator()
discriminator.architecture()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 Conv_1 (Conv2D)             (None, 14, 14, 64)        640       
                                                                 
 tf.nn.leaky_relu (TFOpLambd  (None, 14, 14, 64)       0         
 a)                                                              
                                                                 
 Conv_2 (CNN)                (None, 7, 7, 128)         205312    
                                                                 
 flatten (Flatten)           (None, 6272)              0         
                                                                 
 output (Dense)              (None, 1)                 6273      
                                                           

# Train the Networks


### Optimizer and Loss Function

* loss function : log(D(x))+log(1-D(G(z)))
* Optimizer : Adam
* learning rate :3e-4


In [5]:
gen_opt=optimizers.Adam(1e-4)
disc_opt=optimizers.Adam(1e-4)
loss_fn=losses.BinaryCrossentropy()
num_epochs=50

### Training loop
* the Discriminator try to maximize the loss function
* the Generator try to minimize log(1-D(G(Z))) or maximize log(D(G(x)))

In [6]:
import tqdm
import os

for epoch in range(num_epochs):
    for idx,real in enumerate(tqdm.tqdm(ds_train)):
        bach_size=real.shape[0]
        z=tf.random.normal([bach_size,Generator.seed_dim])  #random noise
        

        

        #train the 
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape : 
            fake=generator(z,training=True) #G(z)

            real_loss=loss_fn(tf.ones(bach_size,1),discriminator(real,training=True)) #log(D(x)) 
            fake_loss=loss_fn(tf.zeros(bach_size,1),discriminator(fake,training=True))#log(1-D(G(Z)))

            disc_loss=real_loss+fake_loss
            gen_loss=loss_fn(tf.ones(bach_size,1),discriminator(fake))
        
        grads_gen=gen_tape.gradient(gen_loss,generator.trainable_variables)
        gen_opt.apply_gradients(zip(grads_gen,generator.trainable_variables))
        grads_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables)
        disc_opt.apply_gradients(zip(grads_disc,discriminator.trainable_variables))
        generated_image=generator(z,training=False)
        if idx % 100 ==0:
            img=tf.keras.preprocessing.image.array_to_img(normalize_to_normal(generated_image[0]))
            img.save(f"generated_images/generated_img{epoch}_{idx}_.png")

100%|██████████| 938/938 [01:21<00:00, 11.49it/s]
100%|██████████| 938/938 [01:12<00:00, 12.90it/s]
100%|██████████| 938/938 [01:11<00:00, 13.13it/s]
100%|██████████| 938/938 [01:10<00:00, 13.23it/s]
100%|██████████| 938/938 [01:12<00:00, 12.90it/s]
100%|██████████| 938/938 [01:13<00:00, 12.76it/s]
100%|██████████| 938/938 [01:13<00:00, 12.80it/s]
100%|██████████| 938/938 [01:12<00:00, 12.92it/s]
100%|██████████| 938/938 [01:14<00:00, 12.66it/s]
 29%|██▉       | 273/938 [00:21<00:53, 12.50it/s]


KeyboardInterrupt: 

In [None]:
# import tqdm
# import os

# for epoch in range(num_epochs):
#     for idx,real in enumerate(tqdm.tqdm(ds_train)):
#         bach_size=real.shape[0]
#         z=tf.random.normal((bach_size,Generator.seed_dim))  #random noise
#         fake=generator(z) #G(z)

#         if idx % 100 ==0:
#             img=tf.keras.preprocessing.image.array_to_img(normalize_to_normal(fake[0]))
#             img.save(f"generated_images/generated_img{epoch}_{idx}_.png")

#         #train the discriminator
#         with tf.GradientTape() as disc_tape: 
#             real_loss=loss_fn(tf.ones(bach_size,1),discriminator(real)) #log(D(x)) 
#             fake_loss=loss_fn(tf.zeros(bach_size,1),discriminator(fake))#log(1-D(G(Z)))
#             disc_loss=real_loss+fake_loss
        
#         grads=disc_tape.gradient(disc_loss,discriminator.trainable_weights)
#         disc_opt.apply_gradients(zip(grads,discriminator.trainable_weights))
        
#         #train the generator 
#         with tf.GradientTape() as gen_tape:
#             Gz=generator(z)
#             gen_loss=loss_fn(tf.ones(bach_size,1),discriminator(Gz))
        
#         grads=gen_tape.gradient(gen_loss,generator.trainable_weights)
#         gen_opt.apply_gradients(zip(grads,generator.trainable_weights))
import matplotlib.pyplot as plt
import numpy as np
noise=np.random.rand(64,128)
plt.style.use("grayscale")
fakes=generator.predict(z)
plt.imshow(fake[3])