In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!cd /content/drive/MyDrive/張雲南/infogan

In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tensorflow.keras.layers import Dense, Input, BatchNormalization, LeakyReLU, Dropout
from tensorflow.keras.layers import Conv2D, Flatten, Reshape, Conv2DTranspose, ReLU
import time
import os
import matplotlib.pyplot as plt

#----------------------- utils -----------------------

_b_acc = None
_c_acc = None


def binary_accuracy(label, pred):
    global _b_acc
    if _b_acc is None:
        _b_acc = tf.keras.metrics.BinaryAccuracy()
    _b_acc.reset_states()
    _b_acc.update_state(label, pred)
    return _b_acc.result()

def class_accuracy(label, pred):
    global _c_acc
    if _c_acc is None:
        _c_acc = tf.keras.metrics.SparseCategoricalAccuracy()
    _c_acc.reset_states()
    _c_acc.update_state(label, pred)
    return _c_acc.result()

def save_weights(model):
    name = model.__class__.__name__.lower()
    os.makedirs("./models/{}".format(name), exist_ok=True)
    model.save_weights("./models/{}/model.ckpt".format(name))

# ------------------- get_half_batch_ds -------------------

def _process_x(x):
    return tf.expand_dims(tf.cast(x, tf.float32), axis=3) / 255. * 2 - 1

def get_half_batch_ds(batch_size):
    return get_ds(batch_size//2)

def get_ds(batch_size):
    (x, y), _ = keras.datasets.mnist.load_data()
    x = _process_x(x)
    y = tf.cast(y, tf.int32)
    ds = tf.data.Dataset.from_tensor_slices((x, y)).cache().shuffle(1024).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return ds

# ------------------- gan_cnn -------------------

def mnist_uni_gen_cnn(input_shape):
    return keras.Sequential([
        # [n, latent] -> [n, 7 * 7 * 128] -> [n, 7, 7, 128]
        Dense(7 * 7 * 128, input_shape=input_shape),
        BatchNormalization(),
        ReLU(),
        Reshape((7, 7, 128)),
        # -> [n, 14, 14, 64]
        Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'),
        BatchNormalization(),
        ReLU(),
        # -> [n, 28, 28, 32]
        Conv2DTranspose(32, (4, 4), strides=(2, 2), padding='same'),
        BatchNormalization(),
        ReLU(),
        # -> [n, 28, 28, 1]
        Conv2D(1, (4, 4), padding='same', activation=keras.activations.tanh)
    ])


def mnist_uni_disc_cnn(input_shape=(28, 28, 1), use_bn=True):
    model = keras.Sequential()
    # [n, 28, 28, n] -> [n, 14, 14, 64]
    model.add(Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=input_shape))
    if use_bn:
        model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Dropout(0.3))
    # -> [n, 7, 7, 128]
    model.add(Conv2D(128, (4, 4), strides=(2, 2), padding='same'))
    if use_bn:
        model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Dropout(0.3))
    model.add(Flatten())
    return model

# ----------------------------------------------------------

def save_gan(model, ep, **kwargs):
    img_label = np.arange(0, model.label_dim).astype(np.int32).repeat(10, axis=0)
    img_style = np.concatenate([np.linspace(-model.style_scale, model.style_scale, 10)] * 10).reshape((100, 1)).repeat(model.style_dim, axis=1).astype(np.float32)
    img_info = img_label, img_style
    imgs = model.predict(img_info)
    _save_gan(ep, imgs, show_label=False)
    
    plt.clf()
    plt.close()

def _save_gan(ep, imgs, show_label=False, nc=10, nr=10):
    if not isinstance(imgs, np.ndarray):
        imgs = imgs.numpy()
    if imgs.ndim > 3:
        imgs = np.squeeze(imgs, axis=-1)
    imgs = _img_recenter(imgs)
    plt.clf()
    plt.figure(0, (nc * 2, nr * 2))
    for c in range(nc):
        for r in range(nr):
            i = r * nc + c
            plt.subplot(nr, nc, i + 1)
            plt.imshow(imgs[i], cmap="gray_r")
            plt.axis("off")
    plt.savefig('/content/drive/MyDrive/張雲南/infogan/img/' + f'{ep}.png')
    plt.close()

def _img_recenter(img):
    return (img + 1) * 255 / 2

# ----------------------------------------------------------

class InfoGAN(keras.Model):
    """
    discriminator 图片 预测 真假
    q net 图片 预测 c  (c可以理解为 虚拟类别 或 虚拟风格)
    generator z&c 生成 图片
    """
    def __init__(self, rand_dim, style_dim, label_dim, img_shape, fix_std=True, style_scale=2):
        super().__init__()
        self.rand_dim, self.style_dim, self.label_dim = rand_dim, style_dim, label_dim
        self.img_shape = img_shape
        self.fix_std = fix_std
        self.style_scale = style_scale

        self.g = self._get_generator()
        self.d = self._get_discriminator()

        self.opt = keras.optimizers.Adam(0.0002, beta_1=0.5)
        self.loss_bool = keras.losses.BinaryCrossentropy(from_logits=True, reduction="none")

    def call(self, img_info, training=None, mask=None):
        img_label, img_style = img_info
        noise = tf.random.normal((len(img_label), self.rand_dim))
        if isinstance(img_label, np.ndarray):
            img_label = tf.convert_to_tensor(img_label, dtype=tf.int32)
        if isinstance(img_style, np.ndarray):
            img_style = tf.convert_to_tensor(img_style, dtype=tf.float32)
        return self.g.call([noise, img_label, img_style], training=training)

    def _get_discriminator(self):
        img = Input(shape=self.img_shape)
        s = keras.Sequential([
            mnist_uni_disc_cnn(self.img_shape),
            Dense(32),
            BatchNormalization(),
            LeakyReLU(),
            Dropout(0.5),
        ])
        style_dim = self.style_dim if self.fix_std else self.style_dim * 2
        q = keras.Sequential([
            Dense(16, input_shape=(32,)),
            BatchNormalization(),
            LeakyReLU(),
            Dense(style_dim+self.label_dim)
        ], name="recognition")
        o = s(img)
        o_bool = Dense(1)(o)
        o_q = q(o)
        if self.fix_std:
            q_style = self.style_scale*tf.tanh(o_q[:, :style_dim])
        else:
            q_style = tf.concat(
                (self.style_scale * tf.tanh(o_q[:, :style_dim//2]), tf.nn.relu(o_q[:, style_dim//2:style_dim])),
                axis=1)
        q_label = o_q[:, -self.label_dim:]
        model = keras.Model(img, [o_bool, q_style, q_label], name="discriminator")
        model.summary()
        return model

    def _get_generator(self):
        latent_dim = self.rand_dim + self.label_dim + self.style_dim
        noise = Input(shape=(self.rand_dim,))
        style = Input(shape=(self.style_dim, ))
        label = Input(shape=(), dtype=tf.int32)
        label_onehot = tf.one_hot(label, depth=self.label_dim)
        model_in = tf.concat((noise, label_onehot, style), axis=1)
        s = mnist_uni_gen_cnn((latent_dim,))
        o = s(model_in)
        model = keras.Model([noise, label, style], o, name="generator")
        model.summary()
        return model

# ------------------------------------------- high light -----------------------------------------------------
# info_loss (q_model loss)
    def loss_mutual_info(self, style, pred_style, label, pred_label):
        # label loss
        categorical_loss = keras.losses.sparse_categorical_crossentropy(label, pred_label, from_logits=True)  

        # 選擇 std 為固定還是隨機
        if self.fix_std:
            style_mean = pred_style
            style_std = tf.ones_like(pred_style)
        else:
            split = pred_style.shape[1]//2
            style_mean, style_std = pred_style[:split], pred_style[split:]
            style_std = tf.sqrt(tf.exp(style_std))

        # continuous latent code loss
        epsilon = (style - style_mean) / (style_std + 1e-5)
        ll_continuous = tf.reduce_sum(
            - 0.5 * tf.math.log(2 * np.pi) - tf.math.log(style_std + 1e-5) - 0.5 * tf.square(epsilon),
            axis=1,
        )

        # loss
        loss = categorical_loss - ll_continuous
        return loss

# --------------------------------------------------------------------------------------------------------------

    def train_d(self, real_fake_img, real_fake_d_label, fake_img_label, fake_style):
        with tf.GradientTape() as tape:
            pred_bool, pred_style, pred_class = self.d.call(real_fake_img, training=True)
            info_split = len(real_fake_d_label)
            real_fake_pred_bool = pred_bool[:info_split]
            loss_bool = self.loss_bool(real_fake_d_label, real_fake_pred_bool)
            fake_pred_style = pred_style[-info_split:]
            fake_pred_label = pred_class[-info_split:]
            loss_info = self.loss_mutual_info(fake_style, fake_pred_style, fake_img_label, fake_pred_label)
            loss = tf.reduce_mean(loss_bool + LAMBDA * loss_info)
        grads = tape.gradient(loss, self.d.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.d.trainable_variables))
        return loss, binary_accuracy(real_fake_d_label, real_fake_pred_bool), class_accuracy(fake_img_label, fake_pred_label)

    def train_g(self, random_img_label, random_img_style):
        d_label = tf.ones((len(random_img_label), 1), tf.float32)   # let d think generated images are real
        with tf.GradientTape() as tape:
            g_img = self.call([random_img_label, random_img_style], training=True)
            pred_bool, pred_style, pred_class = self.d.call(g_img, training=False)
            loss_bool = self.loss_bool(d_label, pred_bool)
            loss_info = self.loss_mutual_info(random_img_style, pred_style, random_img_label, pred_class)
            loss = tf.reduce_mean(loss_bool + LAMBDA * loss_info)
        grads = tape.gradient(loss, self.g.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.g.trainable_variables))
        return loss, g_img, binary_accuracy(d_label, pred_bool)

    def step(self, real_img):
        random_img_label = tf.convert_to_tensor(np.random.randint(0, 10, len(real_img)*2), dtype=tf.int32)
        random_img_style = tf.random.uniform((len(real_img)*2, self.style_dim), -self.style_scale, self.style_scale)
        g_loss, g_img, g_bool_acc = self.train_g(random_img_label, random_img_style)

        real_fake_img = tf.concat((real_img, g_img), axis=0)    # 32+64
        real_fake_d_label = tf.concat(      # 32+32
            (tf.ones((len(real_img), 1), tf.float32), tf.zeros((len(g_img)//2, 1), tf.float32)), axis=0)
        d_loss, d_bool_acc, d_class_acc = self.train_d(real_fake_img, real_fake_d_label, random_img_label, random_img_style)
        return d_loss, d_bool_acc, g_loss, g_bool_acc, random_img_label, d_class_acc


def train(gan, ds):
    t0 = time.time()
    for ep in range(EPOCH):
        for t, (real_img, _) in enumerate(ds):
            d_loss, d_bool_acc, g_loss, g_bool_acc, g_img_label, d_class_acc = gan.step(real_img)
            if t % 400 == 0:
                t1 = time.time()
                print("ep={} | time={:.1f}|t={}|d_acc={:.2f}|d_classacc={:.2f}|g_acc={:.2f}|d_loss={:.2f}|g_loss={:.2f}".format(
                    ep, t1-t0, t, d_bool_acc.numpy(), g_bool_acc.numpy(), d_class_acc.numpy(), d_loss.numpy(), g_loss.numpy(), ))
                t0 = t1
        save_gan(gan, ep)
    save_weights(gan)


if __name__ == "__main__":
    STYLE_DIM = 2
    LABEL_DIM = 10
    RAND_DIM = 8
    LAMBDA = 1
    IMG_SHAPE = (28, 28, 1)
    FIX_STD = True
    STYLE_SCALE = 1
    BATCH_SIZE = 64
    EPOCH = 40

    d = get_half_batch_ds(BATCH_SIZE)
    m = InfoGAN(RAND_DIM, STYLE_DIM, LABEL_DIM, IMG_SHAPE, FIX_STD, STYLE_SCALE)
    train(m, d)

Model: "generator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_19 (InputLayer)           [(None,)]            0                                            
__________________________________________________________________________________________________
input_17 (InputLayer)           [(None, 8)]          0                                            
__________________________________________________________________________________________________
tf.one_hot_4 (TFOpLambda)       (None, 10)           0           input_19[0][0]                   
__________________________________________________________________________________________________
input_18 (InputLayer)           [(None, 2)]          0                                            
__________________________________________________________________________________________

NameError: ignored