In [1]:
# DCGAN实战测试
# 这里采用 matplotlib.image 读入图片数组，注意这里读入的数组是 float32 型的，范围是 0-1
# https://blog.csdn.net/qq_36758914/article/details/104878227 这个是一个详细说明的版本
import glob
import tensorflow as tf
import multiprocessing
img_paths = glob.glob('./data/images/*.jpg')
# batch_size = 4
# print(len(img_path)) 63565
# 构建自定义数据集
# https://www.cnblogs.com/heze/p/12390926.html
def make_anime_dataset(img_paths,batch_size,resize=64,drop_remainder=True,shuffle=True,repeat=1):
    @tf.function
    def _map_fn(img):
        img = tf.io.read_file(img)
        img = tf.image.decode_jpeg(img,channels=3)
        img = tf.image.resize(img,[resize,resize])
        img = tf.clip_by_value(img,0,255)
        img = img/127.5-1
        return img
    img_shape = (resize, resize, 3)
    len_dataset = len(img_paths) # batch_size
    dataset = tf.data.Dataset.from_tensor_slices(img_paths)
    dataset = dataset.shuffle(1000).map(_map_fn).batch(batch_size)
    # print(dataset)
    # print(iter(dataset).next())
    return dataset, img_shape, len_dataset
    
# https://ithelp.ithome.com.tw/articles/10241789?sc=rss.iron dataset用法参考
# https://ithelp.ithome.com.tw/articles/10268970     Shuffle Batch Repeat
# https://www.cnblogs.com/marsggbo/p/9603789.html

#     # drop_remainder https://www.cnblogs.com/wkslearner/p/9484443.html （这个也是参数详解）
#     dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
#     # https://zhuanlan.zhihu.com/p/163656225 prefetch

#   https://blog.csdn.net/xierhacker/article/details/79002902 from_tensor_slices
    


In [2]:
import keras.api._v2.keras as keras
from keras.api._v2.keras import layers
class Generator(keras.Model):
    # 生成器网络
    def __init__(self):
        super(Generator,self).__init__()
        filter = 64

        # 转置卷积层1，输出channel为filter*8，核 4，s 1,padding 0,无偏置
        self.conv1 = layers.Conv2DTranspose(filter*8,4,1,'valid',use_bias=False)
        self.bn1 = layers.BatchNormalization()

        # 转置卷积层2
        self.conv2 = layers.Conv2DTranspose(filter*4,4,2,'same',use_bias=False)
        self.bn2 = layers.BatchNormalization()

        # 转置卷积层3
        self.conv3 = layers.Conv2DTranspose(filter*2,4,2,'same',use_bias=False)
        self.bn3 = layers.BatchNormalization()

        # 转置卷积层4
        self.conv4 = layers.Conv2DTranspose(filter*1,4,2,'same',use_bias=False)
        self.bn4 = layers.BatchNormalization()

        # 转置卷积层5
        self.conv5 = layers.Conv2DTranspose(3,4,2,'same',use_bias=False)

    def call(self, inputs, training=None, mask=None):
        x = inputs
        x = tf.reshape(x,(x.shape[0],1,1,x.shape[1]))
        x = tf.nn.relu(x)

        x = tf.nn.relu(self.bn1(self.conv1(x),training=training))

        x = tf.nn.relu(self.bn2(self.conv2(x),training=training))

        x = tf.nn.relu(self.bn3(self.conv3(x),training=training))

        x = tf.nn.relu(self.bn4(self.conv4(x),training=training))

        x = self.conv5(x)
        
        x = tf.tanh(x) #输出x范围-1 ~1,与预处理一致
        # print(x)
        return x
            

In [3]:
# 判别器
class Discriminator(keras.Model):
    
    def __init__(self):
        super(Discriminator,self).__init__()
        filter = 64

        self.conv1 = layers.Conv2D(filter,4,2,'valid',use_bias=False)
        self.bn1 = layers.BatchNormalization()

        self.conv2 = layers.Conv2D(filter*2,4,2,'valid',use_bias=False)
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2D(filter*4,4,2,'valid',use_bias=False)
        self.bn3 = layers.BatchNormalization()

        self.conv4 = layers.Conv2D(filter*8,3,1,'valid',use_bias=False)
        self.bn4 = layers.BatchNormalization()

        self.conv5 = layers.Conv2D(filter*16,3,1,'valid',use_bias=False)
        self.bn5 = layers.BatchNormalization()

        # 全局池化层
        self.pool = layers.GlobalAveragePooling2D()

        # 特征打平层
        self.flatten = layers.Flatten()

        # 分类全连接层
        self.fc = layers.Dense(1)

    def call(self, inputs, training=None, mask=None):
            # （b，31，31，64）
        x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs),training=training))
        # （b，14,14，128）
        
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
        # （b，6，6，256）
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))
        # （b，4，4，512）
        x = tf.nn.leaky_relu(self.bn4(self.conv4(x),training=training))
        # （b，2，2，1024）
        x = tf.nn.leaky_relu(self.bn5(self.conv5(x),training=training))
        # （b，1024）
        x = self.pool(x)

        logits = self.fc(x)

        return logits


In [4]:
def celoss_ones(logits):
    y = tf.ones_like(logits)
    loss = keras.losses.binary_crossentropy(y,logits,from_logits=True)
    return tf.reduce_mean(loss)

def celoss_zeros(logits):
    y = tf.zeros_like(logits)
    loss = keras.losses.binary_crossentropy(y,logits,from_logits=True)
    return tf.reduce_mean(loss)
# 判别器误差损失函数

def d_loss_fn(generator,discriminator,batch_z,batch_x,is_training):
    # 采样生成图片
    fake_image = generator(batch_z,is_training)
    # print(fake_image)
    # print("generator")
    # 判定该生成图片
    d_fake_logits = discriminator(fake_image,is_training)

    # 判断真实图片
    d_real_logits = discriminator(batch_x,is_training)

    # 真实图片与1的误差
    d_loss_real = celoss_ones(d_real_logits)

    # 生成图片与0之间的误差
    d_loss_fake = celoss_zeros(d_fake_logits)

    return d_loss_fake + d_loss_real

# 生成器网络损失函数

def g_loss_fn(generator,discriminator,batch_z,is_training):
    # 采样生成图片
    fake_image = generator(batch_z,is_training)
    # 在训练网络时需要迫使生成图片判断为真
    d_fake_logits = discriminator(fake_image,is_training)
    # 计算生成图片与1之间的误差
    loss = celoss_ones(d_fake_logits)

    return loss



In [5]:
import numpy as np
from PIL import Image
# import scipy.misc
def save_result(val_out, val_block_size, image_path, color_mode):
      def preprocess(img):
          img = ((img + 1.0) * 127.5).astype(np.uint8)
          # img = img.astype(np.uint8)
          return img
  
      preprocesed = preprocess(val_out)
      final_image = np.array([])
      single_row = np.array([])
  
      for b in range(val_out.shape[0]):
          # concat image into a row
          if single_row.size == 0:
              single_row = preprocesed[b, :, :, :]
          else:
              single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
  
          # concat image row to final_image
          if (b + 1) % val_block_size == 0:
              if final_image.size == 0:
                  final_image = single_row
              else:
                  final_image = np.concatenate((final_image, single_row), axis=0)
  
              # reset single row
              single_row = np.array([])
  
      if final_image.shape[2] == 1:
          final_image = np.squeeze(final_image, axis=2)
      im = Image.fromarray(final_image)
      im.save(image_path)
    #   Image.save(final_image)
    #   Image(final_image).save()
  
  
# d_losses, g_losses = [], []

In [6]:
# 网络训练
import os
from matplotlib import pyplot as plt
z_dim = 100
epochs = 300
batch_size = 64
is_training = True

generator = Generator()


generator.build(input_shape=(4,z_dim))
# 这个4随便设一个就好

discriminator = Discriminator()

discriminator.build(input_shape=(4,64,64,3))

g_optimizer = keras.optimizers.Adam(learning_rate=2e-4,beta_1=0.5)
d_optimizer = keras.optimizers.Adam(learning_rate=2e-4,beta_1=0.5)
dataset,img_shape,_ = make_anime_dataset(img_paths,batch_size,resize=64)

db_iter = iter(dataset)


# 测试
# test = tf.random.normal([64,100])

# test = generator(test,True)
# print(test)
# test = test.numpy()
# test = ((test + 1.0) * 127.5).astype(np.uint8)
# plt.imshow(test[0])
# plt.show()
# print(next(db_iter))
 # 可视化
# z = tf.random.normal([100, z_dim])
# fake_image = generator(z, training=False)
# img_path = os.path.join('gan_images', 'gan-%d.png'%1)
# save_result(fake_image.numpy(), 10, img_path, color_mode='P')

In [7]:
# tensorboard
import datetime

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)

generator_loss = tf.keras.metrics.Mean('generator_loss',dtype=tf.float32)
discriminator_loss = tf.keras.metrics.Mean('discriminator_loss',dtype=tf.float32)

In [8]:
for epoch in range(epochs):  # 训练epochs 次
    for i,batch_x in enumerate(dataset):
         # 1. 训练判别器
         print('epoch:{},step:{}',epoch,i)
         # 采样隐藏向量
         batch_z = tf.random.normal([batch_size, z_dim])
         # 采样真实图片
         # 判别器前向计算
         with tf.GradientTape() as tape:
             d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
         grads = tape.gradient(d_loss, discriminator.trainable_variables)
         d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
        
         discriminator_loss(d_loss)
         # 2. 训练生成器
        #  # 采样隐藏向量
        #  batch_z = tf.random.normal([batch_size, z_dim])
        
         # 生成器前向计算
         with tf.GradientTape() as tape:
             g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
         grads = tape.gradient(g_loss, generator.trainable_variables)
         g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
         
         generator_loss(g_loss)
         
    with train_summary_writer.as_default():
       tf.summary.scalar('generator_loss',generator_loss.result(),step = epoch)
       tf.summary.scalar('discriminator_loss',discriminator_loss.result(),step = epoch)

    template = 'Epoch {}, g_Loss: {},d_loss:{}'
    print(template.format(epoch+1,generator_loss.result(),discriminator_loss.result()))
       # https://zhuanlan.zhihu.com/p/84215973  tf2 tensorboard教程

    if epoch % 10 == 0:
      print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))  # 可视化
      z = tf.random.normal([100, z_dim])
      fake_image = generator(z, training=False)
      img_path = './gan_images'
      img_path = os.path.join(img_path, 'gan-%d.png' % epoch)
      save_result(fake_image.numpy(), 10, img_path, color_mode='P')
    
    generator_loss.reset_states()
    discriminator_loss.reset_states()
        #    if epoch % 10000 == 1:
        #        # print(d_losses)
        #        # print(g_losses)
        #        generator.save_weights('exam11.1_generator.ckpt')
        #        discriminator.save_weights('exam11.1_discriminator.ckpt')



# generator = keras.Sequential()
# generator.add(Generator())

# for epoch in range(10):
#     for _ in range(5)

epoch:{},step:{} 0 0
epoch:{},step:{} 0 1
epoch:{},step:{} 0 2
