Сделаем необходимые импорты

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers import Activation, BatchNormalization, Dense, Dropout, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose, MaxPool2D, UpSampling2D
from keras.layers import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers

from keras.utils import to_categorical

Определим некоторые переменные

In [2]:
input_dim = 100
batch_size = 256
batch_shape = (batch_size, 28, 28, 1)
num_classes = 10
epochs = 3000

Загрузка датасета и нормализация данных

In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [4]:
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

print(x_train.shape)
print(x_test.shape)

(60000, 28, 28)
(10000, 28, 28)


In [5]:
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test  = np.reshape(x_test,  (len(x_test),  28, 28, 1))

Построим генератор

In [6]:
generator = Sequential()

generator.add(Dense(7 * 7 * 64, input_dim=input_dim))
generator.add(LeakyReLU(0.2))
generator.add(Dropout(0.3))
generator.add(Reshape((7, 7, 64)))
generator.add(UpSampling2D(size=(2, 2)))

generator.add(Conv2D(64, kernel_size=(5, 5), padding='same'))
generator.add(LeakyReLU(0.2))
generator.add(Dropout(0.3))

generator.add(Conv2D(32, kernel_size=(3, 3), padding='same'))
generator.add(LeakyReLU(0.2))
generator.add(Dropout(0.3))
generator.add(UpSampling2D(size=(2, 2)))

generator.add(Conv2D(1, kernel_size=(5, 5), activation='tanh', padding='same'))

generator.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 3136)              316736    
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 3136)              0         
                                                                 
 dropout (Dropout)           (None, 3136)              0         
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 up_sampling2d (UpSampling2D  (None, 14, 14, 64)       0         
 )                                                               
                                                                 
 conv2d (Conv2D)             (None, 14, 14, 64)        102464    
                                                        

Построим дискриминатор

In [7]:
discriminator = Sequential()
    
discriminator.add(Conv2D(128, kernel_size=(7, 7), strides=(2, 2), input_shape=(28, 28, 1), padding='same'))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.2))
discriminator.add(MaxPool2D((2, 2), padding='same'))
    
discriminator.add(Conv2D(128, kernel_size=(3, 3), padding='same'))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.2))

discriminator.add(Flatten())

discriminator.add(Dense(1, activation='sigmoid'))

discriminator.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_3 (Conv2D)           (None, 14, 14, 128)       6400      
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 14, 14, 128)       0         
                                                                 
 dropout_3 (Dropout)         (None, 14, 14, 128)       0         
                                                                 
 max_pooling2d (MaxPooling2D  (None, 7, 7, 128)        0         
 )                                                               
                                                                 
 conv2d_4 (Conv2D)           (None, 7, 7, 128)         147584    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 7, 7, 128)         0         
                                                      

Создаем GAN

In [8]:
discriminator.compile(loss='binary_crossentropy',
                      optimizer=Adam(),
                      metrics=['accuracy'])

discriminator.trainable = False

In [9]:
generator.compile(loss='binary_crossentropy', optimizer=Adam())

In [10]:
model = Sequential()

model.add(generator)
model.add(discriminator)

model.compile(loss='binary_crossentropy', optimizer=Adam())

In [11]:
def save_imgs(epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("/content/content/mnist_%d.png" % epoch)
        plt.close()

Начнем тренировку с сохранением результатов генерации каждые 200 эпох

In [12]:
half_batch = int(batch_size / 2)

for epoch in range(1, epochs + 1):
  idx = np.random.randint(0, x_train.shape[0], half_batch)
  imgs = x_train[idx]

  noise = np.random.normal(0, 1, (half_batch, 100))

  gen_imgs = generator.predict(noise)

  d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
  d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

  noise = np.random.normal(0, 1, (batch_size, 100))

  valid_y = np.array([1] * batch_size)

  g_loss = model.train_on_batch(noise, valid_y)

  print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

  if epoch % 200 == 0:
    save_imgs(epoch)

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
508 [D loss: 0.073696, acc.: 98.83%] [G loss: 5.491120]
509 [D loss: 0.083884, acc.: 98.05%] [G loss: 5.504275]
510 [D loss: 0.043181, acc.: 100.00%] [G loss: 6.321898]
511 [D loss: 0.047155, acc.: 98.83%] [G loss: 6.624780]
512 [D loss: 0.061974, acc.: 98.83%] [G loss: 5.966385]
513 [D loss: 0.094944, acc.: 97.66%] [G loss: 5.319020]
514 [D loss: 0.062812, acc.: 98.83%] [G loss: 6.017850]
515 [D loss: 0.054627, acc.: 98.83%] [G loss: 6.205409]
516 [D loss: 0.067773, acc.: 98.83%] [G loss: 5.909723]
517 [D loss: 0.070653, acc.: 99.22%] [G loss: 5.568751]
518 [D loss: 0.083614, acc.: 98.83%] [G loss: 5.811996]
519 [D loss: 0.052196, acc.: 99.61%] [G loss: 6.305288]
520 [D loss: 0.076356, acc.: 97.66%] [G loss: 6.014046]
521 [D loss: 0.057236, acc.: 98.83%] [G loss: 5.588842]
522 [D loss: 0.075563, acc.: 98.44%] [G loss: 5.753854]
523 [D loss: 0.052624, acc.: 100.00%] [G loss: 6.382204]
524 [D loss: 0.09236