In [None]:
from tensorflow.keras.layers import (
    Input,Conv2D,LeakyReLU,Activation,
    Resizing,UpSampling2D,MaxPooling2D,Dropout,Concatenate,Conv2DTranspose
)
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import normalize
from tensorflow.keras import backend as K
# random weight init
from tensorflow.keras.initializers import RandomNormal
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import os
import plotly.graph_objects as go
import tensorflow as tf
import cv2

In [None]:
class CycleGAN:
    def __init__(
        self,
        input_shape,
        learning_rate,
        lambda_validation,
        lambda_reconstr,
        lambda_id,
        genator_filters,
        discriminator_filters,
        dataset_path
    ):
        self.input_shape = input_shape
        self.learning_rate = learning_rate 
        self.lambda_validation = lambda_validation # 指定Lambda验证损失的权重
        self.lambda_reconstr = lambda_reconstr # 指定Lambda重构损失的权重
        self.lambda_id = lambda_id # 指定Lambda身份损失的权重
        self.genator_filters = genator_filters
        self.discriminator_filters = discriminator_filters
        self.dataset_path = dataset_path
        self.channels = self.input_shape[-1]
        self.width = self.input_shape[0]
        self.height = self.input_shape[1]
        self.image_shape = (self.width,self.height,self.channels)
        self.patch = int(self.width / 2**4) # 计算PatchGAN中的patch大小
        self.disc_patch = (self.patch,self.patch,1) # 计算判别器输出的patch大小
        self.epoch = 0
        self.load_data()
        self.compile_model()
    def load_data(self):
        self.trainA = tf.keras.utils.image_dataset_from_directory(
            os.path.join(self.dataset_path,"trainA"),
            image_size=(self.width,self.height),
            batch_size=1,
            label_mode=None,
            interpolation="bilinear",
            color_mode="rgb"
        )
        self.trainB = tf.keras.utils.image_dataset_from_directory(
            os.path.join(self.dataset_path,"trainB"),
            image_size=(self.width,self.height),
            batch_size=1,
            label_mode=None,
            interpolation="bilinear",
            color_mode="rgb"
        )
        self.testA = tf.keras.utils.image_dataset_from_directory(
            os.path.join(self.dataset_path,"testA"),
            image_size=(self.width,self.height),    
            batch_size=1,
            label_mode=None,
            interpolation="bilinear",
            color_mode="rgb"
        )
        self.testB = tf.keras.utils.image_dataset_from_directory(
            os.path.join(self.dataset_path,"testB"),
            image_size=(self.width,self.height),
            batch_size=1,
            label_mode=None,
            interpolation="bilinear",
            color_mode="rgb"
        )
        # Normalize images
        self.trainA = self.trainA.map(lambda x: x / 255.0)
        self.trainB = self.trainB.map(lambda x: x / 255.0)
        self.testA = self.testA.map(lambda x: x / 255.0)
        self.testB = self.testB.map(lambda x: x / 255.0)
        # Convert dataset to numpy array
        self.trainA = np.array(list(self.trainA.as_numpy_iterator()))
        self.trainB = np.array(list(self.trainB.as_numpy_iterator()))
        self.testA = np.array(list(self.testA.as_numpy_iterator()))
        self.testB = np.array(list(self.testB.as_numpy_iterator()))
        # Reshaping as (None,width,height,channels)
        self.trainA = self.trainA.reshape(-1,self.width,self.height,self.channels)
        self.trainB = self.trainB.reshape(-1,self.width,self.height,self.channels)
        self.testA = self.testA.reshape(-1,self.width,self.height,self.channels)
        self.testB = self.testB.reshape(-1,self.width,self.height,self.channels)
    def build_generator_unet(self):
        def downsample(layer_input, filters, f_size=4):
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = InstanceNormalization(axis=-1,center=False,scale=False)(d)
            d = Activation('relu')(d)
            return d
        def upsample(layer_input, skip_input, filters, f_size=4):
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same')(u)
            u = InstanceNormalization(axis=-1,center=False,scale=False)(u) # 使用实例归一化取代批归一化，进一步加强模型的泛化能力
            u = Activation('relu')(u)
            u = Concatenate()([u, skip_input]) # 跳跃连接，将上采样层的输出与下采样层的输出进行连接
            return u
        # Image input
        img = Input(shape=self.input_shape)
        # Downsampling
        d1 = downsample(img, self.genator_filters)
        d2 = downsample(d1, self.genator_filters*2)
        d3 = downsample(d2, self.genator_filters*4)
        d4 = downsample(d3, self.genator_filters*8)
        # Upsampling
        u1 = upsample(d4, d3, self.genator_filters*4)
        u2 = upsample(u1, d2, self.genator_filters*2)
        u3 = upsample(u2, d1, self.genator_filters)
        u4 = UpSampling2D(size=2)(u3)
        # Output
        # 最后的输出层使用tanh激活函数，将像素值归一化到[-1,1]之间，但输出的维通道仍然与输入的通道相同
        output = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)
        return Model(img, output)
    def build_discriminator(self):
        # 创建一个4x4的卷积层，步长为2，用于对输入的图像进行下采样
        def conv4(layer_input, filters, stride, norm=True):
            d = Conv2D(filters, kernel_size=4, strides=stride, padding='same')(layer_input)
            if norm:
                d = InstanceNormalization(axis=-1,center=False,scale=False)(d)   
            d = LeakyReLU(alpha=0.2)(d)
            return d
        img = Input(shape=self.image_shape)
        # 除第一层外，其余层都使用实例归一化
        y = conv4(img, self.discriminator_filters, stride=2, norm=False)
        y = conv4(y, self.discriminator_filters*2, stride=2)
        y = conv4(y, self.discriminator_filters*4, stride=2)
        y = conv4(y, self.discriminator_filters*8, stride=1)
        # 为了防止与MSE计算时产生张量大小，这里使用了一个8x8的Resize层
        # 同时使用双线性插值防止信息丢失
        y = Resizing(
            8,8,interpolation="bilinear", crop_to_aspect_ratio=False
        )(y)
        # 最后输出一个8x8x1的PatchGAN输出
        output = Conv2D(1, kernel_size=4, strides=1, padding='same')(y)
        model = Model(img, output)
        return model
    def compile_model(self):
        # 编译判别器
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(
            loss='mse',
            optimizer=Adam(self.learning_rate),
            metrics=['accuracy']
        )
        self.d_B.compile(
            loss='mse',
            optimizer=Adam(self.learning_rate),
            metrics=['accuracy']
        )
        # 编译生成器
        self.g_AB = self.build_generator_unet()
        self.g_BA = self.build_generator_unet()
        # 锁定判别器
        self.d_A.trainable = False
        self.d_B.trainable = False
        # 输入
        img_A = Input(shape=self.image_shape)
        img_B = Input(shape=self.image_shape)
        # 生成器生成
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # 重构
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        # 生成器的判别器
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)
        # 生成器的恒等
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)

        # 创建混合模型
        self.combined = Model(
            inputs=[img_A, img_B],
            outputs=[valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id]
        )
        self.combined.compile(
            loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],
            loss_weights=[self.lambda_validation, self.lambda_validation, self.lambda_reconstr
                          , self.lambda_reconstr, self.lambda_id, self.lambda_id],
            optimizer=Adam(self.learning_rate)
        )
        # 解锁判别器
        self.d_A.trainable = True
        self.d_B.trainable = True
    def train(self, epochs, batch_size=1):
        # 真假标签
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        self.disc_loss = []
        self.gen_loss = []
        self.cycle_loss = []
        self.id_loss = []
        self.adv_loss = []
        with tqdm(total=epochs,unit=" epoch ") as pbar:
            for epoch in range(epochs):
                # ----------------------
                #  训练判别器
                # ----------------------
                # 选择一批图片
                idx = np.random.randint(0, self.trainA.shape[0], batch_size)
                imgs_A = self.trainA[idx]
                imgs_B = self.trainB[idx]
                # 生成一批假图片
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)
                # 训练判别器
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)
                # ------------------
                #  训练生成器
                # ------------------
                # 训练生成器
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B])
                # 更新进度条
                pbar.set_description("Epoch: %d [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f]" % (
                    epoch+1, (dA_loss[0] + dB_loss[0]) / 2, 100 * (dA_loss[1] + dB_loss[1]) / 2, g_loss[0],
                    np.mean(g_loss[1:3]), np.mean(g_loss[3:5]), np.mean(g_loss[5:6])))
                pbar.update(1)
                self.disc_loss.append((dA_loss[0] + dB_loss[0]) / 2)
                self.gen_loss.append(g_loss[0])
                self.cycle_loss.append(np.mean(g_loss[3:5]))
                self.id_loss.append(np.mean(g_loss[5:6]))
                self.adv_loss.append(np.mean(g_loss[1:3]))
                self.epoch+=1
    def sample_image(self):
        # 设置Matplotlib渲染三列两行
        r, c = 2, 3
        # 选择一批图片
        idx = np.random.randint(0, self.testA.shape[0], 1)
        imgs_A = self.testA[idx] 
        imgs_B = self.testB[idx] 
        # 生成图片
        fake_B = self.g_AB.predict(imgs_A) 
        fake_A = self.g_BA.predict(imgs_B) 
        # 重构图片
        reconstr_A = self.g_BA.predict(fake_B) 
        reconstr_B = self.g_AB.predict(fake_A) 
        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        # 将图像切换回正常的RGB通道范围
        imgs_A = 0.5 * imgs_A + 0.5
        imgs_B = 0.5 * imgs_B + 0.5
        fake_A = 0.5 * fake_A + 0.5
        fake_B = 0.5 * fake_B + 0.5
        reconstr_A = 0.5 * reconstr_A + 0.5
        reconstr_B = 0.5 * reconstr_B + 0.5
        # 绘制图像
        axs[0, 0].imshow(imgs_A[0])
        axs[0, 0].set_title(titles[0])
        axs[0, 1].imshow(fake_B[0])
        axs[0, 1].set_title(titles[1])
        axs[0, 2].imshow(reconstr_A[0])
        axs[0, 2].set_title(titles[2])
        axs[1, 0].imshow(imgs_B[0])
        axs[1, 1].imshow(fake_A[0])
        axs[1, 2].imshow(reconstr_B[0])
        fig.show()
        plt.show()
    def plot_loss(self):
        fig = go.Figure()
        fig.add_trace(
            go.Line(
                x=list(range(len(self.disc_loss))),
                y=self.disc_loss,
                name="disc_loss"
            )
        )
        fig.add_trace(
            go.Line(
                x=list(range(len(self.gen_loss))),
                y=self.gen_loss,
                name="gen_loss"
            )
        )
        fig.add_trace(
            go.Line(
                x=list(range(len(self.gen_loss))),
                y=self.cycle_loss,
                name="cycle_loss"
            )
        )
        fig.add_trace(
            go.Line(
                x=list(range(len(self.gen_loss))),
                y=self.id_loss,
                name="id_loss"
            )
        )
        fig.add_trace(
            go.Line(
                x=list(range(len(self.gen_loss))),
                y=self.adv_loss,
                name="adv_loss"
            )
        )
        fig.update_layout(
            title="Training Loss",
            xaxis_title="epoch",
            yaxis_title="loss"
        )
        fig.update_traces(mode="markers+lines")
        fig.show()

In [None]:
# 创建模型
model = CycleGAN(
    input_shape=(128, 128, 3),
    learning_rate=0.0002,
    lambda_validation=1,
    lambda_reconstr=10,
    lambda_id=2,
    genator_filters=32,
    discriminator_filters=32,
    dataset_path="./apple2orange/"
)

In [None]:
# 训练模型
model.train(epochs=1000, batch_size=1)

In [None]:
# 生成图片
model.sample_image()

In [None]:
# 绘制损失曲线
model.plot_loss()