In [None]:
# example of training the discriminator model on real and random mnist images
from numpy import expand_dims
from numpy import ones
from numpy import zeros
from numpy.random import rand
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Conv2D
from keras.layers import Flatten
from keras.layers import Dropout
from keras.layers import LeakyReLU

In [None]:
# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):
  model = Sequential()
  model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.4))
  model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.4))
  model.add(Flatten())
  model.add(Dense(1, activation='sigmoid'))
  # compile model
  opt = Adam(lr=0.0002, beta_1=0.5)
  model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
  return model

In [None]:
# load and prepare mnist training images
def load_real_samples():
  # load mnist dataset
  (trainX, _), (_, _) = load_data()
  # expand to 3d, e.g. add channels dimension
  X = expand_dims(trainX, axis=-1)
  # convert from unsigned ints to floats
  X = X.astype('float32')
  # scale from [0,255] to [0,1]
  X = X / 255.0
  return X

In [None]:
# select real samples
def generate_real_samples(dataset, n_samples):
  # choose random instances
  ix = randint(0, dataset.shape[0], n_samples)
  # retrieve selected images
  X = dataset[ix]
  # generate 'real' class labels (1)
  y = ones((n_samples, 1))
  return X, y

In [None]:
# generate n fake samples with class labels
def generate_fake_samples(n_samples):
  # generate uniform random numbers in [0,1]
  X = rand(28 * 28 * n_samples)
  # reshape into a batch of grayscale images
  X = X.reshape((n_samples, 28, 28, 1))
  # generate 'fake' class labels (0)
  y = zeros((n_samples, 1))
  return X, y

In [None]:
# train the discriminator model
def train_discriminator(model, dataset, n_iter=100, n_batch=256):
  half_batch = int(n_batch / 2)
  # manually enumerate epochs
  for i in range(n_iter):
    # get randomly selected 'real' samples
    X_real, y_real = generate_real_samples(dataset, half_batch)
    # update discriminator on real samples
    _, real_acc = model.train_on_batch(X_real, y_real)
    # generate 'fake' examples
    X_fake, y_fake = generate_fake_samples(half_batch)
    # update discriminator on fake samples
    _, fake_acc = model.train_on_batch(X_fake, y_fake)
    # summarize performance
    print('>%d real=%.0f%% fake=%.0f%%' % (i+1, real_acc*100, fake_acc*100))

In [None]:
# define the discriminator model
model = define_discriminator()
# load image data
dataset = load_real_samples()
# fit the model
train_discriminator(model, dataset)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
    8192/11490434 [..............................] - ETA: 0s

  super().__init__(name, **kwargs)


>1 real=56% fake=22%
>2 real=76% fake=37%
>3 real=69% fake=55%
>4 real=70% fake=73%
>5 real=58% fake=87%
>6 real=66% fake=95%
>7 real=63% fake=98%
>8 real=67% fake=99%
>9 real=55% fake=100%
>10 real=56% fake=100%
>11 real=62% fake=100%
>12 real=63% fake=100%
>13 real=58% fake=100%
>14 real=66% fake=100%
>15 real=56% fake=100%
>16 real=58% fake=100%
>17 real=66% fake=100%
>18 real=66% fake=100%
>19 real=63% fake=100%
>20 real=73% fake=100%
>21 real=77% fake=100%
>22 real=76% fake=100%
>23 real=84% fake=100%
>24 real=86% fake=100%
>25 real=89% fake=100%
>26 real=91% fake=100%
>27 real=91% fake=100%
>28 real=91% fake=100%
>29 real=97% fake=100%
>30 real=93% fake=100%
>31 real=95% fake=100%
>32 real=98% fake=100%
>33 real=98% fake=100%
>34 real=100% fake=100%
>35 real=99% fake=100%
>36 real=98% fake=100%
>37 real=100% fake=100%
>38 real=100% fake=100%
>39 real=100% fake=100%
>40 real=100% fake=100%
>41 real=99% fake=100%
>42 real=100% fake=100%
>43 real=99% fake=100%
>44 real=100% fake=100