In [1]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

2023-09-01 18:00:51.623806: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-09-01 18:00:51.706091: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-09-01 18:00:51.707300: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class DCGAN():
    def __init__(self, rows, cols, channels, z = 10):
        # Input shape
        self.img_rows = rows
        self.img_cols = cols
        self.channels = channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = z

        optimizer = tf.keras.optimizers.legacy.Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1, activation='sigmoid'))

        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=256, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            # Sample noise and generate a batch of new images
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator (wants discriminator to mistake images as real)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/dcgan_mnist_%d.png" % epoch)
        plt.close()

In [3]:
dcgan = DCGAN(28,28,1)
dcgan.train(epochs=500, batch_size=256, save_interval=50)

2023-09-01 18:00:56.413765: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 14, 14, 32)        320       
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 14, 14, 32)        0         
                                                                 
 dropout (Dropout)           (None, 14, 14, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 64)          18496     
                                                                 
 zero_padding2d (ZeroPaddin  (None, 8, 8, 64)          0         
 g2D)                                                            
                                                                 
 batch_normalization (Batch  (None, 8, 8, 64)          256       
 Normalization)                                         

26 [D loss: 0.631039, acc.: 67.19%] [G loss: 0.796637]
27 [D loss: 0.414148, acc.: 80.86%] [G loss: 0.757792]
28 [D loss: 0.319247, acc.: 87.89%] [G loss: 0.514823]
29 [D loss: 0.217928, acc.: 93.55%] [G loss: 0.477493]
30 [D loss: 0.263641, acc.: 88.67%] [G loss: 0.565768]
31 [D loss: 0.702416, acc.: 64.06%] [G loss: 1.637596]
32 [D loss: 0.851546, acc.: 52.15%] [G loss: 1.956462]
33 [D loss: 1.035523, acc.: 41.41%] [G loss: 1.701118]
34 [D loss: 0.694549, acc.: 63.87%] [G loss: 1.241961]
35 [D loss: 0.534829, acc.: 74.02%] [G loss: 1.029472]
36 [D loss: 0.685371, acc.: 63.87%] [G loss: 1.102475]
37 [D loss: 0.869354, acc.: 51.37%] [G loss: 1.915508]
38 [D loss: 1.347189, acc.: 30.08%] [G loss: 2.126117]
39 [D loss: 1.548726, acc.: 20.90%] [G loss: 1.774237]
40 [D loss: 1.265191, acc.: 34.96%] [G loss: 1.173418]
41 [D loss: 1.229654, acc.: 32.62%] [G loss: 1.079868]
42 [D loss: 1.053922, acc.: 40.43%] [G loss: 1.081475]
43 [D loss: 0.791206, acc.: 56.25%] [G loss: 0.906569]
44 [D loss

102 [D loss: 0.759228, acc.: 58.40%] [G loss: 1.188369]
103 [D loss: 0.812224, acc.: 51.76%] [G loss: 1.213637]
104 [D loss: 0.798999, acc.: 51.95%] [G loss: 1.225005]
105 [D loss: 0.711207, acc.: 60.35%] [G loss: 1.140019]
106 [D loss: 0.743033, acc.: 57.23%] [G loss: 1.142313]
107 [D loss: 0.711528, acc.: 55.86%] [G loss: 1.135073]
108 [D loss: 0.780103, acc.: 52.34%] [G loss: 1.109139]
109 [D loss: 0.808703, acc.: 51.95%] [G loss: 1.037673]
110 [D loss: 0.847453, acc.: 47.27%] [G loss: 1.012392]
111 [D loss: 0.845852, acc.: 49.02%] [G loss: 1.067168]
112 [D loss: 0.846757, acc.: 50.20%] [G loss: 1.058564]
113 [D loss: 0.965708, acc.: 40.43%] [G loss: 1.179262]
114 [D loss: 0.861926, acc.: 46.88%] [G loss: 1.145931]
115 [D loss: 0.795712, acc.: 52.34%] [G loss: 1.261646]
116 [D loss: 0.756783, acc.: 56.05%] [G loss: 1.211266]
117 [D loss: 0.818340, acc.: 49.41%] [G loss: 1.132066]
118 [D loss: 0.819406, acc.: 51.37%] [G loss: 1.178097]
119 [D loss: 0.832661, acc.: 45.90%] [G loss: 1.

177 [D loss: 0.737845, acc.: 56.05%] [G loss: 1.046140]
178 [D loss: 0.777501, acc.: 52.34%] [G loss: 0.989368]
179 [D loss: 0.756753, acc.: 55.86%] [G loss: 0.999487]
180 [D loss: 0.703269, acc.: 59.38%] [G loss: 1.008132]
181 [D loss: 0.734251, acc.: 55.27%] [G loss: 1.080739]
182 [D loss: 0.680253, acc.: 61.52%] [G loss: 1.115381]
183 [D loss: 0.833149, acc.: 47.66%] [G loss: 1.088480]
184 [D loss: 0.781517, acc.: 51.76%] [G loss: 1.173433]
185 [D loss: 0.858239, acc.: 43.36%] [G loss: 1.145571]
186 [D loss: 0.835563, acc.: 49.61%] [G loss: 1.191088]
187 [D loss: 0.846080, acc.: 49.61%] [G loss: 1.150574]
188 [D loss: 0.784735, acc.: 51.37%] [G loss: 1.020269]
189 [D loss: 0.797499, acc.: 48.24%] [G loss: 1.028972]
190 [D loss: 0.794468, acc.: 51.76%] [G loss: 1.115844]
191 [D loss: 0.760875, acc.: 55.27%] [G loss: 1.044396]
192 [D loss: 0.766981, acc.: 55.27%] [G loss: 1.201386]
193 [D loss: 0.782655, acc.: 53.12%] [G loss: 1.092098]
194 [D loss: 0.758746, acc.: 54.30%] [G loss: 1.

252 [D loss: 0.833606, acc.: 47.85%] [G loss: 1.046008]
253 [D loss: 0.865115, acc.: 42.58%] [G loss: 1.011741]
254 [D loss: 0.830962, acc.: 48.83%] [G loss: 1.141884]
255 [D loss: 0.809414, acc.: 47.66%] [G loss: 1.116294]
256 [D loss: 0.786315, acc.: 51.56%] [G loss: 1.055007]
257 [D loss: 0.679313, acc.: 60.35%] [G loss: 1.057359]
258 [D loss: 0.636670, acc.: 63.09%] [G loss: 1.040301]
259 [D loss: 0.691901, acc.: 57.42%] [G loss: 0.958848]
260 [D loss: 0.627708, acc.: 64.65%] [G loss: 1.042243]
261 [D loss: 0.670628, acc.: 60.16%] [G loss: 1.006880]
262 [D loss: 0.674189, acc.: 59.77%] [G loss: 1.044410]
263 [D loss: 0.792957, acc.: 52.15%] [G loss: 1.032997]
264 [D loss: 0.723201, acc.: 58.20%] [G loss: 1.056284]
265 [D loss: 0.763615, acc.: 48.63%] [G loss: 1.050675]
266 [D loss: 0.834142, acc.: 46.88%] [G loss: 1.064357]
267 [D loss: 0.835608, acc.: 44.92%] [G loss: 1.101878]
268 [D loss: 0.773212, acc.: 51.76%] [G loss: 1.069905]
269 [D loss: 0.807067, acc.: 47.85%] [G loss: 1.

327 [D loss: 0.899799, acc.: 37.89%] [G loss: 1.121085]
328 [D loss: 0.839608, acc.: 47.85%] [G loss: 1.239891]
329 [D loss: 0.826639, acc.: 46.68%] [G loss: 1.157218]
330 [D loss: 0.864784, acc.: 43.16%] [G loss: 1.031208]
331 [D loss: 0.825589, acc.: 44.92%] [G loss: 1.083914]
332 [D loss: 0.827485, acc.: 45.12%] [G loss: 1.005617]
333 [D loss: 0.779957, acc.: 49.61%] [G loss: 0.921941]
334 [D loss: 0.692845, acc.: 57.03%] [G loss: 0.929209]
335 [D loss: 0.670054, acc.: 61.52%] [G loss: 0.898425]
336 [D loss: 0.639511, acc.: 63.09%] [G loss: 0.965590]
337 [D loss: 0.578935, acc.: 68.55%] [G loss: 0.958263]
338 [D loss: 0.645262, acc.: 62.30%] [G loss: 0.977307]
339 [D loss: 0.689266, acc.: 55.66%] [G loss: 1.021504]
340 [D loss: 0.689789, acc.: 58.79%] [G loss: 1.094198]
341 [D loss: 0.839862, acc.: 43.75%] [G loss: 1.105425]
342 [D loss: 0.888963, acc.: 42.19%] [G loss: 1.074312]
343 [D loss: 0.788640, acc.: 50.39%] [G loss: 1.178565]
344 [D loss: 0.810498, acc.: 51.37%] [G loss: 1.

402 [D loss: 0.633663, acc.: 62.89%] [G loss: 0.962584]
403 [D loss: 0.671896, acc.: 61.72%] [G loss: 0.940681]
404 [D loss: 0.613630, acc.: 63.48%] [G loss: 0.955350]
405 [D loss: 0.588292, acc.: 70.12%] [G loss: 0.985506]
406 [D loss: 0.629769, acc.: 68.16%] [G loss: 1.054093]
407 [D loss: 0.648262, acc.: 61.72%] [G loss: 0.997207]
408 [D loss: 0.669293, acc.: 62.50%] [G loss: 1.054397]
409 [D loss: 0.744486, acc.: 51.76%] [G loss: 1.074263]
410 [D loss: 0.810752, acc.: 48.05%] [G loss: 1.116103]
411 [D loss: 0.804482, acc.: 47.46%] [G loss: 1.121960]
412 [D loss: 0.791618, acc.: 50.59%] [G loss: 1.068899]
413 [D loss: 0.785441, acc.: 52.34%] [G loss: 1.098508]
414 [D loss: 0.825176, acc.: 44.14%] [G loss: 1.109426]
415 [D loss: 0.758816, acc.: 52.15%] [G loss: 1.057191]
416 [D loss: 0.680940, acc.: 60.35%] [G loss: 0.998317]
417 [D loss: 0.756885, acc.: 50.98%] [G loss: 0.953478]
418 [D loss: 0.611657, acc.: 65.04%] [G loss: 0.968743]
419 [D loss: 0.728956, acc.: 52.34%] [G loss: 0.

477 [D loss: 0.382479, acc.: 86.52%] [G loss: 1.024582]
478 [D loss: 0.439193, acc.: 84.57%] [G loss: 0.974380]
479 [D loss: 0.521278, acc.: 74.02%] [G loss: 1.089310]
480 [D loss: 0.630230, acc.: 67.58%] [G loss: 1.076224]
481 [D loss: 0.783776, acc.: 52.54%] [G loss: 1.052352]
482 [D loss: 0.765750, acc.: 51.76%] [G loss: 1.162861]
483 [D loss: 0.956390, acc.: 33.79%] [G loss: 1.157515]
484 [D loss: 0.860321, acc.: 41.60%] [G loss: 1.156845]
485 [D loss: 0.902311, acc.: 39.84%] [G loss: 1.107591]
486 [D loss: 0.828520, acc.: 47.66%] [G loss: 1.086677]
487 [D loss: 0.706773, acc.: 56.84%] [G loss: 1.114403]
488 [D loss: 0.766190, acc.: 52.73%] [G loss: 1.079358]
489 [D loss: 0.698594, acc.: 57.03%] [G loss: 1.084055]
490 [D loss: 0.513856, acc.: 74.80%] [G loss: 1.042683]
491 [D loss: 0.498517, acc.: 76.56%] [G loss: 0.987909]
492 [D loss: 0.466751, acc.: 81.25%] [G loss: 0.991956]
493 [D loss: 0.550102, acc.: 69.73%] [G loss: 0.939974]
494 [D loss: 0.439083, acc.: 83.40%] [G loss: 0.

---