### setting

#### import module

In [None]:
# system & address
import os
from zipfile import ZipFile

# data
import numpy as np
from imageio import imread
from skimage.transform import resize

import tensorflow as tf

# model
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, ReLU
from tensorflow.keras.optimizers import Adam


# visualization
import matplotlib.pyplot as plt

In [None]:
data = ZipFile("augmentation.zip", "r")
# print(data.namelist())

#### Data Loader class

In [None]:
class DataLoader:
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        # path = glob("%s/%s/*" % (self.dataset_name, data_type))
        path = [f for f in data.namelist() if f.startswith(data_type)]
        # path = [f for f in data.namelist() if data_type in f]

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = resize(img, self.img_res)
                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = resize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs) / 127.5 - 1.0
        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        path_A = [f for f in data.namelist() if f.startswith(data_type + "0")]
        path_B = [f for f in data.namelist() if f.startswith(data_type + "1")]
        # path_A = [f for f in data.namelist() if "train" + "A" in f and ".jpg" in f]
        # path_B = [f for f in data.namelist() if "train" + "B" in f and ".jpg" in f]

        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size

        # Sample n_batches * batch_size from each path list so that model sees all
        # samples from both domains
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches - 1):
            batch_A = path_A[i * batch_size : (i + 1) * batch_size]
            batch_B = path_B[i * batch_size : (i + 1) * batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = resize(img_A, self.img_res)
                img_B = resize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                    img_A = np.fliplr(img_A)
                    img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A) / 127.5 - 1.0
            imgs_B = np.array(imgs_B) / 127.5 - 1.0

            yield imgs_A, imgs_B

    def imread(self, path):
        path = data.open(path)
        return imread(path, pilmode="RGB").astype(np.float)
        # return imread(path).astype(np.float)

### model

In [None]:
class CycleGAN:
    def __init__(self, img_rows=128, img_cols=128):
        # parameter setting
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        patch = int(self.img_rows / 2 ** 4)
        self.disc_patch = (patch, patch, 1)

        self.gf = 32
        self.df = 64
        self.lambda_cycle = 10.0
        self.lambda_id = 0.9 * self.lambda_cycle

        optimizer = Adam(0.0002, 0.5)

        self.dataset_name = "augmentation"
        # self.dataset_name = "apple2orange"
        self.data_loader = DataLoader(dataset_name=self.dataset_name, img_res=(self.img_rows, self.img_cols))


        # modeling
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(loss="mse", optimizer=optimizer, metrics=["accuracy"])
        self.d_B.compile(loss="mse", optimizer=optimizer, metrics=["accuracy"])
        self.d_A.trainable = False
        self.d_B.trainable = False
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        fake_A = self.g_BA(img_B)
        fake_B = self.g_AB(img_A)

        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)

        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)

        self.combined = Model(inputs=[img_A, img_B], outputs=[valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id])
        self.combined.compile(
            loss=["mse", "mse", "mae", "mae", "mae", "mae"],
            loss_weights=[1, 1, self.lambda_cycle, self.lambda_cycle, self.lambda_id, self.lambda_id],
            optimizer=optimizer,
        )


        # ? checkpoint
        checkpoint_path = "model_history"
        if not os.path.exists(checkpoint_path):
            os.mkdir(checkpoint_path)
        # self.checkpoint_path = checkpoint_path + "/%03d-G%.4f-D%.4f.hdf5"

        self.ckpt = tf.train.Checkpoint(generator_g=self.g_AB, generator_f=self.g_BA, discriminator_x=self.d_A, discriminator_y=self.d_B)
        self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, checkpoint_path, max_to_keep=None)


# modeling에 이용된 함수를 뜯어봅시다.
class CycleGAN(CycleGAN):
    @staticmethod
    def conv2d(layer_input, filters, strides=2, f_size=4, normalization=True):
        d = Conv2D(filters, kernel_size=f_size, strides=strides, padding="same")(layer_input)
        if normalization:
            d = BatchNormalization()(d)
        d = ReLU()(d)
        return d

    @staticmethod
    def deconv2d(layer_input, filters, f_size=4):
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same")(u)
        u = BatchNormalization()(u)
        u = ReLU()(u)
        return u

    @staticmethod
    def resnet_block(layer_input, filters, strides=1, reverse=False, f_size=3, normalization=True):
        c1 = Conv2D(filters, kernel_size=f_size, strides=strides, padding="same")(layer_input)
        c1 = BatchNormalization()(c1)
        c1 = ReLU()(c1)

        c2 = Conv2D(filters, kernel_size=f_size, strides=1, padding="same")(c1)
        c2 = BatchNormalization()(c2)

        identity = Conv2D(filters, kernel_size=1, strides=strides, padding="same")(layer_input)
        u = ReLU()(c2 + identity)
        return u


# ? 생성자
class CycleGAN(CycleGAN):
    def build_generator(self):
        d0 = Input(shape=self.img_shape)
        d1 = self.conv2d(d0, self.gf, strides=1)
        d2 = self.conv2d(d1, self.gf * 2)
        d3 = self.conv2d(d2, self.gf * 4)

        r1 = self.resnet_block(d3, self.gf * 4)
        r2 = self.resnet_block(r1, self.gf * 4)
        r3 = self.resnet_block(r2, self.gf * 4)
        r4 = self.resnet_block(r3, self.gf * 4)
        r5 = self.resnet_block(r4, self.gf * 4)
        r6 = self.resnet_block(r5, self.gf * 4)
        r7 = self.resnet_block(r6, self.gf * 4)
        r8 = self.resnet_block(r7, self.gf * 4)
        r9 = self.resnet_block(r8, self.gf * 4)

        u1 = self.deconv2d(r9, self.gf * 4)
        u2 = self.deconv2d(u1, self.gf * 2)
        output_img = Conv2D(self.channels, kernel_size=3, strides=1, padding="same", activation="tanh")(u2)
        return Model(d0, output_img)


# ? 판별자
class CycleGAN(CycleGAN):
    def build_discriminator(self):
        img = Input(shape=self.img_shape)
        d1 = self.conv2d(img, self.df, normalization=False)
        d2 = self.conv2d(d1, self.df * 2)
        d3 = self.conv2d(d2, self.df * 4)
        d4 = self.conv2d(d3, self.df * 8)
        validity = Conv2D(1, kernel_size=4, strides=1, padding="same")(d4)
        return Model(img, validity)


# ? sampling function
class CycleGAN(CycleGAN):
    def sample_images(self, epoch, batch_i, folder="images"):
        if not os.path.exists("images"):
            os.mkdir("images")
        if not os.path.exists(folder):
            os.mkdir(folder)
        r, c = 2, 3

        imgs_A = self.data_loader.load_data(domain="0", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="1", batch_size=1, is_testing=True)
        # imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        # imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)

        # 이미지를 다른 도메인으로 변환합니다.
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # 원본 도메인으로 되돌립니다.
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # 이미지를 0 - 1 사이로 스케일을 바꿉니다.
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ["Original", "Translated", "Reconstructed"]
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i, j].axis("off")
                cnt += 1
        fig.savefig("%s/%d_%d.png" % (folder, epoch, batch_i))
        plt.show()

    # def showYourSelf(self, index=None):
    #     if index < 2:
    #         self.g_AB.summary()
    #     elif index < 3:
    #         self.g_BA.summary()
    #     elif index < 4:
    #         self.d_B.summary()
    #     elif index < 5:
    #         self.d_A.summary()
    #     else:
    #         self.combined.summary()


# ? train
class CycleGAN(CycleGAN):
    def train(self, epochs, batch_size=1, sample_interval=50):
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        for epoch in range(epochs):
            for i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                d_loss = 0.5 * np.add(dA_loss, dB_loss)
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B])

                if i % sample_interval == 0:
                    self.sample_images(epoch, i, "images/images_resnet")
            ckpt_save_path = self.ckpt_manager.save()
            # self.sample_images(epoch, 0, "images/images_resnet")

### train

In [None]:
cycle_gan = CycleGAN()
cycle_gan.train(epochs=100, batch_size=16)

### 뒷처리

In [None]:
# checkpoint_path = "model_history"
# if not os.path.exists(checkpoint_path):
#     os.mkdir(checkpoint_path)
# # checkpoint_path = checkpoint_path + "/%03d-G%.4f-D%.4f.hdf5"

# ckpt = tf.train.Checkpoint(generator_g=cycle_gan.g_AB, generator_f=cycle_gan.g_BA, discriminator_x=cycle_gan.d_A, discriminator_y=cycle_gan.d_B)
# ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=None)
# ckpt_save_path = ckpt_manager.save()

In [None]:
data.close()