## WGAN

GAN与WGAN在实现时，只有在train的时候部分代码需要改动

In [1]:
import tensorflow as tf
import multiprocessing
import os
import numpy as np
import glob
from scipy.misc import toimage
from PIL import Image 
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

In [2]:
config = ConfigProto()   
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
assert np.__version__.startswith('1.16.2')

In [3]:
# 定义生成器G   向量[b, 100] => 图片[b, 81, 81, 3]
class Generator(keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        # 输入一个向量，经过生成器，变成一个图片，图片逐渐变大，通道数不断减少
        # z:[b, 100] => [b, 3*3*512]  全连接层        
        self.fc = layers.Dense(3*3*512)
        
        # [b, 3, 3, 512] => [b, 9, 9, 256]   反卷积层(输出通道数_卷积核数量；核的大小；步长；padding)
        # Conv2DTranspose 是考虑什么样的feature_map通过这个卷积层变成了现在的feature_map
        self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
        self.bn1 = layers.BatchNormalization()
        # [b, 9, 9, 256] => [b, 29, 29, 128]  此处不是[b, 27, 27, 128]可能是因为受卷积核大小的影响吧
        self.conv2 = layers.Conv2DTranspose(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()
        # [b, 29, 29, 128] => [b, 88, 88, 3]
        self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')
        
        
    def call(self, inputs, training=None):
        # [z, 100] => [z, 3*3*512]
        x = self.fc(inputs)
        # 获得图片feature_map
        x = tf.reshape(x, [-1, 3, 3, 512])        
        # 通过leaky_relu() 函数
        x = tf.nn.leaky_relu(x)
        # [b, 3, 3, 512] => [b, 9, 9, 256]
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
        # [b, 9, 9, 256] => [b, 29, 29, 128]  因为卷积核比较大，是(5*5)所以得到是29*29
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        # [b, 29, 29, 128] => [b, 88, 88, 3]
        x = self.conv3(x)
        x = tf.tanh(x)       
        return x    # [b, 88, 88, 3]
    
    
# 定义判别器D  图片[b, 81, 81, 3] => 向量[b, 1]
class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        # [b, 81, 81, 3] => [b, 27, 27, 64]
        self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
        self.bn1 = layers.BatchNormalization()
        # [b, 27, 27, 64] => [b, 9, 9, 128]
        self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()
        # [b, 9, 9, 128] => [b, 3, 3, 256]
        self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
        self.bn3 = layers.BatchNormalization()
        
        # [b, 3, 3, 256] => [b, -1]
        # 调用 layers.Flatten() 函数
        self.flatten = layers.Flatten()
        # 通过一个全连接层，输出节点数为1
        self.fc = layers.Dense(1)
        
    def call(self, inputs, training=None):
        x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training=training))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
        x = self.flatten(x)
        logits = self.fc(x)
        
        return logits    # [b, 1]

In [4]:
def save_result(val_out, val_block_size, image_path, color_mode):
    def preprocess(img):
        img = ((img + 1.0) * 127.5).astype(np.uint8)    # 变换到（0， 255），再转换数据类型为np.uint8
        # img = img.astype(np.uint8)
        return img

    preprocesed = preprocess(val_out)    # 获得经过preprocess()处理过的图片数据，在（0,255）之间     
    single_row = np.array([])     # 定义一行
    final_image = np.array([])    # 每次添加 single_row 一行的数据
    for b in range(val_out.shape[0]):    # val_out.shape[0] = 100,一张一张处理图片；b = 0, 1, ... 99
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocesed[b, :, :, :] # preprocessed[0]表示第b张图片，并添加到single_row数组
        else:
            single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)

        # concat image row to final_image
        if (b+1) % val_block_size == 0:   # 每val_block_size（此处是10）张截取一行，当b=9时执行
            if final_image.size == 0:    
                final_image = single_row    # 把 single_row 中的数据拷贝至 final_image
            else:    # axis=0 表示 每次添加一行图片数据
                final_image = np.concatenate((final_image, single_row), axis=0)

            # reset single row
            single_row = np.array([])

    if final_image.shape[2] == 1:    # 防止 final_image 出现 [10, 10, 1]情况，squeeze 高维度的1
        final_image = np.squeeze(final_image, axis=2)
    
    toimage(final_image).save(image_path)

In [5]:
# 定义 make_anime_dataset()  用于获取anime数据集中的每一张图片
def make_anime_dataset(img_paths, batch_size, resize=88, drop_remainder=True, shuffle=True, repeat=1):
    @tf.function
    def _map_fn(img):
        img = tf.image.resize(img, [resize, resize])
        img = tf.clip_by_value(img, 0, 255)
        img = img / 127.5 - 1
        return img

    dataset = disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)
    img_shape = (resize, resize, 3)
    len_dataset = len(img_paths) // batch_size

    return dataset, img_shape, len_dataset


def batch_dataset(dataset,
                  batch_size,
                  drop_remainder=True,
                  n_prefetch_batch=1,
                  filter_fn=None,
                  map_fn=None,
                  n_map_threads=None,
                  filter_after_map=False,
                  shuffle=True,
                  shuffle_buffer_size=None,
                  repeat=None):
    # set defaults
    if n_map_threads is None:
        n_map_threads = multiprocessing.cpu_count()
    if shuffle and shuffle_buffer_size is None:
        shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048

    # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` 
    # is sometimes costly
    if shuffle:
        dataset = dataset.shuffle(shuffle_buffer_size)

    if not filter_after_map:
        if filter_fn:
            dataset = dataset.filter(filter_fn)

        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

    else:  # [*] this is slower
        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

        if filter_fn:
            dataset = dataset.filter(filter_fn)

    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)

    dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)

    return dataset


def memory_data_batch_dataset(memory_data,
                              batch_size,
                              drop_remainder=True,
                              n_prefetch_batch=1,
                              filter_fn=None,
                              map_fn=None,
                              n_map_threads=None,
                              filter_after_map=False,
                              shuffle=True,
                              shuffle_buffer_size=None,
                              repeat=None):
    """Batch dataset of memory data.

    Parameters
    ----------
    memory_data : nested structure of tensors/ndarrays/lists

    """
    dataset = tf.data.Dataset.from_tensor_slices(memory_data)
    dataset = batch_dataset(dataset,
                            batch_size,
                            drop_remainder=drop_remainder,
                            n_prefetch_batch=n_prefetch_batch,
                            filter_fn=filter_fn,
                            map_fn=map_fn,
                            n_map_threads=n_map_threads,
                            filter_after_map=filter_after_map,
                            shuffle=shuffle,
                            shuffle_buffer_size=shuffle_buffer_size,
                            repeat=repeat)
    return dataset


def disk_image_batch_dataset(img_paths,
                             batch_size,
                             labels=None,
                             drop_remainder=True,
                             n_prefetch_batch=1,
                             filter_fn=None,
                             map_fn=None,
                             n_map_threads=None,
                             filter_after_map=False,
                             shuffle=True,
                             shuffle_buffer_size=None,
                             repeat=None):
    """Batch dataset of disk image for PNG and JPEG.

    Parameters
    ----------
        img_paths : 1d-tensor/ndarray/list of str
        labels : nested structure of tensors/ndarrays/lists

    """
    if labels is None:
        memory_data = img_paths
    else:
        memory_data = (img_paths, labels)

    def parse_fn(path, *label):
        img = tf.io.read_file(path)
        img = tf.image.decode_png(img, 3)  # fix channels to 3
        return (img,) + label

    if map_fn:  # fuse `map_fn` and `parse_fn`
        def map_fn_(*args):
            return map_fn(*parse_fn(*args))
    else:
        map_fn_ = parse_fn

    dataset = memory_data_batch_dataset(memory_data,
                                        batch_size,
                                        drop_remainder=drop_remainder,
                                        n_prefetch_batch=n_prefetch_batch,
                                        filter_fn=filter_fn,
                                        map_fn=map_fn_,
                                        n_map_threads=n_map_threads,
                                        filter_after_map=filter_after_map,
                                        shuffle=shuffle,
                                        shuffle_buffer_size=shuffle_buffer_size,
                                        repeat=repeat)

    return dataset

In [6]:
# 定义 loss_ones()函数，logits 与 [1, 1, 1, ..., 1] 之间的误差
def loss_ones(logits):
    # [b, 1]  VS. [b] = [1, 1, ... , 1]
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits))
    return tf.reduce_mean(loss)

# 定义 loss_zeros()函数，logits 与 [0, 0, 0, ..., 0] 之间的误差
def loss_zeros(logits):
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.zeros_like(logits))
    return tf.reduce_mean(loss)

#  **********************  WAN  *****************************************
# 定义惩罚项gradient_penalty() 函数，这是WGAN特有的
# gp_loss = gradient_penalty(discriminator, batch_x, batch_x_hat)
def gradient_penalty(discriminator, batch_x, batch_x_hat):
    batchsz = batch_x.shape[0]  
    # [b, 1, 1, 1]
    t = tf.random.uniform([batchsz, 1, 1, 1])
    # [b, 1, 1, 1] => [b, h, w, c]
    t = tf.broadcast_to(t, batch_x.shape)
    #定义中间变量
    interplate = t * batch_x + (1. - t) * batch_x_hat
    # 计算梯度
    with tf.GradientTape() as tape:
        # 因为 interplate 只是 tensor 类型，需要使用tape.watch()读取，若是variable类型，则不需要
        tape.watch([interplate])
        d_interplate_logits = discriminator(interplate)
    grads = tape.gradient(d_interplate_logits, interplate)
    # 打平操作
    # grads [b, h, w, c] => [b, -1]
    grads = tf.reshape(grads, [grads.shape[0], -1])
    # 计算范数值
    grads_norm = tf.norm(grads, axis=1)  # [b]
    # 计算范数的均值
    grads_norm_mean = tf.reduce_mean((grads_norm-1) ** 2)
    
    return grads_norm_mean

#  **********************  WGAN  *****************************************
# 定义计算判别器discriminator的损失函数d_loss_fn()
# d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    # treat real image as real 对于真实的图片数据
    d_real_logits = discriminator(batch_x, is_training)
    # treat generated image as fake 对于生成的图片数据
    batch_x_hat = generator(batch_z, is_training)
    d_fake_logits = discriminator(batch_x_hat, is_training)
    # 计算真实图片的损失
    d_loss_real = loss_ones(d_real_logits)
    # 计算生成图片的损失（fake）
    d_loss_fake = loss_zeros(d_fake_logits)    
    # *************************************  WGAN  *********************************** d_loss_fn
    
    gp_loss = gradient_penalty(discriminator, batch_x, batch_x_hat)
    # 设置平衡因子，大致可设置为 1. 到 10. 左右，需要手动调节
    lamda = 1.
    loss = d_loss_real + d_loss_fake + gp_loss * lamda
    
    return loss, gp_loss
    # *************************************  WGAN  ***********************************
    
# 定义计算生成器generator的损失函数
# g_loss = g_loss_fn(generator, discriminator, batch_z, is_training):
def g_loss_fn(generator, discriminator, batch_z, is_training):
    d_fake_logits = discriminator(generator(batch_z, is_training), is_training)
    g_loss_fake = loss_ones(d_fake_logits)    
    return g_loss_fake

In [7]:
# 定义输入到生成器的向量维度
z_dim = 100    # z[b, 100] => x_hat[b, 88, 88, 3] => x_hat_pro[b, 1]
# 定义训练世代数
epochs = 3000000  # 三百万次
# epochs = 200
# 定义批处理数据大小
# batch_size = 1024
batch_size = 768
# 定义学习率
learning_rate = 2e-3
# 定义训练参数
is_training = True
# 定义数据集中图片的路径
#返回所有匹配的文件路径列表。它只有一个参数pathname，定义了文件路径匹配规则
img_path = glob.glob('/home/kukafee/shared/faces/*.jpg')
# 获取数据集中每批次的数据
dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
# print(dataset, img_shape) # <PrefetchDataset shapes: (512, 88, 88, 3), types: tf.float32> (88, 88, 3)
# 构建迭代器，并查看其中的元素
# sample = next(iter(dataset))
# print(sample.shape, tf.reduce_max(sample).numpy(), tf.reduce_min(sample).numpy())
# 使得可迭代对象可循环
dataset = dataset.repeat()
db_iter = iter(dataset)    # 构建迭代器 db_iter

# 定义并构建生成器网络
generator = Generator()
generator.build(input_shape = (None, z_dim))
# generator.summary()
# 定义并构建判别器网络
discriminator = Discriminator()
discriminator.build(input_shape = (None, 88, 88, 3))
# discriminator.summary()
# 定义优化器 
# 其中参数 beta_1: A float value or a constant float tensor. The exponential decay rate for the 1st
# moment estimates.
g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)

In [8]:
# 训练并测试GAN
def main():
    # 循环每一个世代
    for epoch in range(epochs):
        # 定义一批要送入生成器的随机均匀分布数据
        # z[b, 100] => x_hat[b, 88, 88, 3] => x_hat_pro[b, 1]
        batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
        # 从迭代器 db_iter 中生成要送入判别器的图片数据
        batch_x = next(db_iter)
        for i in range(2):            
            # 训练判别器D——discriminator
            with tf.GradientTape() as tape:
                d_loss, dg_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
            # 计算梯度
            grads = tape.gradient(d_loss, discriminator.trainable_variables)
            # 更新参数
            d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
                
        # 训练生成器G——generator
        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
        
        
        if epoch % 100 == 0:    # *************GAN*********添加gp_loss的显示，该值应该接近于0
            print(epoch, 'd-loss: ', float(d_loss), 'g-loss: ', float(g_loss), 'dg_loss', float(dg_loss))
            
            # 测试生气器网络效果
            # 从正态分布中随机产生数据
            z = tf.random.uniform([100, z_dim])    # [100, 100]
            fake_image = generator(z, training=False)    # [100, 88, 88, 3]
            # '/home/kukafee/workspace/picture/GAN/fake_epoch_%d.png'%epoch
            img_path = os.path.join('/home/kukafee/workspace/picture/WGAN', 'fake_epoch_%d.png'%epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')
            
        if epoch % 1000 == 0:            
            # 保存 network 的******参数******到文件
            generator.save_weights('/home/kukafee/workspace/save_model/WGAN/G_weight_%d.ckpt'%epoch)
            discriminator.save_weights('/home/kukafee/workspace/save_model/WGAN/D_weight_%d.ckpt'%epoch)
            # 打印信息
            print('Saved G&D_weights: %d'%epoch)

In [10]:
# main()

训练了514分钟；模型参数最新保存至Saved G&D_weights: 37000
###### 37000 d-loss:  1.3720537424087524 g-loss:  0.7379453778266907 dg_loss 0.0009328257292509079
###### Saved G&D_weights: 37000
###### 37100 d-loss:  1.376948595046997 g-loss:  0.7198958992958069 dg_loss 0.0010607395088300109
###### 37200 d-loss:  1.3777310848236084 g-loss:  0.7109716534614563 dg_loss 0.000748057325836271
###### 37300 d-loss:  1.3667497634887695 g-loss:  0.7310948967933655 dg_loss 0.0010721470462158322