# StarGAN 논문 구현 ( tensorflwo 2.0 )

- 최근 가장 핫하게 떠오른 GAN 알고리즘은 starGAN을 tensorflow2.0을 사용하여 구현.
    - title : StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation (arXiv:1711.09020)
    - authors : Yunjey Choi, Minje Choi, Munyoung Kim, Jung-Woo Ha, Sunghun Kim, Jaegul Choo
- cf) GPU 사용 여건이 안되어 데이터 셋은 랜덤 생성한 텐서 데이터로 논문에서 사용한 데이터 셋의 shape와 특성을 똑같이 구현하여 사용하였습니다.

In [199]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, ReLU, Conv2DTranspose, LeakyReLU
from tensorflow.keras.activations import tanh

In [None]:
tf.executing_eagerly()

# Load Data

- GPU 사용을 하지 않고 프로토 타입을 만드는 연산이기 때문에 랜덤으로 이미지, 라벨에 대응하는 matrix를 생성을 한다.
- image.shape = (batch_size,128,128,3)
- label.shape = (batch_size,1,5)

In [400]:
# random data와 label 생성하는 함수
def random_image_label_data(sample_size, image_size, RGB=3):
    '''
    sample_size : 원하는 샘플 양 설정.
    image_size :  원하는 이미지 사이즈 설정. ex) 128x128 -> image_size=128
    RGB : default=3 
    '''
    N1 = sample_size
    image_size_ = image_size
    RGB_ = RGB
    # np.random.normal : Draw random samples from a normal (Gaussian) distribution.
    image_data = np.random.normal(size=[N1,image_size_,image_size_,RGB_]) 
    label_hair = tf.one_hot(np.random.randint(0,3,size=N1),depth=3)
    label_gender = tf.one_hot(np.random.randint(0,2,size=N1),depth=1)
    label_age = tf.one_hot(np.random.randint(0,2,size=N1),depth=1)
    label_gender_age = tf.concat((label_gender,label_age), axis=1)
#    label_gender_age = np.append(label_gender,label_age, axis=1)
    label = tf.concat((label_hair,label_gender_age), axis=1)
    return image_data, label

## data set

In [3]:
# 데이터 shape 파악
# input이 들어가서 output으로 오류없이 나오는지만 확인._실제로 하면 GPU필요하고 하루 이상 돌아가야함.
image_data, label = random_image_label_data(1000,128)
train_dataset = tf.data.Dataset.from_tensor_slices((image_data, label))
train_dataset = train_dataset.batch(2) # 작은 배치로 해서 코드가 잘 돌아가는지 점검

for images, labels in train_dataset.take(2): 
    print(images.shape)
    print(labels.shape)
    '''
    (2, 128, 128, 3) #첫번째 2개씩 128,128,3짜리
    (2, 5) # 첫번째 label 2개씩 불러오기
    (2, 128, 128, 3)
    (2, 5)
    '''

(2, 128, 128, 3)
(2, 5)
(2, 128, 128, 3)
(2, 5)


# Generator

- generator의 input으로는 imgae와 label이 들어간다. 
- 실제 이 두 메트릭스가 generator로 연산될 때는 하나로 묵어서 들어간다.
    - image = (batch_size,128,128,3)
    - label = (batch_size,1,5)
- concat(image, label) 구조
    1. label을 shape를 (batch_size,1,5)에서 (batch_size,128,128,5)로 만들어 준다.
    1. image와 label을 concatenate해서 shape가 (batch_size,128,128,8)인 dataset을 만들어 준다.
    

In [544]:
def concat(image, label):
    '''
    1차원 행렬인 label(array)을 다차원 행렬인 image(array)에 맞춰서 broadcasting해서 합쳐서 tensor로 출력해주는 함수.
    label, image 는 array와 tensor형태 모두 허용.
    '''
    batchsize = image.shape[0] #  batchsize  
    label_len =  label[0].shape[0] # 하나의 label의 list 길이 #5
    image_size = image[0].shape[0] # 하나의 image의 한변의 길이 #128
    
    label_vol = tf.zeros((batchsize,image_size,image_size,label_len)) \
                    + tf.reshape(labels, shape=[batchsize,1,1,label_len])  
    label_vol = tf.dtypes.cast(label_vol, tf.float64)
    return tf.concat((image, label_vol), axis = 3)

In [545]:
#test
for images, labels in train_dataset.take(2): 
    print(images.shape)
    print(labels.shape)
    print(concat(images, labels).shape)

(2, 128, 128, 3)
(2, 5)
(2, 128, 128, 8)
(2, 128, 128, 3)
(2, 5)
(2, 128, 128, 8)


In [154]:
class InstanceNormalization(tf.keras.layers.Layer): #IN은 layer이기 때문에 'tf.keras.layers.Layer'

    def __init__(self, axes=[1, 2], epsilon=1e-5): # 여기 변수는 class 객체화 할때 입력해주어야하는 변수
        super(InstanceNormalization, self).__init__()
        self.epsilon = epsilon # cut_off
        self.axes = axes

    def build(self, input_shape):  #IN은 하나의 layer기 때문에 build해줘야함.
        shape = tf.TensorShape(input_shape)
        param_shape = shape[-1] 
        # Create a trainable weight variable for this layer. 'InstanceNormalization(tf.keras.layers.Layer)'
        self.gamma = self.add_weight(name='gamma',     #학습하는 parameter명시
                                     shape=param_shape, #param_shape 모양으로 만들어줌.
                                     initializer='ones', #학습하기 전엔 일단 1
                                     trainable=True)     # 학습가능하게 해줌
        self.beta = self.add_weight(name='beta',
                                    shape=param_shape,
                                    initializer='zeros',
                                    trainable=True)
        # Make sure to call the `build` method at the end
        super(InstanceNormalization, self).build(input_shape) # 모양 빌드 해준다.
        
    def call(self, inputs): #객체 실행할때 실질 수행되는 함수
        # 여기 있는 변수는 함수 쓸때 명시해주어야하는 변수
        input_shape = inputs.get_shape() 
        mean, variance = tf.nn.moments(inputs, axes = self.axes, keepdims=True) 
            # tf.nn.moments : Calculates the mean and variance of x along axes = [1, 2]
            # keepdims: produce 'moments' with the same dimensionality as the input.
        normalized = (inputs - mean) / tf.sqrt(variance + self.epsilon) 
        return self.gamma * normalized + self.beta

### residual block 
![residual%20block.png](attachment:residual%20block.png)
- conv -> BN -> relu -> conv -> BN -> RB ->relu
- 아래 블록이 6개
- 자세한 이유 공부하기?
    - https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf
- 참고 코드 : https://www.tensorflow.org/alpha/tutorials/eager/custom_layers

In [182]:
#참고 코드 : https://www.tensorflow.org/alpha/tutorials/eager/custom_layers
class ResidualBlock(tf.keras.Model):
    def __init__(self):
        super().__init__() 

        self.conv1 = Conv2D(256, kernel_size = 3, strides = 1, padding = 'same')
        self.IN1 = InstanceNormalization()
        self.conv2 = Conv2D(256, kernel_size = 3, strides = 1, padding = 'same')
        self.IN2 = InstanceNormalization()
        self.activation_relu = ReLU()
    
    def call(self, input_tensor):
        x = self.conv1(input_tensor)
        x = self.IN1(x)
        x = self.activation_relu(x)

        x = self.conv2(x)
        x = self.IN2(x)

        x += input_tensor #이게 핵심
        return self.activation_relu(x)

- cf) 'tf.keras.layers.Layer'을 상속해온 이유..
    - We recommend that descendants of Layer implement the following methods:

    - **\_\_init\_\_()**: Save configuration in member variables
    - **build()**: Called once from \_\_call\_\_, when we know the shapes of inputs and dtype. Should have the calls to add_weight(), and then call the super's build() (which sets self.built = True, which is nice in case the user wants to call build() manually before the first \_\_call\_\_).
    - **call()**: Called in \_\_call\_\_ after making sure build() has been called once. Should actually perform the logic of applying the layer to the input tensors (which should be passed in as the first argument).

- Generator network architecture
![G_structure.jpg](attachment:G_structure.jpg)

- cf) padding size
    - $O=\frac{W-F+2P}{S} +1$
    - $O$ = output size, $W$ = input size, $F$ = Filter size, $S$ = stride, $P$ = padding size
    - ex) 1st CONV(N64, K7, S1, P3) in Down sampling
        - $h = \frac{h-7+2P}{1} +1$ $\rightarrow$ $P=3$
    - ex) 2nd CONV(N128, K4, S2, P1) in Down sampling
        - $\frac{h}{2} = \frac{h-4+2P}{2} +1$ $\rightarrow$ $P=1$        

In [197]:
class Generator(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.activation_relu = ReLU()  
        #Down-sampling
        self.conv1 = Conv2D(64, kernel_size = 7, strides = 1, padding = 'SAME') #kernel_size의미
        '''
            kernel_size: An integer or tuple/list of 2 integers, 
                  specifying the height and width of the 2D convolution window. 
                  Can be a single integer to specify the same value for all spatial dimensions.
                  cf) the size of filter
        '''
        self.IN1 = InstanceNormalization() 
        self.conv2 = Conv2D(128, kernel_size = 4, strides = 2, padding = 'same')
        self.IN2 = InstanceNormalization()
        self.conv3 = Conv2D(256, kernel_size = 4, strides = 2, padding = 'same')
        self.IN3 = InstanceNormalization()
        #Bottleneck
        self.Residual1 = ResidualBlock()
        self.Residual2 = ResidualBlock()
        self.Residual3 = ResidualBlock()
        self.Residual4 = ResidualBlock()
        self.Residual5 = ResidualBlock()
        self.Residual6 = ResidualBlock() 
        #Up-sampling    
        self.deconv1 = Conv2DTranspose(128, kernel_size = 4, strides = 2, padding = 'same') #DECONV
        self.IN4 = InstanceNormalization()
        self.deconv2 = Conv2DTranspose(64, kernel_size = 4, strides = 2, padding = 'same')
        self.IN5 = InstanceNormalization()
        self.conv_final = Conv2D(3, kernel_size = 7, strides = 1, padding = 'same')
    
    def call(self, image, label):
        #Down-sampling
        flow = concat(image, label)
        flow = self.conv1(flow)
        flow = self.IN1(flow)
        flow = self.activation_relu(flow)
        flow = self.conv2(flow)
        flow = self.IN2(flow)
        flow = self.activation_relu(flow)
        flow = self.conv3(flow)
        flow = self.IN3(flow)
        flow = self.activation_relu(flow)
        #Bottleneck
        flow = self.Residual1(flow)
        flow = self.Residual2(flow)
        flow = self.Residual3(flow)
        flow = self.Residual4(flow)
        flow = self.Residual5(flow)
        flow = self.Residual6(flow)
        #Up-sampling
        flow = self.deconv1(flow)
        flow = self.IN4(flow)
        flow = self.activation_relu(flow)
        flow = self.deconv2(flow)
        flow = self.IN5(flow)
        flow = self.activation_relu(flow)
        flow = self.conv_final(flow)
        output = tanh(flow)
        
        return output

In [196]:
G = Generator()
# test
for images, labels in train_dataset.take(2): 
    print(images.shape)
    print(labels.shape)
    print(G(images, labels).shape)

(2, 128, 128, 3)
(2, 5)
(2, 128, 128, 3)
(2, 128, 128, 3)
(2, 5)
(2, 128, 128, 3)


# Discriminator
- Discriminator network architecture
![D_structure.jpg](attachment:D_structure.jpg)

In [308]:
class Discriminator(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.activation_leakyRelu = LeakyReLU()
        self.conv_input = Conv2D(64, kernel_size = 4, strides = 2, padding = 'SAME')
        self.conv1 = Conv2D(128, kernel_size = 4, strides = 2, padding = 'SAME')
        self.conv2 = Conv2D(256, kernel_size = 4, strides = 2, padding = 'SAME')
        self.conv3 = Conv2D(512, kernel_size = 4, strides = 2, padding = 'SAME')
        self.conv4 = Conv2D(1024, kernel_size = 4, strides = 2, padding = 'SAME')
        self.conv5 = Conv2D(2048, kernel_size = 4, strides = 2, padding = 'SAME')
        self.conv_src = Conv2D(1, kernel_size = 3, strides = 1, padding = 'SAME')
        self.conv_cls = Conv2D(5, kernel_size = 2, strides = 1, padding = 'valid')
    
    def call(self, Input_image):
        flow = self.conv_input(Input_image)
        flow = self.activation_leakyRelu(flow)
        flow = self.conv1(flow)
        flow = self.activation_leakyRelu(flow)
        flow = self.conv2(flow)
        flow = self.activation_leakyRelu(flow)
        flow = self.conv3(flow)
        flow = self.activation_leakyRelu(flow)
        flow = self.conv4(flow)
        flow = self.activation_leakyRelu(flow)
        flow = self.conv5(flow)
        flow = self.activation_leakyRelu(flow)
        flow = self.conv_src(flow)
        D_src = tf.sigmoid(flow) #논문엔 안나왔지만 sigmoid 넣어줘야함. loss 함수를 cross_entropy만 쓸것이기 때문.
        flow = self.conv_cls(flow)
        D_cls = tf.sigmoid(flow) #논문엔 안나왔지만 sigmoid 넣어줘야함. loss 함수를 cross_entropy만 쓸것이기 때문.
        ''' 
        Discriminator은 확률을 출력해야하니 마지막에 sigmoid 함수를 가해주거나,
        loss 함수를 sigmoid_corss_entropy를 쓰거나.
        '''
        return D_src, D_cls
    
    

In [309]:
# test
D = Discriminator()
for images, labels in train_dataset.take(2): 
    print(images.shape)
    print(labels.shape)
    print(G(images, labels).shape)
    src, cls = D(images)
    print(src.shape, cls.shape)

(2, 128, 128, 3)
(2, 5)
(2, 128, 128, 3)
(2, 2, 2, 1) (2, 1, 1, 5)
(2, 128, 128, 3)
(2, 5)
(2, 128, 128, 3)
(2, 2, 2, 1) (2, 1, 1, 5)


# Loss function
![starGAN_training.jpg](attachment:starGAN_training.jpg)
- (a) : ${\cal L}_{\rm D} = -{\cal L}_{\rm adv} + \lambda_{\rm cls} {\cal L}^{\rm r}_{\rm cls}$
- (b)~(d) : ${\cal L}_{\rm D} = {\cal L}_{\rm adv} + \lambda_{\rm cls} {\cal L}^{\rm f}_{\rm cls} + \lambda_{\rm rec} {\cal L}_{\rm rec}$
    - ${\cal L}_{\rm adv} = {\bf E}_x [\ln D_{\rm src}(x)] + {\bf E}_{x,c} [\ln(1- D_{\rm src}(G(x,c))]$
        - fake image(G(x,c))를 생성할 때 원본 이미지 x에 요구한 class = c
    - ${\cal L}^{\rm r}_{\rm cls} = {\bf E}_{x, c'} [-\ln D_{\rm cls}(c'|x)]$
        - real image x에 대응되는 class = c'
    - ${\cal L}^{\rm f}_{\rm cls} = {\bf E}_{x, c'} [-\ln D_{\rm cls}(c|G(x,c))]$
    - ${\cal L}_{\rm rec} = {\bf E}_{x, c', c} [\| x - G(G(x,c),c')\|_{\rm L1}]$
    - $\lambda_{\rm cls}=1$ and $\lambda_{\rm rec}=10$ in this paper.

In [317]:
'''
Loss function
'''
epsilon = 1e-6
D = Discriminator()
G = Generator()
# generated_output = G(images, labels)

def adv_loss(real_input, generated_output):
    D_real, D_real_cls = D(real_input)  
    D_fake, D_fake_cls = D(generated_output) 
    return tf.reduce_mean(tf.math.log(D_real+epsilon)+tf.math.log(1-D_fake+epsilon))
    
def real_cls_loss(real_input): 
    D_real, D_real_cls = D(real_input) 
    return tf.reduce_mean(-tf.math.log(D_real_cls + epsilon))

def fake_cls_loss(generated_output, label_fake):
    D_fake, D_fake_cls = D(generated_output) 
    return tf.reduce_mean(-tf.math.log(D_fake_cls + epsilon))

def rec_loss(real_input, label_real, generated_output, label_fake):
    G_real_from_fake = G(generated_output, label_real)
    real_subt_G = tf.norm(tf.math.subtract(real_input,G_real_from_fake), axis=1 )
    return tf.reduce_mean(real_subt_G)

def Loss_G(real_input, label_real, generated_output, label_fake, lambda_cls=1, lambda_rec=10):
    return adv_loss(real_input, generated_output) + lambda_cls*fake_cls_loss(generated_output, label_fake)\
        +lambda_rec*rec_loss(real_input, label_real, generated_output, label_fake)

def Loss_D(real_input, generated_output, lambda_cls=1):
    return -adv_loss(real_input, generated_output) + lambda_cls*real_cls_loss(real_input)


In [319]:
# test
N=2
for images, labels in train_dataset.take(N): 
    _, fake_label = random_image_label_data(N,128)
    print(images.shape)
    print(labels.shape)
    print(fake_label.shape)
    generated_output = G(images, fake_label)
    print(Loss_G(images, labels, generated_output, fake_label))
    print(Loss_D(images, generated_output))

(2, 128, 128, 3)
(2, 5)
(2, 5)
tf.Tensor(135.81918080993773, shape=(), dtype=float64)
tf.Tensor(2.0774274480575943, shape=(), dtype=float64)
(2, 128, 128, 3)
(2, 5)
(2, 5)
tf.Tensor(135.51463978790318, shape=(), dtype=float64)
tf.Tensor(2.0775985789917595, shape=(), dtype=float64)


In [546]:
#train algorithm

generator_optimizer = tf.optimizers.Adam(1e-4)
discriminator_optimizer = tf.optimizers.Adam(1e-4)

@tf.function  #Decorlator
def train_step(images, labels, batch_size):
    _, fake_cls = random_image_label_data(batch_size,128)
    fake_cls = tf.convert_to_tensor(fake_cls)
    generated_images = G(images, fake_cls)
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: #무슨 역할이지?
        
        gen_loss = Loss_G(images, labels, generated_images, fake_cls)
        disc_loss = Loss_D(images, generated_images)

    gradients_of_generator = gen_tape.gradient(gen_loss, G.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, D.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, G.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, D.trainable_variables))

**cf) @tf.function 역할**
    - Input : print(concat(images, labels))
    - Output : (2, 128, 128, 3)

- 함수에 @tf.function 붙이면
    - Input : print(concat(images, labels))
    - Output : tf.Tensor([  2 128 128   3], shape=(4,), dtype=int32)

In [548]:
#test
N=2
for images, labels in train_dataset.take(N): 
    train_step(images, labels, N)
    print('Good!')

Good!
Good!


In [346]:
# def train(dataset, epochs):  
#     for epoch in range(epochs):
#         start = time.time()
    
#     for images in dataset:
#         train_step(images)

#     display.clear_output(wait=True)
#     generate_and_save_images(generator, epoch + 1, random_vector_for_generation)
    
#     # saving (checkpoint) the model every 15 epochs
#     if (epoch + 1) % 15 == 0:
#         checkpoint.save(file_prefix = checkpoint_prefix)
    
#     print ('Time taken for epoch {} is {} sec'.format(epoch + 1,
#                                                       time.time()-start))
#   # generating after the final epoch
#     display.clear_output(wait=True)
#     generate_and_save_images(generator,epochs, random_vector_for_generation)