In [1]:
import tensorflow as tf
tf.config.run_functions_eagerly(True)

from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D

import matplotlib.pyplot as plt
import numpy as np
import sys, os, pathlib

In [2]:
img_shape = (28, 28, 1) # MNIST图像尺寸
latent_dim = 200 # 噪声向量的维度

In [3]:
os.makedirs('images', exist_ok=True) # 创建保存图像的文件夹

In [4]:
def build_generator():
    # ======================================= #
    #     生成器，输入一串随机噪声向量生成图片
    # ======================================= #
    model = Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2), #相比于传统的ReLU:f(x) = max(0, x)--->LeakyReLU:f(x)=max(alpha*x, x) alpha是一个很小的正数 解决了神经元死亡问题 可帮助梯度更好地流动
        layers.BatchNormalization(momentum=0.8),

        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),

        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),

        layers.Dense(np.prod(img_shape), activation='tanh'),
        layers.Reshape(img_shape)
    ])

    noise = layers.Input(shape=(latent_dim,))
    img = model(noise)

    return Model(noise, img)

In [5]:
def build_discriminator():
    # ===================================== #
    #   鉴别器，对输入的图片进行判别真假
    # ===================================== #
    model = Sequential([
        layers.Flatten(input_shape=img_shape),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid') # 输出真图概率
    ])

    img = layers.Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

In [6]:
discriminator = build_discriminator()

optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

generator = build_generator()

discriminator.trainable = False # 停止鉴别器学习
# 生成器
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)
validity = discriminator(img)
combined = Model(gan_input, validity) # 输入噪声--造假--鉴别
combined.compile(loss='binary_crossentropy', optimizer=optimizer) # target: 让discriminator将fake鉴别为true

  super().__init__(**kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [7]:
def sample_images(epoch):
    '''
    saving images
    '''
    row, col = 4, 4
    noise = np.random.normal(0, 1, (row*col, latent_dim))
    gen_imgs = generator.predict(noise)

    fig, axs = plt.subplots(row, col)
    cnt = 0
    for i in range(row):
        for j in range(col):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    fig.savefig('images/%05d.png' % epoch)
    plt.close()

In [8]:
def train(epochs, batch_size=128, sample_interval=50):
    # 加载并预处理MNIST数据
    (train_images,_), (_,_) = tf.keras.datasets.mnist.load_data() 
    train_images = (train_images - 127.5) / 127.5 # 归一化到[-1, 1] 适用于tanh激活函数的生成器输出层
    train_images = np.expand_dims(train_images, axis=3)

    for epoch in range(epochs):

        # discriminator.trainable = True
        
        # 训练判别器
        idx = np.random.randint(0, train_images.shape[0], batch_size)
        imgs = train_images[idx] # 真实图像

        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        gen_imgs = generator.predict(noise, verbose=0) # 生成图像 verbose=0 不显示任何进度信息 =1显示进度条 =2只显示每个epoch信息

        true = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
    
        d_loss_true = discriminator.train_on_batch(imgs, true) # 真图标签为1
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake) # 假图标签为0
        d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)
        
        # 训练生成器
        # discriminator.trainable = False
        
        noise = np.random.normal(0, 1, (batch_size, latent_dim))

        g_loss = combined.train_on_batch(noise, true) # 期望输出1 实际输出是判别器的真实判断概率值 计算交叉熵损失
        
        print('%d [D loss: %f, acc.: %.2f%%] [G loss: %f]' % (epoch, d_loss[0], 100*d_loss[1], g_loss))
        # if epoch % 10 == 0:
        #     print(f"{epoch} [D loss: {d_loss[0]:.6f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.6f}]")
            
        # 打印损失并保存样本图像
        if epoch % sample_interval == 0:
            sample_images(epoch)
            #if epoch % 100 == 0:
            #    plt.show()


In [9]:
train(epochs=1000, batch_size=256, sample_interval=200)

0 [D loss: 0.947463, acc.: 29.10%] [G loss: 0.867090]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step



[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
1 [D loss: 0.896657, acc.: 42.12%] [G loss: 0.846413]
2 [D loss: 0.893049, acc.: 42.14%] [G loss: 0.833053]
3 [D loss: 0.891344, acc.: 40.05%] [G loss: 0.816624]
4 [D loss: 0.891672, acc.: 37.64%] [G loss: 0.804610]
5 [D loss: 0.896650, acc.: 35.38%] [G loss: 0.791498]
6 [D loss: 0.899874, acc.: 33.95%] [G loss: 0.775906]
7 [D loss: 0.902210, acc.: 32.62%] [G loss: 0.764127]
8 [D loss: 0.906665, acc.: 31.11%] [G loss: 0.750302]
9 [D loss: 0.911955, acc.: 29.63%] [G loss: 0.737507]
10 [D loss: 0.917282, acc.: 28.23%] [G loss: 0.725954]
11 [D loss: 0.922218, acc.: 27.04%] [G loss: 0.713138]
12 [D loss: 0.926288, acc.: 25.98%] [G loss: 0.701705]
13 [D loss: 0.931378, acc.: 24.97%] [G loss: 0.689989]
14 [D loss: 0.937610, acc.: 23.91%] [G loss: 0.678642]
15 [D loss: 0.943731, acc.: 22.96%] [G loss: 0.667809]
16 [D loss: 0.949052, acc.: 22.09%] [G loss: 0.657225]
17 [D loss: 0.956511, acc.: 21.16%] [G loss: 0.647309]
18

In [10]:
import imageio

def compose_gif():
    data_dir = 'images'
    data_dir = pathlib.Path(data_dir)
    paths = list(data_dir.glob('*'))

    gif_images = []
    for path in paths:
        print(path)
        gif_images.append(imageio.imread(path))
    imageio.mimsave('test.gif', gif_images, fps=2)

compose_gif()

images/00800.png
images/00400.png
images/00600.png
images/00200.png
images/00000.png


  gif_images.append(imageio.imread(path))
