In [0]:
import tensorflow as tf
import numpy as np
import time
import os
import time
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import regularizers
from tensorflow.keras.callbacks import TensorBoard

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
x_train = (x_train - 127.5) / 127.5
buffer_size = 60000
batch_size = 256
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(buffer_size).batch(batch_size)


class gen(keras.Model):
    def __init__(self):
        super().__init__()
        self.layer1 = keras.layers.Dense(7 * 7 * 256, use_bias=False, input_shape=(100,))
        self.layer2 = keras.layers.BatchNormalization()
        self.layer3 = keras.layers.LeakyReLU()
        self.layer4 = keras.layers.Reshape((7, 7, 256))

        self.layer5 = keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)
        self.layer6 = keras.layers.BatchNormalization()
        self.layer7 = keras.layers.LeakyReLU()

        # Adding more layers with increased nodes
        self.layer8 = keras.layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False)
        self.layer9 = keras.layers.BatchNormalization()
        self.layer10 = keras.layers.LeakyReLU()

        self.layer11 = keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)
        self.layer12 = keras.layers.BatchNormalization()
        self.layer13 = keras.layers.LeakyReLU()

        self.layer14 = keras.layers.Conv2DTranspose(1, (5, 5), strides=(1, 1), padding='same', use_bias=False, activation='tanh')

    def call(self, inputs):
        x = self.layer1(inputs)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        x = self.layer8(x)  # New layer
        x = self.layer9(x)  # New layer
        x = self.layer10(x)  # New layer
        x = self.layer11(x)
        x = self.layer12(x)
        x = self.layer13(x)
        output = self.layer14(x)
        return output

class dis(keras.Model):
    def __init__(self):
        super().__init__()
        # Increased number of filters in Conv2D layer

        self.layer1 = keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')
        self.layer2 = keras.layers.LeakyReLU()
        self.layer3 = keras.layers.Dropout(0.3)  # Dropout层，用于随机断开输入神经元，防止过拟合

        self.layer4 = keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same', kernel_regularizer=regularizers.l2(0.01))
        self.layer5 = keras.layers.LeakyReLU()
        self.layer6 = keras.layers.Dropout(0.3)


        self.layer7 = keras.layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same',kernel_regularizer=regularizers.l2(0.01))
        self.layer8 = keras.layers.LeakyReLU()
        self.layer9 = keras.layers.Dropout(0.3)

        self.layer10 = keras.layers.Flatten()
        self.layer11 = keras.layers.Dense(1)

    def call(self, inputs):
        x = self.layer1(inputs)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)  # Increased number of filters
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        x = self.layer8(x)
        x = self.layer9(x)
        x = self.layer10(x)
        output = self.layer11(x)
        return output


#创建了一个二分类交叉熵损失函数的实例，设置from_logits=True表示输入的预测输出是经过 Sigmoid 函数处理之前的结果。
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output);
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output);
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)


#创建了一个Adam优化器实例，用于优化生成器模型的参数。学习率设置为1e-4
generator_optimizer = keras.optimizers.Adam(1e-4)
discriminator_optimizer = keras.optimizers.Adam(1e-4)

generator = gen()

discriminator = dis()

checkpoint_dir = './training_checkpoints'

checkpoint_prefix = os.path.join(checkpoint_dir,"ckpt");

checkpoint = tf.train.Checkpoint(generator_optimizer = generator_optimizer,
                                 discriminator_optimizer = discriminator_optimizer,
                                 generator = generator,
                                 discriminator = discriminator)

Epochs = 50
noise_dim = 100
num_example_to_generate = 16
seed = tf.random.normal([num_example_to_generate, noise_dim]);

@tf.function()
def train_step(images):

    noise = tf.random.normal([batch_size, noise_dim])# 生成噪声输入

    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:

        generator_images = generator(noise, training = True)

        real_output = discriminator(images)

        fake_output = discriminator(generator_images)

        gen_loss = generator_loss(fake_output)

        dis_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)# 计算生成器的梯度

    gradients_of_discriminator = dis_tape.gradient(dis_loss, discriminator.trainable_variables)# 计算判别器的梯度

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))# 使用生成器优化器来更新生成器参数

    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))# 使用判别器优化器来更新判别器参数

# 计算损失函数
def calculate_losses(real_images):
    noise = tf.random.normal([batch_size, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    return gen_loss, disc_loss

#生成并保存生成器模型在给定输入上的图像预测结果
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training = False)
    fig = plt.figure(figsize = (4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    plt.savefig(r'cimage_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

# 获取一批噪声数据作为测试输入
test_input = tf.random.normal([num_example_to_generate, noise_dim])

# 在每个 Epoch 结束时评估生成对抗网络
def evaluate(generator, discriminator, test_input):
    generated_images = generator(test_input, training=False)
    real_output = discriminator(x_train, training=False)
    fake_output = discriminator(generated_images, training=False)
    real_accuracy = tf.reduce_mean(tf.cast(tf.math.round(real_output), tf.float32))
    fake_accuracy = tf.reduce_mean(tf.cast(tf.math.round(1 - fake_output), tf.float32))
    return real_accuracy, fake_accuracy

# 执行训练
def train(dataset, epochs):
    tensorboard_callback = TensorBoard(log_dir='logs/', histogram_freq=1)
    for epoch in range(epochs):
        start = time.time()
        for image_batch in dataset:
            noise = tf.random.normal([batch_size, noise_dim])

            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                generated_images = generator(noise, training=True)

                real_output = discriminator(image_batch, training=True)
                fake_output = discriminator(generated_images, training=True)

                gen_loss = generator_loss(fake_output)
                disc_loss = discriminator_loss(real_output, fake_output)

            gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
            gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

            generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
            discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

        real_acc, fake_acc = evaluate(generator, discriminator, test_input)

        with summary_writer.as_default():
            tf.summary.scalar('Generator Loss', gen_loss, step=epoch)
            tf.summary.scalar('Discriminator Loss', disc_loss, step=epoch)
            tf.summary.scalar('Real Accuracy', real_acc, step=epoch)
            tf.summary.scalar('Fake Accuracy', fake_acc, step=epoch)

        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)

        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
        generate_and_save_images(generator, epoch, test_input)


train(dataset, Epochs)


KeyboardInterrupt

