In [40]:
#load and preprocess data
import tensorflow as tf
import numpy as np
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
#use all the available examples from both the training and test sets
all_digits = np.concatenate([train_images, test_images])
all_labels = np.concatenate([train_labels, test_labels])
# Scale the pixel values to [0, 1] range
all_digits = all_digits.astype("float32") / 255.0
#add a channel dimension to the images
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
#one-hot encode the labels
all_labels = tf.keras.utils.to_categorical(all_labels, 10)

# Create tf.data.Dataset.
batch_size = 64
dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)


In [41]:
latent_dim = 128    #dim of noise vector
num_classes = 10    #dim of class(one-hot) vector
num_channels = 1    #number of image channel(it is gray-scale image)
image_size = 28

In [42]:
#concatenate noise vector and class vector for generator input
generator_in_channels = latent_dim + num_classes
#number of channels of the input image to discriminator
discriminator_in_channels = num_channels + num_classes

In [43]:
#build generator
inp_gen = tf.keras.layers.Input(shape=(generator_in_channels,))
# generate 128 + num_classes coefficients to reshape into a 7x7x(128 + num_classes) map.
x = tf.keras.layers.Dense(7 * 7 * generator_in_channels, use_bias = False)(inp_gen)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = tf.keras.layers.Reshape((7, 7, generator_in_channels))(x)
x = tf.keras.layers.Conv2DTranspose(128, (4,4), strides=2, padding='same', use_bias=False)(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = tf.keras.layers.Conv2DTranspose(128, (4,4), strides=2, padding='same', use_bias=False)(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
out_gen = tf.keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid")(x)
generator = tf.keras.models.Model(inp_gen, out_gen)



In [44]:
generator.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 138)]             0         
_________________________________________________________________
dense_2 (Dense)              (None, 6762)              933156    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 6762)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 138)         0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 14, 14, 128)       282624    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 28, 28, 128)       2621

In [45]:
#build discriminator
inp_disc = tf.keras.layers.Input(shape=(28,28, discriminator_in_channels))
x = tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same')(inp_disc)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same')(inp_disc)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = tf.keras.layers.GlobalMaxPooling2D()(x)
out_disc = tf.keras.layers.Dense(units=1)(x)

discriminator = tf.keras.models.Model(inp_disc, out_disc)

In [46]:
discriminator.summary()

Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 28, 28, 11)]      0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 14, 14, 128)       12800     
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 14, 14, 128)       0         
_________________________________________________________________
global_max_pooling2d_1 (Glob (None, 128)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 129       
Total params: 12,929
Trainable params: 12,929
Non-trainable params: 0
_________________________________________________________________


In [47]:
train_iterations = 10000

In [54]:
train_iterations = 1
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)

gen_loss_tracker = tf.keras.metrics.Mean(name="generator_loss")
disc_loss_tracker = tf.keras.metrics.Mean(name="discriminator_loss")

for itr in range(train_iterations):
    # Unpack the data(one batch of data)
    real_images, one_hot_labels = next(iter(dataset))

    #creating kind of one-hot encoder , each element is in size of image with all 1 or all 0
    image_one_hot_labels = one_hot_labels[:, :, None, None]
    image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=[image_size * image_size])
    image_one_hot_labels = tf.reshape(image_one_hot_labels, (-1, image_size, image_size, num_classes))

    # Sample random points in the latent space and concatenate the labels.
    batch_size = tf.shape(real_images)[0]
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    random_vector_labels = tf.concat([random_latent_vectors, one_hot_labels], axis=1)
    generated_images = generator(random_vector_labels)
    
    # Combine them with real images. Note that we are concatenating the labels with these images here.
    fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)
    real_image_and_labels = tf.concat([real_images, image_one_hot_labels], -1)
    combined_images = tf.concat([fake_image_and_labels, real_image_and_labels], axis=0)

    # Assemble labels discriminating real from fake images.
    labels = tf.concat([tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0)

    #train discriminator
    with tf.GradientTape() as disc_tape:
         logits = discriminator(combined_images)
         d_loss = cross_entropy(labels, logits)
    grads = disc_tape.gradient(d_loss, discriminator.trainable_weights)
    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))

    #train generator
    # Sample random points in the latent space.
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    random_vector_labels = tf.concat([random_latent_vectors, one_hot_labels], axis=1)

    # Assemble labels that say "all real images".
    misleading_labels = tf.ones((batch_size, 1))
    with tf.GradientTape() as gen_tape:
         fake_images = generator(random_vector_labels)
         fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], -1)
         logits = discriminator(fake_image_and_labels)
         g_loss = cross_entropy(misleading_labels, logits)
    grads = gen_tape.gradient(g_loss, generator.trainable_weights)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))

     # Monitor loss.
    if (itr + 1) % 100 == 0:
         gen_loss_tracker.update_state(g_loss)
         disc_loss_tracker.update_state(d_loss)
         g_loss = gen_loss_tracker.result()
         d_loss = disc_loss_tracker.result()
         print('train BinaryCrossentropyloss discriminator/generator loss at iteration %d: %.4f/%.4f ' % (itr+1, float(d_loss), float(g_loss)))


In [10]:
dataset

<BatchDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>

In [None]:
for data in dataset:
    real_images, one_hot_labels = data

In [21]:
# Unpack the data(one batch of data)
real_images, one_hot_labels = next(iter(dataset))


In [None]:
image_one_hot_labels = one_hot_labels[:, :, None, None]
image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=[image_size * image_size])
image_one_hot_labels = tf.reshape(image_one_hot_labels, (-1, image_size, image_size, num_classes))

In [25]:
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
random_vector_labels = tf.concat([random_latent_vectors, one_hot_labels], axis=1)
generated_images = generator(random_vector_labels)



2023-02-27 14:51:54.700026: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.10
2023-02-27 14:51:54.918164: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.7


In [26]:
generated_images.shape

TensorShape([64, 28, 28, 1])

In [33]:
image_one_hot_labels = one_hot_labels[:, :, None, None]
image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=[image_size * image_size])
image_one_hot_labels = tf.reshape(image_one_hot_labels, (-1, image_size, image_size, num_classes))
image_one_hot_labels

<tf.Tensor: shape=(64, 28, 28, 10), dtype=float32, numpy=
array([[[[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]],

        [[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]],

        [[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ...,

In [28]:
one_hot_labels[:, :, None, None]

<tf.Tensor: shape=(64, 10, 1, 1), dtype=float32, numpy=
array([[[[1.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]]],


       [[[1.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]]],


       [[[0.]],

        [[1.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]]],


       [[[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[1.]],

        [[0.]]],


       [[[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[1.]]],


       [[[0.]],

        [[0.]],

        [[0.]],

        [[1.]],

        [[0.]],
