<a href="https://colab.research.google.com/github/Junghwan-brian/Colab/blob/main/DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U -q PyDrive
from __future__ import absolute_import, division, print_function, unicode_literals

!pip install tensorflow-gpu==2.1
import tensorflow as tf
print(tf.__version__)
print(tf.test.gpu_device_name())
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
   
# PyDrive Authentication
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [None]:
import tensorflow as tf
from tensorflow_core.python.keras.api import keras
import numpy as np
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import os
batch_size = 64

data = tfds.load('mnist', split='train')
# img = next(iter(data))['image']
# plt.imshow(tf.squeeze(img), cmap='gray')

dataset = data.map(lambda x: tf.cast(
    x['image'], tf.float32)/255.0).batch(batch_size)


class Generator(keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        self.bn1 = keras.layers.BatchNormalization()
        self.dense1 = keras.layers.Dense(
            7*7*512, use_bias=False, input_shape=(100,))  # mnist는 크기가 28이라 7사용
        self.relu1 = keras.layers.ReLU()

        self.bn2 = keras.layers.BatchNormalization()
        self.deconv1 = keras.layers.Conv2DTranspose(
            filters=256, kernel_size=(5, 5), strides=(1, 1), padding='same', use_bias=False)
        self.relu2 = keras.layers.ReLU()

        self.bn3 = keras.layers.BatchNormalization()
        self.deconv2 = keras.layers.Conv2DTranspose(
            filters=128, kernel_size=(5, 5), strides=(2, 2), padding='same', use_bias=False)
        self.relu3 = keras.layers.ReLU()

        self.deconv3 = keras.layers.Conv2DTranspose(
            filters=1, kernel_size=(5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')

    def call(self, x, training=False):
        x = self.relu1(self.bn1(self.dense1(x), training=training))
        x = tf.reshape(x, (-1, 7, 7, 512))
        x = self.relu2(self.bn2(self.deconv1(
            x), training=training))  # (7,7,256)
        x = self.relu3(self.bn3(self.deconv2(
            x), training=training))  # (14,14,128)
        x = self.deconv3(x)  # (28,28,1)
        return x



class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = keras.layers.Conv2D(
            64, (5, 5), strides=2, padding='same', input_shape=(28, 28, 1), use_bias=False)
        self.bn1 = keras.layers.BatchNormalization()
        self.lrelu1 = keras.layers.LeakyReLU(alpha=0.2)

        self.conv2 = keras.layers.Conv2D(
            128, (5, 5), strides=2, padding='same', use_bias=False)
        self.bn2 = keras.layers.BatchNormalization()
        self.lrelu2 = keras.layers.LeakyReLU(alpha=0.2)

        self.flatten = keras.layers.Flatten()
        self.dense = keras.layers.Dense(1, activation='sigmoid')

    def call(self, x, training=False):
        x = self.lrelu1(self.bn1(self.conv1(x), training=training))
        x = self.lrelu2(self.bn2(self.conv2(x), training=training))
        x = self.flatten(x)
        return self.dense(x)


generator = Generator()
discriminator = Discriminator()

adam_g = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
adam_d = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(real_output, fake_output):
    # real_output의 label이 1이므로 1로 판별해야 함.
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    # fake_output의 label이 0이므로 0으로 판별해야 함.
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss


def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)


@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape_g, tf.GradientTape() as tape_d:
        z = tf.random.uniform([batch_size, 100])
        image = generator(z, training=True)
        fake_output = discriminator(image, training=True)
        real_output = discriminator(inputs, training=True)

        g_loss = generator_loss(fake_output)
        d_loss = discriminator_loss(real_output, fake_output)
    g_gradient = tape_g.gradient(g_loss, generator.trainable_variables)
    d_gradient = tape_d.gradient(d_loss, discriminator.trainable_variables)

    adam_g.apply_gradients(zip(g_gradient, generator.trainable_variables))
    adam_d.apply_gradients(zip(d_gradient, discriminator.trainable_variables))
    return g_loss, d_loss


d_metric = keras.metrics.Mean()
g_metric = keras.metrics.Mean()

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=adam_g,
                                 discriminator_optimizer=adam_d,
                                 generator=generator,
                                 discriminator=discriminator)


def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(3, 3, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()


test = tf.random.uniform([9, 100])
for epoch in range(100):
    for image in dataset:
        g_loss, d_loss = train_step(image)
        g_metric.update_state(g_loss)
        d_metric.update_state(d_loss)
    print('epoch: {}, g_loss: {}, d_loss: {}'.format(
        epoch+1, g_metric.result(), d_metric.result()))
    if epoch % 10 == 0:
        checkpoint.save(file_prefix=checkpoint_prefix)

        img_list = list()
        sample_img = generator(test)
        for idx in range(sample_img.shape[0]):
            img_list.append(sample_img[idx][:, :, 0])
        img = np.concatenate(img_list, axis=1)
        plt.imshow(img, cmap='gray')
        plt.show()

    g_metric.reset_states()
    d_metric.reset_states()
