# 课时62 ACGAN人脸图像生成

In [58]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import glob
import random
tf.__version__

'2.1.0'

## 1. 读取数据以及数据预处理

In [59]:
images_path = glob.glob(pathname='../日月光华-GAN大型数据集/2.0部分数据集/face/*/*.jpg')

In [60]:
# 对图片路径进行乱序
np.random.seed(2020)
np.random.shuffle(images_path)

In [61]:
# 使用列表推导式获取数据的标签
labels = [p.split('\\')[1] for p in images_path]
# 将字符标签转换为数值标签
cls_to_num = dict((name,i) for i, name in enumerate(np.unique(labels)))
num_to_cls = dict((name, i) for i, name in cls_to_num.items())
labels = [cls_to_num.get(name) for name in labels]

In [62]:
# 建立dataset(from_to_slice需要的是ndarray格式的数据)
images_path = np.array(images_path)
labels = np.array(labels)

In [84]:
@tf.function
def load_images(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img)
    # 原有图像的大小不一，因此需要先resize到一致大小(如果有下面一步的步骤的话，这一步不一定非得要做)
    # 由于resize会造成图像的畸变，因此将来生成的图片也会产生一定的扭曲
    # 因此也可以直接通过random_crop来crop出一致尺寸的图像或者通过图像填充padding的方式也可以
    img = tf.image.resize(img, (80, 80))
    img = tf.image.random_crop(img, [64, 64, 3])
    img = tf.image.flip_left_right(img)
    img = img / 127.5 - 1
    return img

In [85]:
# 定义图片的datasets
AUTOTUNE = tf.data.experimental.AUTOTUNE
images_dataset = tf.data.Dataset.from_tensor_slices(images_path)
images_dataset = images_dataset.map(load_images, num_parallel_calls=AUTOTUNE)

In [65]:
images_dataset

<MapDataset shapes: (64, 64, 3), types: tf.float32>

In [66]:
# 创建labels_dataset
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
training_datasets = tf.data.Dataset.zip((images_dataset, labels_dataset))

In [67]:
# 定义一些超参数
BATCH_SIZE = 16
noise_dim = 100
image_count = len(images_path)
class_nums = len(num_to_cls)

In [68]:
training_datasets = training_datasets.shuffle(300).batch(BATCH_SIZE)

## 2. 定义生成器和判别器模型

In [201]:
def generator_model():
    # 由于是两个输入，因此采用的是函数式API进行模型搭建
    noise = tf.keras.layers.Input(shape=((noise_dim,)))
    # 这里的标签值是一个单个的值，可以从train_datasets中可以看出来
    condition_label = tf.keras.layers.Input(shape=(()))
    
    # 需要将输入的noise和condition_label进行合并concat
    # 在进行合并之前，由于condition_label的shape=()，不太好进行合并
    # 因此需要使用Embedding函数将其转换成我们制定shape的一个向量才好进行合并
    # 其中output_dim=100代表将condition_label映射到长度与noise长度相同的维度
    # https://www.jianshu.com/p/e8986d0ff4ff
    # https://blog.csdn.net/claroja/article/details/95196612
    # 需要注意的是input_dim代表的是整个数据的词汇表的个数，这里整个MNIST数据集也就是10个数字
    # 而input_length则代表每次输入的序列的长度。
    x = tf.keras.layers.Embedding(input_dim=class_nums, output_dim=noise_dim, 
                                  input_length=1)(condition_label)
    # ????
#     x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.concatenate([noise, x])
    # 合并完成之后，现在再利用全连接层将现在合并完之后的向量转换为合适shape的向量
    # 方便后续以合适shape的向量为基准开始反卷积，知道反卷积到合适尺寸的图片大小
    x = tf.keras.layers.Dense(units=4*4*64*8, use_bias=False)(x)
    x = tf.keras.layers.Reshape(target_shape=(4, 4, 64*8))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    # [8, 8, 64*8]
    x = tf.keras.layers.Conv2DTranspose(filters=64*4, kernel_size=(5, 5), strides=(2, 2), 
                                        use_bias=False, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    # [16, 16, 64*2]
    x = tf.keras.layers.Conv2DTranspose(filters=64*2, kernel_size=(5, 5), strides=(2, 2), 
                                        use_bias=False, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    # [32, 32, 64]
    x = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=(5, 5), strides=(2, 2), 
                                        use_bias=False, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    # [64, 64, 3]
    x = tf.keras.layers.Conv2DTranspose(filters=3, kernel_size=(5, 5), strides=(2, 2), 
                                        use_bias=False, padding='same')(x)
    x = tf.keras.layers.Activation('tanh')(x)
    
    model = tf.keras.Model(inputs=[noise, condition_label], outputs=x)
    return model

In [190]:
def discriminator_model():
    input_image = tf.keras.layers.Input(shape=((64, 64, 3)))
    # ACGAN没有condition_label输入，所以注释掉
    # condition_label = tf.keras.layers.Input(shape=(()))
    
    # x = tf.keras.layers.Embedding(input_dim=10, output_dim=28*28, input_length=1)(condition_label)
    # x = tf.keras.layers.Reshape(target_shape=((28, 28, 1)))(x)
    # x = tf.keras.layers.concatenate([input_image, x])
    
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(2, 2), 
                               padding='same', use_bias=False)(input_image)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(rate=0.5)(x)
    
    x = tf.keras.layers.Conv2D(filters=64*2, kernel_size=(3, 3), strides=(2, 2), 
                               padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(rate=0.5)(x)
    
    x = tf.keras.layers.Conv2D(filters=64*4, kernel_size=(3, 3), strides=(2, 2), 
                               padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(rate=0.5)(x)
    
    x = tf.keras.layers.Conv2D(filters=64*8, kernel_size=(3, 3), strides=(2, 2), 
                               padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(rate=0.5)(x)
    
    x = tf.keras.layers.Flatten()(x)
    # 图片真假(real/fake)输出
    real_or_fake_outputs_logits = tf.keras.layers.Dense(units=1)(x)
    # 图片类别判断输出(MNIST是十分类)
    category_outputs_logits = tf.keras.layers.Dense(units=class_nums)(x)
    
    model = tf.keras.Model(inputs=input_image, 
                           outputs=[real_or_fake_outputs_logits, category_outputs_logits])
    return model

## 3. 定义损失函数及优化器

In [191]:
generator = generator_model()
discriminator = discriminator_model()

In [192]:
Binary_Crossentropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# 增加图片类别分类损失(SparseCategoricalCrossentropy)
# 由于这里男女类别是0/1的，而不是one-hot的，所以还是用SparseCategoricalCrossentropy进行定义
Category_Cross_Entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [193]:
def discriminator_loss(real_image_outputs, pred_class_outs, fake_image_outputs, real_class_label):
    real_image_loss = Binary_Crossentropy(y_true=tf.ones_like(real_image_outputs), y_pred=real_image_outputs)
    fake_image_loss = Binary_Crossentropy(y_true=tf.zeros_like(fake_image_outputs), y_pred=fake_image_outputs)
    
    category_loss = Category_Cross_Entropy(y_true=real_class_label, y_pred=pred_class_outs)
    d_total_loss = real_image_loss + fake_image_loss + category_loss
    return d_total_loss

In [194]:
def generator_loss(fake_image_outputs, pred_class_outs, real_class_label):
    fake_image_loss = Binary_Crossentropy(y_true=tf.ones_like(fake_image_outputs), y_pred=fake_image_outputs)
    category_loss = Category_Cross_Entropy(y_true=real_class_label, y_pred=pred_class_outs)
    
    g_total_loss = fake_image_loss + category_loss
    return g_total_loss

In [195]:
generator_optimizer = tf.keras.optimizers.Adam(1e-5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-5)

## 4. 定义梯度更新函数

In [196]:
@tf.function
def train_step(images, labels):
    batchsize = labels.shape[0]
    noise = tf.random.normal([batchsize, noise_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(inputs=(noise, labels), training=True)
        fake_image_out, fake_class_outs = discriminator(inputs=generated_images, training=True)
        real_image_out, real_class_outs = discriminator(inputs=images, training=True)
        
        generator_loss_ = generator_loss(fake_image_out, fake_class_outs, labels)
        discriminator_loss_ = discriminator_loss(real_image_out, real_class_outs, fake_image_out, labels)
        
    generator_gradients = gen_tape.gradient(generator_loss_, generator.trainable_variables)
    disciminator_gradients = disc_tape.gradient(discriminator_loss_, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(disciminator_gradients, discriminator.trainable_variables))

## 5. 定义辅助绘图函数

In [197]:
# 设置模型训练后期查看模型效果的noise和label，这里取定了之后可以在后续展示的时候展示一样的图片
noise_seed = tf.random.normal([10, noise_dim])
label_seed = np.random.randint(0, class_nums, size=(10, 1))
# condition代表字符的标签，方便画图的时候取用
condition = [num_to_cls.get(n) for n in label_seed.T[0]]

In [198]:
def plot_generator_images(model, noise, label, epoch_num):
    print('现在是第%i个epoch.'%(epoch_num))
    generated_images = model(inputs=(noise, label), training=False)
    fig = plt.figure(figsize=(10, 1))
    for i in range(generated_images.shape[0]):
        plt.subplot(1, 10, i+1)
        plt.imshow((generated_images[i, :, :, :]+1)/2, cmap='gray')
        plt.title(condition[i])
        plt.axis('off')
    plt.show()

## 6. 定义模型训练函数

In [199]:
def train(datasets, epochs):
    for epoch in range(epochs):
        print('Epoch is:', epoch)
        for images_batch, labels_batch in datasets:
            train_step(images_batch, labels_batch)
            print('.', end=' ')
        print()
        if epoch % 10 == 0:
            plot_generator_images(generator, noise_seed, label_seed, epoch)
    plot_generator_images(generator, noise_seed, label_seed, epoch)

In [200]:
train(training_datasets, 10000)

Epoch is: 0
. . 

KeyboardInterrupt: 

In [None]:
generator.save('generate_v2.h5')
num = 10
noise_seed = tf.random.normal([num, noise_dim])
cat_seed = np.arange(10).reshape(-1, 1)
print(cat_seed.T)

In [None]:
generate_images(generator, noise_seed, cat_seed, 1)

In [None]:
cat_seed = np.array([3]*10)
generate_images(generator, noise_seed, cat_seed, 0)

In [None]:
cat_seed = np.array([6]*10)
generate_images(generator, noise_seed, cat_seed, 0)