In [5]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten,\
    Dropout, Convolution2D, MaxPooling2D, \
    AveragePooling2D, Convolution2DTranspose, Conv2DTranspose,GlobalAveragePooling2D
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.activation import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam,SGD
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import sys
import cv2 as cv
import numpy as np
import glob
import tensorflow as tf

def load_preprosess_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = (image/127.5) - 1
    return image

def readPicture(batch_size):
    image_path = glob.glob(r"autodl-nas/anime/*.png")
    img_ds = tf.data.Dataset.from_tensor_slices(image_path)
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    img_ds = img_ds.map(load_preprosess_image, num_parallel_calls=AUTOTUNE)
    BATCH_SIZE = batch_size
    image_count = len(image_path)
    img_ds = img_ds.shuffle(image_count).batch(BATCH_SIZE)
    img_ds = img_ds.prefetch(AUTOTUNE)
    return img_ds

class DCGAN(Model):
    def __init__(self):
        super().__init__()
        self.img_rows = 64
        self.img_cols = 64
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        self.optimizer = Adam(0.00001)


    def generator_model(self):
        model = Sequential()
        model.add(Dense(8 * 8 * 256, use_bias=False, input_shape=(100,)))
        model.add(BatchNormalization())
        model.add(LeakyReLU())

        model.add(Reshape((8, 8, 256))
        # 反卷积
        model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias='false'))
        model.add(BatchNormalization())
        model.add(LeakyReLU()) 

        model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias='false'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())

        model.add(Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias='false'))
        model.add(BatchNormalization())
        model.add(LeakyReLU()) 

        model.add(Conv2DTranspose(3, (5, 5),
                                         strides=(2, 2),
                                         padding='same',
                                         use_bias=False,
                                         activation='tanh'))  # 输出64*64*3

        return model


    def discriminator_model(self):
        model = Sequential()
        model.add(Conv2D(32,
                         (5, 5),
                         strides=(2, 2),
                         padding='same',
                         input_shape=(64, 64, 3)))
        model.add(LeakyReLU())
        model.add(Dropout(0.3))

        model.add(Conv2D(64,
                                (5, 5),
                                strides=(2, 2),
                                padding='same'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(Dropout(0.3)) 

        model.add(Conv2D(128,
                         (5, 5),
                         strides=(2, 2),
                         padding='same'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(Dropout(0.3)) 

        model.add(Conv2D(256,
                                (5, 5),
                                strides=(2, 2),
                                padding='same'))
        model.add(BatchNormalization())
        model.add(LeakyReLU()) 

        model.add(GlobalAveragePooling2D())

        model.add(Dense(1024))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(Dense(1, activation='sigmoid'))
        return model

    def d_on_g(self, d, g):
        model = Sequential()
        model.add(g)
        d.trainable = False
        model.add(d)
        d.trainable = True
        return model

    def train(self, batch_size, epochs):
        discriminator = self.discriminator_model()
        discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])
        generator = self.generator_model()
        generator.compile(loss='binary_crossentropy', optimizer=self.optimizer)
        d_on_g = self.d_on_g(discriminator, generator)
        d_on_g.compile(loss='binary_crossentropy', optimizer=self.optimizer)

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        for epoch in range(epochs):
            all_imgs = readPicture(batch_size)
            ct = 1
            for batch in all_imgs:
                if batch.shape[0] != batch_size:
                    continue

                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

                gen_imgs = generator.predict(noise)

                d_loss_real = discriminator.train_on_batch(batch, valid)
                d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
 
                g_loss = d_on_g.train_on_batch(noise, valid)

                print("epoch:%d batch: %d [D loss: %f] [G loss: %f]" % (epoch, ct, d_loss[0], g_loss))
                ct += 1

            # If at save interval => save generated image samples
            if epoch % 10 == 0:
                noise = np.random.normal(0, 1, (4, self.latent_dim))
                pre_image = generator.predict(noise)
                fig = plt.figure(figsize=(16, 3))
                for i in range(pre_image.shape[0]): 
                    plt.subplot(1, 4, i + 1) 
                    plt.imshow((pre_image[i, :, :, :] + 1) / 2)  
                    plt.axis('off')  # 不要坐标
                plt.savefig("images/%d.png" % epoch)
    

In [None]:
dcgan = DCGAN()
dcgan.train(batch_size=64, epochs=500)

epoch:0 batch: 1 [D loss: 0.783395] [G loss: 0.772676]
epoch:0 batch: 2 [D loss: 0.783488] [G loss: 0.762679]
epoch:0 batch: 3 [D loss: 0.794136] [G loss: 0.757625]
epoch:0 batch: 4 [D loss: 0.770776] [G loss: 0.770300]
epoch:0 batch: 5 [D loss: 0.763570] [G loss: 0.782121]
epoch:0 batch: 6 [D loss: 0.760860] [G loss: 0.765095]
epoch:0 batch: 7 [D loss: 0.780984] [G loss: 0.727041]
epoch:0 batch: 8 [D loss: 0.764672] [G loss: 0.741673]
epoch:0 batch: 9 [D loss: 0.752236] [G loss: 0.731384]
epoch:0 batch: 10 [D loss: 0.759938] [G loss: 0.736268]
epoch:0 batch: 11 [D loss: 0.739253] [G loss: 0.733303]
epoch:0 batch: 12 [D loss: 0.746580] [G loss: 0.702010]
epoch:0 batch: 13 [D loss: 0.756939] [G loss: 0.726575]
epoch:0 batch: 14 [D loss: 0.750128] [G loss: 0.717696]
epoch:0 batch: 15 [D loss: 0.753832] [G loss: 0.719729]
epoch:0 batch: 16 [D loss: 0.746563] [G loss: 0.697668]
epoch:0 batch: 17 [D loss: 0.745324] [G loss: 0.716300]
epoch:0 batch: 18 [D loss: 0.758525] [G loss: 0.703329]
e