In [1]:
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.cifar10 import load_data
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Dropout
from matplotlib import pyplot

2024-05-14 18:59:54.289952: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-14 18:59:54.290005: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-14 18:59:54.290047: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-14 18:59:54.299675: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# define the standalone discriminator model
def define_discriminator(in_shape=(32,32,3)):
    model = Sequential()
    # normal
    model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    # downsample
    model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # downsample
    model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # downsample
    model.add(Conv2D(256, (3,3), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # classifier
    model.add(Flatten())
    model.add(Dropout(0.4))
    model.add(Dense(1, activation='sigmoid'))
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

In [3]:
# define the standalone generator model
def define_generator(latent_dim):
    model = Sequential()
    # foundation for 4x4 image
    n_nodes = 256 * 4 * 4
    model.add(Dense(n_nodes, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((4, 4, 256)))
    # upsample to 8x8
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # upsample to 16x16
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # upsample to 32x32
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # output layer
    model.add(Conv2D(3, (3,3), activation='tanh', padding='same'))
    return model

In [4]:
# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # connect them
    model = Sequential()
    # add generator
    model.add(g_model)
    # add the discriminator
    model.add(d_model)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

In [5]:
#THIS ONE
import numpy as np

# load and prepare cifar10 training images
def load_real_samples():
    # Load CIFAR-10 dataset
    (train_images, train_labels), (test_images, test_labels) = load_data()

    # Define class names
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                'dog', 'frog', 'horse', 'ship', 'truck']

    # Find index of the 'automobile' class
    automobile_class_index = class_names.index('automobile')

    # Filter images and labels to include only 'automobile' class
    automobile_images = train_images[np.where(train_labels == automobile_class_index)[0]]
    # convert from unsigned ints to floats
    X = automobile_images.astype('float32')
    # scale from [0,255] to [-1,1]
    X = (X - 127.5) / 127.5
    return X

In [6]:
# select real samples
def generate_real_samples(dataset, n_samples):
    # choose random instances
    ix = randint(0, dataset.shape[0], n_samples)
    # retrieve selected images
    X = dataset[ix]
    # generate ✬real✬ class labels (1)
    y = ones((n_samples, 1))
    return X, y

In [7]:
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

In [8]:
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples):
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    X = g_model.predict(x_input)
    # create ✬fake✬ class labels (0)
    y = zeros((n_samples, 1))
    return X, y

In [9]:
# create and save a plot of generated images
def save_plot(examples, epoch, n=7):
    # scale from [-1,1] to [0,1]
    examples = (examples + 1) / 2.0
    # plot images
    for i in range(n * n):
        # define subplot
        pyplot.subplot(n, n, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(examples[i])
    # save plot to file
    filename = 'generated_plot_e%03d.png' % (epoch+1)
    pyplot.savefig(filename)
    pyplot.close()

In [10]:
# evaluate the discriminator, plot generated images, save generator model
def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples=150):
    # prepare real samples
    X_real, y_real = generate_real_samples(dataset, n_samples)
    # evaluate discriminator on real examples
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
    # prepare fake examples
    x_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
    # evaluate discriminator on fake examples
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)
    # summarize discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    # save plot
    save_plot(x_fake, epoch)
    # save the generator model tile file
    filename = 'generator_model_%03d.h5' % (epoch+1)
    g_model.save(filename)

In [11]:
# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=50, n_batch=256):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
    # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected ✬real✬ samples
            X_real, y_real = generate_real_samples(dataset, half_batch)
            # update discriminator model weights
            d_loss1, _ = d_model.train_on_batch(X_real, y_real)
            # generate ✬fake✬ examples
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # update discriminator model weights
            d_loss2, _ = d_model.train_on_batch(X_fake, y_fake)
            # prepare points in latent space as input for the generator
            X_gan = generate_latent_points(latent_dim, n_batch)
            # create inverted labels for the fake samples
            y_gan = ones((n_batch, 1))
            # update the generator via the discriminator✬s error
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            # summarize loss on this batch
            print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %
            (i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))
            # evaluate the model performance, sometimes
        if (i+1) % 10 == 0:
            summarize_performance(i, g_model, d_model, dataset, latent_dim)

In [12]:
# size of the latent space
latent_dim = 100
# create the discriminator
d_model = define_discriminator()
# create the generator
g_model = define_generator(latent_dim)
# create the gan
gan_model = define_gan(g_model, d_model)
# load image data
dataset = load_real_samples()
# train model
train(g_model, d_model, gan_model, dataset, latent_dim)

2024-05-14 18:59:56.586552: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:894] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-05-14 18:59:56.623001: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:894] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-05-14 18:59:56.623215: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:894] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysf

>1, 1/19, d1=0.694, d2=0.710 g=0.689
>1, 2/19, d1=0.419, d2=0.799 g=0.665
>1, 3/19, d1=0.065, d2=0.951 g=0.882
>1, 4/19, d1=0.035, d2=0.617 g=0.915
>1, 5/19, d1=0.026, d2=0.618 g=0.875
>1, 6/19, d1=0.007, d2=0.944 g=0.862
>1, 7/19, d1=0.013, d2=2.050 g=1.005
>1, 8/19, d1=0.760, d2=0.652 g=0.772
>1, 9/19, d1=0.279, d2=0.638 g=0.815
>1, 10/19, d1=0.112, d2=0.590 g=0.882
>1, 11/19, d1=0.065, d2=0.579 g=0.881
>1, 12/19, d1=0.074, d2=0.593 g=0.872
>1, 13/19, d1=0.022, d2=0.525 g=1.053
>1, 14/19, d1=0.020, d2=0.724 g=1.332
>1, 15/19, d1=0.006, d2=1.813 g=2.212
>1, 16/19, d1=0.191, d2=0.173 g=2.641
>1, 17/19, d1=0.077, d2=0.204 g=3.317
>1, 18/19, d1=0.011, d2=5.023 g=6.377
>1, 19/19, d1=1.727, d2=0.320 g=1.264
>2, 1/19, d1=0.265, d2=0.672 g=1.242
>2, 2/19, d1=0.261, d2=0.430 g=1.682
>2, 3/19, d1=0.397, d2=0.352 g=1.940
>2, 4/19, d1=0.186, d2=0.219 g=2.282
>2, 5/19, d1=0.123, d2=0.278 g=2.394
>2, 6/19, d1=0.060, d2=2.677 g=2.861
>2, 7/19, d1=0.222, d2=0.048 g=4.133
>2, 8/19, d1=0.572, d2=0.363

  saving_api.save_model(


>11, 1/19, d1=0.679, d2=0.808 g=1.944
>11, 2/19, d1=0.824, d2=0.519 g=1.603
>11, 3/19, d1=0.764, d2=0.487 g=1.591
>11, 4/19, d1=0.635, d2=0.425 g=1.716
>11, 5/19, d1=0.642, d2=0.447 g=1.682
>11, 6/19, d1=0.550, d2=0.490 g=1.643
>11, 7/19, d1=0.652, d2=0.634 g=1.314
>11, 8/19, d1=0.573, d2=0.638 g=1.339
>11, 9/19, d1=0.522, d2=0.555 g=1.478
>11, 10/19, d1=0.476, d2=0.502 g=1.790
>11, 11/19, d1=0.539, d2=0.523 g=1.677
>11, 12/19, d1=0.455, d2=0.529 g=1.842
>11, 13/19, d1=0.450, d2=0.526 g=2.151
>11, 14/19, d1=0.566, d2=0.473 g=2.366
>11, 15/19, d1=0.482, d2=0.669 g=2.425
>11, 16/19, d1=0.669, d2=0.675 g=2.086
>11, 17/19, d1=0.667, d2=0.551 g=1.945
>11, 18/19, d1=0.501, d2=0.502 g=2.137
>11, 19/19, d1=0.495, d2=0.516 g=1.896
>12, 1/19, d1=0.582, d2=0.872 g=1.802
>12, 2/19, d1=0.551, d2=0.611 g=1.564
>12, 3/19, d1=0.643, d2=0.654 g=1.420
>12, 4/19, d1=0.634, d2=0.610 g=1.629
>12, 5/19, d1=0.485, d2=0.452 g=1.890
>12, 6/19, d1=0.562, d2=0.439 g=2.005
>12, 7/19, d1=0.483, d2=0.592 g=2.195
>1



>21, 1/19, d1=0.471, d2=0.676 g=1.603
>21, 2/19, d1=0.571, d2=0.580 g=1.493
>21, 3/19, d1=0.418, d2=0.592 g=1.471
>21, 4/19, d1=0.452, d2=0.646 g=1.485
>21, 5/19, d1=0.487, d2=0.609 g=1.481
>21, 6/19, d1=0.436, d2=0.773 g=1.515
>21, 7/19, d1=0.502, d2=0.596 g=1.486
>21, 8/19, d1=0.483, d2=0.588 g=1.541
>21, 9/19, d1=0.522, d2=0.589 g=1.660
>21, 10/19, d1=0.484, d2=0.567 g=1.607
>21, 11/19, d1=0.566, d2=0.717 g=1.540
>21, 12/19, d1=0.565, d2=0.791 g=1.731
>21, 13/19, d1=0.745, d2=0.698 g=1.680
>21, 14/19, d1=0.796, d2=0.691 g=1.740
>21, 15/19, d1=0.826, d2=0.459 g=1.719
>21, 16/19, d1=0.624, d2=0.493 g=1.784
>21, 17/19, d1=0.426, d2=0.524 g=2.068
>21, 18/19, d1=0.496, d2=0.719 g=2.161
>21, 19/19, d1=0.624, d2=0.554 g=1.930
>22, 1/19, d1=0.832, d2=0.547 g=1.410
>22, 2/19, d1=0.690, d2=0.829 g=1.238
>22, 3/19, d1=0.590, d2=0.743 g=1.327
>22, 4/19, d1=0.710, d2=0.604 g=1.376
>22, 5/19, d1=0.630, d2=0.503 g=1.440
>22, 6/19, d1=0.549, d2=0.514 g=1.493
>22, 7/19, d1=0.506, d2=0.527 g=1.639
>2



>31, 1/19, d1=0.671, d2=0.610 g=1.232
>31, 2/19, d1=0.619, d2=0.635 g=1.254
>31, 3/19, d1=0.621, d2=0.540 g=1.315
>31, 4/19, d1=0.630, d2=0.557 g=1.276
>31, 5/19, d1=0.565, d2=0.671 g=1.402
>31, 6/19, d1=0.619, d2=0.546 g=1.457
>31, 7/19, d1=0.665, d2=0.597 g=1.405
>31, 8/19, d1=0.699, d2=0.728 g=1.285
>31, 9/19, d1=0.700, d2=0.573 g=1.262
>31, 10/19, d1=0.642, d2=0.605 g=1.300
>31, 11/19, d1=0.594, d2=0.656 g=1.269
>31, 12/19, d1=0.667, d2=0.641 g=1.255
>31, 13/19, d1=0.605, d2=0.612 g=1.241
>31, 14/19, d1=0.614, d2=0.602 g=1.186
>31, 15/19, d1=0.595, d2=0.655 g=1.222
>31, 16/19, d1=0.572, d2=0.599 g=1.193
>31, 17/19, d1=0.490, d2=0.580 g=1.264
>31, 18/19, d1=0.540, d2=0.571 g=1.351
>31, 19/19, d1=0.495, d2=0.533 g=1.416
>32, 1/19, d1=0.502, d2=0.576 g=1.368
>32, 2/19, d1=0.482, d2=0.653 g=1.427
>32, 3/19, d1=0.574, d2=0.935 g=1.663
>32, 4/19, d1=0.745, d2=0.491 g=1.470
>32, 5/19, d1=0.706, d2=0.727 g=1.335
>32, 6/19, d1=0.745, d2=0.704 g=1.356
>32, 7/19, d1=0.727, d2=0.588 g=1.389
>3



>41, 1/19, d1=0.546, d2=0.650 g=1.617
>41, 2/19, d1=0.656, d2=0.542 g=1.718
>41, 3/19, d1=0.649, d2=0.768 g=1.703
>41, 4/19, d1=0.624, d2=0.614 g=1.577
>41, 5/19, d1=0.552, d2=0.765 g=1.629
>41, 6/19, d1=0.683, d2=3.069 g=1.674
>41, 7/19, d1=0.773, d2=2.063 g=4.376
>41, 8/19, d1=2.070, d2=0.601 g=1.576
>41, 9/19, d1=1.013, d2=0.602 g=1.248
>41, 10/19, d1=0.771, d2=0.600 g=1.285
>41, 11/19, d1=0.668, d2=0.588 g=1.407
>41, 12/19, d1=0.671, d2=0.613 g=1.392
>41, 13/19, d1=0.647, d2=0.672 g=1.327
>41, 14/19, d1=0.700, d2=0.759 g=1.223
>41, 15/19, d1=0.765, d2=0.751 g=1.291
>41, 16/19, d1=0.846, d2=0.716 g=1.252
>41, 17/19, d1=0.743, d2=0.674 g=1.256
>41, 18/19, d1=0.806, d2=0.651 g=1.204
>41, 19/19, d1=0.714, d2=0.669 g=1.250
>42, 1/19, d1=0.672, d2=0.605 g=1.224
>42, 2/19, d1=0.631, d2=0.566 g=1.202
>42, 3/19, d1=0.558, d2=0.605 g=1.235
>42, 4/19, d1=0.498, d2=0.535 g=1.269
>42, 5/19, d1=0.434, d2=0.662 g=1.185
>42, 6/19, d1=0.509, d2=0.708 g=1.198
>42, 7/19, d1=0.492, d2=0.750 g=1.296
>4

