In [1]:
import tensorflow as tf

In [2]:
from tensorflow.keras.layers import Dense, Conv2D, Conv2DTranspose, Flatten, Input, MaxPooling2D, LeakyReLU, BatchNormalization, PReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
from tensorflow.keras.preprocessing import image_dataset_from_directory

from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input

# from PIL import Image
# from tensorflow.keras.preprocessing import image
# from tensorflow.keras.preprocessing.image import load_img, img_to_array
# from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model, Sequential
import cv2

import random
import numpy as np
import matplotlib.pyplot as plt
from IPython.core.pylabtools import figsize
from glob import glob
import os
from tqdm import tqdm
import shutil

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  # 텐서플로가 첫 번째 GPU만 사용하도록 제한
  try:
    tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
  except RuntimeError as e:
    # 프로그램 시작시에 접근 가능한 장치가 설정되어야만 합니다
    print(e)

In [4]:
image_size = (64,64,3)
H = image_size[0]
W = image_size[1]
C = image_size[2]

target_size = (256,256,3)
T_H = target_size[0]
T_W = target_size[1]
T_C = target_size[2]

batch_size = 2
epochs = 1000
sample_every = 5

In [5]:
def g_layer_block(kernel_size, units, strides, x):

  x = Conv2D(filters = units, kernel_size= kernel_size, strides = strides, padding = "same")(x)
  
  x = BatchNormalization(momentum = 0.5)(x)
  x = PReLU(shared_axes = [1,2])(x)
  x = Conv2D(filters = units, kernel_size= kernel_size, strides = strides, padding = "same")(x)
  
  
  x = BatchNormalization()(x)
  
  return x


In [6]:
def g_layer_block_2(kernel_size, units, strides, x, scale):
  x = Conv2D(filters = units, kernel_size= kernel_size, strides = strides, padding = "same")(x)
  x = tf.nn.depth_to_space(x,scale)
  x = PReLU(shared_axes = [1,2])(x)

  return x

In [7]:
def g_layer_block_3(x, connect_x):
  x = g_layer_block(kernel_size=3, units = H, strides = 1, x= x)
  x = tf.keras.layers.add([connect_x, x])
  new_connect_x = x

  return x, new_connect_x

In [8]:
def build_generator(img_size):
  i = Input(shape = img_size)
  x = Conv2D(filters = H, kernel_size = 9, strides = 1, padding = "same")(i)
  x = PReLU(shared_axes = [1,2])(x)

  connect_x = x
  connect_x_2 = connect_x

  num_residual = 16
  for j in range(num_residual):
    x, connect_x = g_layer_block_3(x, connect_x= connect_x)
  
  
  x = Conv2D(kernel_size=3, filters = H, strides = 1, padding = "same")(x)
  x = BatchNormalization(momentum = 0.5)(x)
  x = tf.keras.layers.add([connect_x_2,x])

  x = g_layer_block_2(kernel_size=3, units= H*4, strides= 1, x= x, scale = 2)

  x = g_layer_block_2(kernel_size=3, units= H*4, strides= 1, x= x, scale = 2)
  # x = Conv2D(filters= 3, kernel_size=9, strides = 1, padding = "same" )(x)

  x = Conv2D(filters= 3, kernel_size=9, strides = 1, padding = "same", activation = "tanh")(x)

  model = Model(inputs = i, outputs = x)
  return model


  



In [9]:
def d_layer_block(kernel_size,units,strides,x):
  x = Conv2D(filters = units, kernel_size= kernel_size, strides = strides, padding = "same")(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = LeakyReLU()(x)

  return x

In [10]:
def build_discriminator(img_size):
  i = Input(shape = img_size)
  x = Conv2D(filters=H, kernel_size=3, strides = 1, activation= LeakyReLU())(i)
  x = d_layer_block(3,H,2,x)
  x = d_layer_block(3,H*2,1,x)
  x = d_layer_block(3,H*2,2,x)
  x = d_layer_block(3,H*4,1,x)
  x = d_layer_block(3,H*4,2,x)
  x = d_layer_block(3,H*8,1,x)
  x = d_layer_block(3,H*8,2,x)

  x = Flatten()(x)
  x = Dense(H*16)(x)
  x = LeakyReLU()(x)
  x = Dense(1, activation = "sigmoid")(x)

  model = Model(inputs = i, outputs = x)
  return model

In [11]:
discriminator = build_discriminator(img_size=(T_H,T_W,T_C))

generator = build_generator(img_size = (H,W,C))

z = Input(shape = (H,W,C))

sample_img= generator(z, training = False)

fake_pre = discriminator(sample_img, training = False)

In [12]:
generator.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 conv2d_8 (Conv2D)              (None, 64, 64, 64)   15616       ['input_2[0][0]']                
                                                                                                  
 p_re_lu (PReLU)                (None, 64, 64, 64)   64          ['conv2d_8[0][0]']               
                                                                                                  
 conv2d_9 (Conv2D)              (None, 64, 64, 64)   36928       ['p_re_lu[0][0]']                
                                                                                            

Pretrained model -> VGG19_model -> perceptual_model? -> content loss

In [13]:
# Calling VGG 19 network for Loss
vgg = VGG19(include_top= False, weights = "imagenet", input_shape = (T_H, T_W, T_C))
# vgg = VGG19(include_top= False, weights = "imagenet", input_shape = (None, T_W, T_C))
img = Input(shape = (T_H, T_W, T_C))

VGG19_model = Model(inputs= vgg.input, outputs = vgg.layers[10].output)

VGG19_model.trainable = False

In [14]:
VGG19_model.summary()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 256, 256, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 256, 256, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 256, 256, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 128, 128, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 128, 128, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 128, 128, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 64, 64, 128)       0   

In [15]:
# This is my image result file path
file_path = "Result/Sr_gan"
if not os.path.exists(file_path):
  os.makedirs(file_path)

In [16]:
# This is my check point file path
check_path = "Check_point/Sr_gan"
if not os.path.exists(check_path):
  os.makedirs(check_path)
check_point = tf.train.Checkpoint(generator = generator, discriminator = discriminator)
check_point_manager = tf.train.CheckpointManager(checkpoint=check_point, directory= check_path, max_to_keep= 5)

In [17]:
# check_point.restore(check_point_manager.latest_checkpoint)

In [18]:
def image_to_numpy(batch_size, batch, img_path):
  
  batch_np_64 = np.zeros((1, H, W, C))
  batch_np_128 = np.zeros((1, T_H, T_W, T_C))

  for img_in_path in img_path[batch:batch + batch_size]:
    image_open = cv2.imread(img_in_path)
    image_open = cv2.cvtColor(image_open, cv2.COLOR_BGR2RGB)

    image_64 = cv2.resize(image_open, (H,W) , interpolation=cv2.INTER_CUBIC)
    image_128 = cv2.resize(image_open, (T_H, T_W))

    image_64 = np.expand_dims(image_64, axis = 0)
    image_128 = np.expand_dims(image_128, axis = 0)

    batch_np_64 = np.append(batch_np_64, image_64, axis = 0)
    batch_np_128 = np.append(batch_np_128, image_128, axis = 0)

  batch_np_64 = np.delete(batch_np_64, 0, axis = 0)
  batch_np_128 = np.delete(batch_np_128, 0, axis = 0)

  batch_np_64 = ((batch_np_64/255.0)-0.5)*2
  batch_np_128 = ((batch_np_128/255.0)-0.5)*2

  return batch_np_64, batch_np_128

In [19]:
def sample_images(epoch, img_path):
  epoch = epoch
  rows, cols = 4,3
  
  fig, ax = plt.subplots(rows,cols,figsize = (15,15))
  fig.tight_layout()
  counter = 0

  for i in range(0 ,rows*cols-1, cols):

    rand_index = np.random.choice(len(img_path))
    image_open = cv2.imread(img_path[rand_index])
    image_open = cv2.cvtColor(image_open, cv2.COLOR_BGR2RGB)

    image_64 = cv2.resize(image_open,(H,W), interpolation=cv2.INTER_CUBIC)
    image_128 = cv2.resize(image_open, (T_H,T_W))

    image_128 = image_128/255.0

    image_64_plot = image_64/255.0

    image_64 = ((image_64/255.0)-0.5)*2

    image_64 = np.expand_dims(image_64, axis= 0)

    fake_img = generator.predict(image_64)
    
    fake_img = np.reshape(fake_img, (T_H,T_W,T_C))
    
    fake_img = ((fake_img + 1)/2.0)
    

    plt.axis("off")
    ax[counter][0].imshow(image_64_plot)
    ax[counter][1].imshow(fake_img)
    ax[counter][2].imshow(image_128)
    

    counter += 1
  
  while True:
    file_name = file_path + f"/%d.png" %epoch
    if not os.path.exists(file_name):
      break     
    epoch += 2
    
    
  plt.savefig(file_name)
  plt.close()

In [20]:
# sample_images(0)

In [21]:
discriminator_losses = tf.keras.losses.BinaryCrossentropy()
generator_losses = tf.keras.losses.BinaryCrossentropy()
content_losses = tf.keras.losses.MeanSquaredError()

learning_rate = PiecewiseConstantDecay(boundaries = [100000], values = [1e-4, 1e-5])

gen_optimizer = Adam(learning_rate = learning_rate)
dis_optimizer = Adam(learning_rate = learning_rate)


def content_loss(real_imgs, fake_imgs):
  real_imgs = (real_imgs + 1)/2
  fake_imgs = (fake_imgs + 1)/2

  real_imgs = preprocess_input(real_imgs)
  fake_imgs = preprocess_input(fake_imgs)

  how_real_classified = VGG19_model(real_imgs)
  how_fake_classified = VGG19_model(fake_imgs)

  vgg_loss = content_losses(how_real_classified, how_fake_classified)
  

  return vgg_loss
  


def discriminator_loss(real_output, fake_output):
    real_loss = discriminator_losses(tf.ones_like(real_output), real_output)
    fake_loss = discriminator_losses(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    
    return total_loss

def generator_loss(fake_output):

    fake_loss = generator_losses(tf.ones_like(fake_output), fake_output)
    
    return fake_loss

In [22]:
def train_step(real_imgs, lr_imgs):
  
  with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:

    fake_hr_imgs = generator(lr_imgs, training = True)

    real_outputs = discriminator(real_imgs, training = True)
    fake_outputs = discriminator(fake_hr_imgs, training = True)

    d_loss = discriminator_loss(real_output = real_outputs, fake_output = fake_outputs)
    g_loss = generator_loss(fake_output = fake_outputs)
    c_loss = content_loss(real_imgs = real_imgs, fake_imgs=fake_hr_imgs)

    SR_loss = c_loss + 0.001*g_loss
    
    
    g_gradient = gen_tape.gradient(SR_loss, generator.trainable_variables)
    d_gradient = dis_tape.gradient(d_loss, discriminator.trainable_variables)

    gen_optimizer.apply_gradients(zip(g_gradient, generator.trainable_variables))
    dis_optimizer.apply_gradients(zip(d_gradient, discriminator.trainable_variables))

    return d_loss, SR_loss

In [23]:
# # since it take so long to extract all element in imagenet_2013 dataset, I choose to extract them one by one and delete them after training.

# import tarfile
# d_losses = tf.keras.metrics.Mean(name="d_loss")
# g_losses = tf.keras.metrics.Mean(name="g_loss")

# sample_path = glob("/content/drive/MyDrive/every_Gan/imagenet_2013/*.tar")
# file_number = 0
# for tar in sample_path:

#   file_number += 1

#   result_file_with_number = file_path + "file_sample_" + str(file_number)
#   file_name = "/content/extracted_tar"

#   if not os.path.exists(file_name):
#     os.makedirs(file_name)

  
  
#   my_tar = tarfile.open(tar)
#   my_tar.extractall(file_name)
#   my_tar.close()

#   img_p_1 = glob(file_name + "/*/*.JPEG")
#   img_p_2 = glob(file_name + "/*.JPEG")

#   img_path = img_p_1 + img_p_2

#   size_check(img_path)

#   img_p_1 = glob(file_name + "/*/*.JPEG")
#   img_p_2 = glob(file_name + "/*.JPEG")

#   img_path = img_p_1 + img_p_2


#   for epoch in range(epochs):
#     for batch in tqdm(range(0,len(img_path) - 1, batch_size)):
      
#       image_64, image_128 = image_to_numpy(batch_size = batch_size, batch = batch, img_path = img_path)


#       d_loss, SR_loss = train_step(real_imgs = image_128, lr_imgs = image_64)

#       d_losses.update_state(d_loss)
#       g_losses.update_state(SR_loss)
      
#     if epoch % sample_every == 0:
#       print(f"epoch : {epoch}/{epochs}   d_loss : {d_losses.result():0.5f}   g_loss : {g_losses.result():0.5f} ")
#       sample_images(epoch, img_path)
#       # check_point.step.assign_add(1)
      
#       d_losses.reset_states()
#       g_losses.reset_states()

#       check_point_manager.save()
#   shutil.rmtree(file_name, ignore_errors = True)

In [24]:
img_path = glob("data/mirflickr/*.jpg")
# print(f"img_path len : {len(img_path)}")


# for i in img_path:
#   image_open = cv2.imread(i)
#   image_size = image_open.shape

  
#   if image_size[0] < 256 or image_size[1] < 256:   
#     os.remove(i)

# img_path = glob("data/mirflickr/*.jpg")
# print(f"img_path len : {len(img_path)}")

In [25]:
# since it take so long to extract all element in imagenet_2013 dataset, I choose to extract them one by one and delete them after training.

d_losses = tf.keras.metrics.Mean(name="d_loss")
g_losses = tf.keras.metrics.Mean(name="g_loss")



for epoch in range(epochs):
    for batch in tqdm(range(0,len(img_path) - 1, batch_size)):
      
      image_64, image_128 = image_to_numpy(batch_size = batch_size, batch = batch, img_path = img_path)


      d_loss, SR_loss = train_step(real_imgs = image_128, lr_imgs = image_64)

      d_losses.update_state(d_loss)
      g_losses.update_state(SR_loss)
      
    if epoch % sample_every == 0:
      print(f"epoch : {epoch}/{epochs}   d_loss : {d_losses.result():0.5f}   g_loss : {g_losses.result():0.5f} ")
      sample_images(epoch, img_path)
      # check_point.step.assign_add(1)
      
      d_losses.reset_states()
      g_losses.reset_states()

      check_point_manager.save()
  

100%|██████████| 12338/12338 [1:24:52<00:00,  2.42it/s]


epoch : 0/1000   d_loss : 0.43380   g_loss : 76.53303 


100%|██████████| 12338/12338 [1:25:48<00:00,  2.40it/s]
100%|██████████| 12338/12338 [1:25:51<00:00,  2.40it/s]
100%|██████████| 12338/12338 [1:23:51<00:00,  2.45it/s]
100%|██████████| 12338/12338 [1:23:10<00:00,  2.47it/s]
100%|██████████| 12338/12338 [1:24:34<00:00,  2.43it/s]


epoch : 5/1000   d_loss : 0.02949   g_loss : 56.07910 


100%|██████████| 12338/12338 [1:24:21<00:00,  2.44it/s]
100%|██████████| 12338/12338 [1:24:29<00:00,  2.43it/s]
100%|██████████| 12338/12338 [1:24:04<00:00,  2.45it/s]
100%|██████████| 12338/12338 [1:23:22<00:00,  2.47it/s]
100%|██████████| 12338/12338 [1:21:44<00:00,  2.52it/s]


epoch : 10/1000   d_loss : 0.01728   g_loss : 48.65018 


100%|██████████| 12338/12338 [1:21:32<00:00,  2.52it/s]
100%|██████████| 12338/12338 [1:21:28<00:00,  2.52it/s]
100%|██████████| 12338/12338 [1:21:25<00:00,  2.53it/s]
100%|██████████| 12338/12338 [1:21:39<00:00,  2.52it/s]
100%|██████████| 12338/12338 [1:22:22<00:00,  2.50it/s]


epoch : 15/1000   d_loss : 0.00313   g_loss : 45.89977 


100%|██████████| 12338/12338 [1:21:24<00:00,  2.53it/s]
100%|██████████| 12338/12338 [1:21:22<00:00,  2.53it/s]
100%|██████████| 12338/12338 [1:21:57<00:00,  2.51it/s]
100%|██████████| 12338/12338 [1:21:20<00:00,  2.53it/s]
100%|██████████| 12338/12338 [1:21:50<00:00,  2.51it/s]


epoch : 20/1000   d_loss : 0.00093   g_loss : 44.78990 


100%|██████████| 12338/12338 [1:21:56<00:00,  2.51it/s]
100%|██████████| 12338/12338 [1:15:55<00:00,  2.71it/s]
100%|██████████| 12338/12338 [1:17:47<00:00,  2.64it/s]
100%|██████████| 12338/12338 [1:20:50<00:00,  2.54it/s]
100%|██████████| 12338/12338 [1:19:41<00:00,  2.58it/s]


epoch : 25/1000   d_loss : 0.00095   g_loss : 43.72190 


100%|██████████| 12338/12338 [1:19:54<00:00,  2.57it/s]
100%|██████████| 12338/12338 [1:20:36<00:00,  2.55it/s]
100%|██████████| 12338/12338 [1:21:27<00:00,  2.52it/s]
100%|██████████| 12338/12338 [1:21:16<00:00,  2.53it/s]
100%|██████████| 12338/12338 [1:21:16<00:00,  2.53it/s]


epoch : 30/1000   d_loss : 0.00004   g_loss : 42.73145 


100%|██████████| 12338/12338 [1:21:15<00:00,  2.53it/s]
100%|██████████| 12338/12338 [1:21:17<00:00,  2.53it/s]
100%|██████████| 12338/12338 [1:21:26<00:00,  2.52it/s]
100%|██████████| 12338/12338 [1:20:42<00:00,  2.55it/s]
100%|██████████| 12338/12338 [1:19:43<00:00,  2.58it/s]


epoch : 35/1000   d_loss : 0.00060   g_loss : 41.76362 


100%|██████████| 12338/12338 [1:19:49<00:00,  2.58it/s]
100%|██████████| 12338/12338 [1:19:44<00:00,  2.58it/s]
100%|██████████| 12338/12338 [1:14:54<00:00,  2.75it/s]
100%|██████████| 12338/12338 [1:10:59<00:00,  2.90it/s]
100%|██████████| 12338/12338 [1:18:02<00:00,  2.63it/s]


epoch : 40/1000   d_loss : 0.00103   g_loss : 40.81795 


 54%|█████▍    | 6724/12338 [43:30<36:19,  2.58it/s]  


KeyboardInterrupt: 