In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import tensorflow as tf
from tensorflow import keras
import time
import numpy as np
import matplotlib.pyplot as plt

# -------------------get_half_batch_ds-------------------

def _process_x(x):
    return tf.expand_dims(tf.cast(x, tf.float32), axis=3) / 255. * 2 - 1

def get_half_batch_ds(batch_size):
    return get_ds(batch_size//2)

def get_ds(batch_size):
    (x, y), _ = keras.datasets.mnist.load_data()
    x = _process_x(x)
    y = tf.cast(y, tf.int32)
    ds = tf.data.Dataset.from_tensor_slices((x, y)).cache().shuffle(1024).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return ds

# -------------------------------------------------------

class Attention(keras.layers.Layer):
    def __init__(self, gamma=0.01, trainable=True):
        super().__init__(trainable=trainable)
        self._gamma = gamma
        self.gamma = None
        self.f = None
        self.g = None
        self.h = None
        self.v = None
        self.attention = None

    def build(self, input_shape):
        c = input_shape[-1]
        if c > 1:
          self.f = self.block(c//8)     # reduce channel size, reduce computation
          self.g = self.block(c//8)     # reduce channel size, reduce computation
          self.h = self.block(c//8)     # reduce channel size, reduce computation
        else:
          self.f = self.block(c)
          self.g = self.block(c)
          self.h = self.block(c)
        self.v = keras.layers.Conv2D(c, 1, 1)              # scale back to original channel size
        global GAMMA_id
        self.gamma = self.add_weight(
            "gamma{}".format(GAMMA_id), shape=None, initializer=keras.initializers.constant(self._gamma))
        # print(self.gamma.shape)
        GAMMA_id += 1
        print(GAMMA_id)

    @staticmethod
    def block(c):
        return keras.Sequential([
            keras.layers.Conv2D(c, 1, 1),   # [n, w, h, c] 1*1conv
            keras.layers.Reshape((-1, c)),          # [n, w*h, c]
        ])

    def call(self, inputs, **kwargs):
        f = self.f(inputs)    # [n, w, h, c] -> [n, w*h, c//8]
        g = self.g(inputs)    # [n, w, h, c] -> [n, w*h, c//8]
        h = self.h(inputs)    # [n, w, h, c] -> [n, w*h, c//8]
        s = tf.matmul(f, g, transpose_b=True)   # [n, w*h, c//8] @ [n, c//8, w*h] = [n, w*h, w*h]
        self.attention = tf.nn.softmax(s, axis=-1)
        context_wh = tf.matmul(self.attention, h)  # [n, w*h, w*h] @ [n, w*h, c//8] = [n, w*h, c//8]
        s = inputs.shape        # [n, w, h, c]
        cs = context_wh.shape   # [n, w*h, c//8]
        context = tf.reshape(context_wh, [-1, s[1], s[2], cs[-1]])    # [n, w, h, c//8]
        o = self.v(self.gamma * context) + inputs   # residual
        return o


class SAGAN(keras.Model):
    """
    自注意力加强生成器能力,使用常用在SVM中的 hinge loss, 连续性loss.
    因为注意力的矩阵很大(w*h @ w*h), 所以训练起来比较慢, 意味着留有改动空间.
    里面的稳定W gradient的Spectral normalization（SN）写起来有点麻烦,
    我有空再考虑把这个 SN regularizer 写进来.
    """
    def __init__(self, latent_dim, img_shape, gamma):
        super().__init__()
        self.gamma = gamma
        self.img_shape = img_shape
        self.latent_dim = latent_dim
        self.g = self._get_generator()
        self.d = self._get_discriminator()
        self.opt = keras.optimizers.Adam(0.0002, beta_1=0.5)
        self.loss_func = keras.losses.Hinge()       # change loss to hinge based on the paper

    def call(self, n, training=None, mask=None):
        return self.g.call(tf.random.normal((n, self.latent_dim)), training=training)

    def _get_discriminator(self):
        model = keras.Sequential([
            keras.layers.GaussianNoise(0.01, input_shape=self.img_shape),
            keras.layers.Conv2D(16, 4, strides=2, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            Attention(self.gamma),
            keras.layers.Dropout(0.3),

            keras.layers.Conv2D(32, 4, strides=2, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Dropout(0.3),

            keras.layers.Flatten(),
            keras.layers.Dense(1),
        ], name="discriminator")
        model.summary()
        return model

    def _get_generator(self):
        model = keras.Sequential([
            # [n, latent] -> [n, 7 * 7 * 128] -> [n, 7, 7, 128]
            keras.layers.Dense(7 * 7 * 128, input_shape=(self.latent_dim,)),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.Reshape((7, 7, 128)),

            # -> [n, 14, 14, 64]
            keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            Attention(self.gamma),

            # -> [n, 28, 28, 32]
            keras.layers.Conv2DTranspose(32, (4, 4), strides=(2, 2), padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            # -> [n, 28, 28, 1]
            keras.layers.Conv2D(1, (4, 4), padding='same', activation=keras.activations.tanh),
            Attention(self.gamma)
        ], name="generator")
        model.summary()
        return model

    def train_d(self, img, d_label):
        with tf.GradientTape() as tape:
            pred = self.d.call(img, training=True)
            loss = self.loss_func(d_label, pred)
        grads = tape.gradient(loss, self.d.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.d.trainable_variables))
        return loss

    def train_g(self, d_label):
        with tf.GradientTape() as tape:
            g_img = self.call(len(d_label), training=True)
            pred = self.d.call(g_img, training=False)
            loss = self.loss_func(d_label, pred)
        grads = tape.gradient(loss, self.g.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.g.trainable_variables))
        return loss, g_img

    def step(self, img):
        d_label = 2*tf.ones((len(img) * 2, 1), tf.float32)  # a stronger positive label?
        g_loss, g_img = self.train_g(d_label)

        d_label = tf.concat((tf.ones((len(img), 1), tf.float32), -tf.ones((len(g_img)//2, 1), tf.float32)), axis=0)
        img = tf.concat((img, g_img[:len(g_img)//2]), axis=0)
        d_loss = self.train_d(img, d_label)
        return d_loss, g_loss


def train(gan, ds, epoch):
    t0 = time.time()
    a = 0
    for ep in range(epoch):
        for t, (img, _) in enumerate(ds):
            d_loss, g_loss = gan.step(img)
            if a == 0:
              t_img = img
              a = 1
            if t % 400 == 0:
                t1 = time.time()
                print(
                    "ep={} | time={:.1f} | t={} | d_loss={:.2f} | g_loss={:.2f}".format(
                        ep, t1 - t0, t, d_loss.numpy(), g_loss.numpy(), ))
                t0 = t1
        save_gan(gan, ep, t_img)


def save_gan(model, a, img, **kwargs):
    imgs = model.call(100, training=False).numpy()
    _save_gan(imgs, a, show_label=False)
    imgs = model.g.layers[-1](img)
    _save_gan2(imgs, a, show_label=False)

def _save_gan(imgs, a, show_label=False, nc=5, nr=5):
    if not isinstance(imgs, np.ndarray):
        imgs = imgs.numpy()
    if imgs.ndim > 3:
        imgs = np.squeeze(imgs, axis=-1)
    plt.clf()
    plt.figure(0, (nc * 2, nr * 2))
    for c in range(nc):
        for r in range(nr):
            i = r * nc + c
            plt.subplot(nr, nc, i + 1)
            plt.imshow(imgs[i], cmap="gray_r")
            plt.axis("off")
    plt.savefig('/content/drive/MyDrive/張雲南/SAGAN/img/' + f'1_{a}.png')
    plt.close()

def _save_gan2(imgs, a, show_label=False, nc=5, nr=5):
    if not isinstance(imgs, np.ndarray):
        imgs = imgs.numpy()
    if imgs.ndim > 3:
        imgs = np.squeeze(imgs, axis=-1)
    plt.clf()
    plt.figure(0, (nc * 2, nr * 2))
    for c in range(nc):
        for r in range(nr):
            i = r * nc + c
            plt.subplot(nr, nc, i + 1)
            plt.imshow(imgs[i], cmap="gray_r")
            plt.axis("off")
    plt.savefig('/content/drive/MyDrive/張雲南/SAGAN/img/' + f'2_{a}.png')
    plt.close()

if __name__ == "__main__":
    GAMMA_id = 0
    LATENT_DIM = 100
    IMG_SHAPE = (28, 28, 1)
    BATCH_SIZE = 64
    GAMMA = 0.01
    EPOCH = 20

    d = get_half_batch_ds(BATCH_SIZE)
    m = SAGAN(LATENT_DIM, IMG_SHAPE, GAMMA)
    train(m, d, EPOCH)

1
2
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 6272)              633472    
_________________________________________________________________
batch_normalization_3 (Batch (None, 6272)              25088     
_________________________________________________________________
re_lu_3 (ReLU)               (None, 6272)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 14, 14, 64)        131136    
_________________________________________________________________
batch_normalization_4 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
re_lu_4 (ReLU)               (None, 14, 14, 64)      

<Figure size 432x288 with 0 Axes>