In [1]:
import tensorflow as tf

In [2]:
from tensorflow.keras.layers import Dense, Conv2D, Conv2DTranspose, Flatten, Input, MaxPooling2D, LeakyReLU, BatchNormalization, PReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
from tensorflow.keras.preprocessing import image_dataset_from_directory

from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input

# from PIL import Image
# from tensorflow.keras.preprocessing import image
# from tensorflow.keras.preprocessing.image import load_img, img_to_array
# from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model, Sequential
import cv2

import random
import numpy as np
import matplotlib.pyplot as plt
from IPython.core.pylabtools import figsize
from glob import glob
import os
from tqdm import tqdm
import shutil
import tarfile

In [3]:
from google.colab import drive

drive.mount("/content/drive")

Mounted at /content/drive


In [4]:
# !unzip -q /content/drive/MyDrive/every_Gan/grayscale_data/mirflickr25k.zip

In [5]:
# !tar -xvf /content/drive/MyDrive/Coding/Data/ILSVRC2013_DET_train.tar

In [6]:
# !mkdir data
# !mv /content/mirflickr/*.jpg data

In [7]:
# img_path = glob("data/*.JPEG")
# print(len(img_path))
# train_path = "data"

In [56]:
def size_check(img_path):
  counter = 0
  for ii in img_path:
    image_opener = cv2.imread(ii)
    if image_opener.shape[0] < 256 and image_opener.shape[1] < 256:
      os.remove(ii)
  

흑백사진 조심하기 2트

In [57]:
image_size = (32,32,3)
H = image_size[0]
W = image_size[1]
C = image_size[2]

target_size = (64,64,3)
T_H = target_size[0]
T_W = target_size[1]
T_C = target_size[2]

batch_size = 8
epochs = 150
sample_every = 10

In [58]:
def g_layer_block(kernel_size, units, strides, x):

  x = Conv2D(filters = units, kernel_size= kernel_size, strides = strides, padding = "same")(x)
  
  x = BatchNormalization(momentum = 0.5)(x)
  x = PReLU(shared_axes = [1,2])(x)
  x = Conv2D(filters = units, kernel_size= kernel_size, strides = strides, padding = "same")(x)
  
  
  x = BatchNormalization()(x)
  
  return x


In [59]:
def g_layer_block_2(kernel_size, units, strides, x, scale):
  x = Conv2D(filters = units, kernel_size= kernel_size, strides = strides, padding = "same")(x)
  x = tf.nn.depth_to_space(x,scale)
  x = PReLU(shared_axes = [1,2])(x)

  return x

In [60]:
def g_layer_block_3(x, connect_x):
  x = g_layer_block(kernel_size=3, units = H, strides = 1, x= x)
  x = tf.keras.layers.add([connect_x, x])
  new_connect_x = x

  return x, new_connect_x

In [61]:
def build_generator(img_size):
  i = Input(shape = img_size)
  x = Conv2D(filters = H, kernel_size = 9, strides = 1, padding = "same")(i)
  x = PReLU(shared_axes = [1,2])(x)

  connect_x = x
  connect_x_2 = connect_x

  num_residual = 16
  for j in range(num_residual):
    x, connect_x = g_layer_block_3(x, connect_x= connect_x)
  
  
  x = Conv2D(kernel_size=3, filters = H, strides = 1, padding = "same")(x)
  x = BatchNormalization(momentum = 0.5)(x)
  x = tf.keras.layers.add([connect_x_2,x])

  x = g_layer_block_2(kernel_size=3, units= H*4, strides= 1, x= x, scale = 2)

  # x = g_layer_block_2(kernel_size=3, units= H*4, strides= 1, x= x, scale = 2)
  # x = Conv2D(filters= 3, kernel_size=9, strides = 1, padding = "same" )(x)

  x = Conv2D(filters= 3, kernel_size=9, strides = 1, padding = "same", activation = "tanh")(x)

  model = Model(inputs = i, outputs = x)
  return model


  



In [62]:
def d_layer_block(kernel_size,units,strides,x):
  x = Conv2D(filters = units, kernel_size= kernel_size, strides = strides, padding = "same")(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = LeakyReLU()(x)

  return x

In [63]:
def build_discriminator(img_size):
  i = Input(shape = img_size)
  x = Conv2D(filters=H, kernel_size=3, strides = 1, activation= LeakyReLU())(i)
  x = d_layer_block(3,H,2,x)
  x = d_layer_block(3,H*2,1,x)
  x = d_layer_block(3,H*2,2,x)
  x = d_layer_block(3,H*4,1,x)
  x = d_layer_block(3,H*4,2,x)
  x = d_layer_block(3,H*8,1,x)
  x = d_layer_block(3,H*8,2,x)

  x = Flatten()(x)
  x = Dense(H*16)(x)
  x = LeakyReLU()(x)
  x = Dense(1, activation = "sigmoid")(x)

  model = Model(inputs = i, outputs = x)
  return model

In [64]:
discriminator = build_discriminator(img_size=(T_H,T_W,T_C))

generator = build_generator(img_size = (H,W,C))

z = Input(shape = (H,W,C))

sample_img= generator(z, training = False)

fake_pre = discriminator(sample_img, training = False)

In [65]:
generator.summary()

Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_7 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d_52 (Conv2D)             (None, 32, 32, 32)   7808        ['input_7[0][0]']                
                                                                                                  
 p_re_lu_18 (PReLU)             (None, 32, 32, 32)   32          ['conv2d_52[0][0]']              
                                                                                                  
 conv2d_53 (Conv2D)             (None, 32, 32, 32)   9248        ['p_re_lu_18[0][0]']             
                                                                                            

Pretrained model -> VGG19_model -> perceptual_model? -> content loss

In [66]:
vgg = VGG19(include_top= False, weights = "imagenet", input_shape = (T_H, T_W, T_C))
# vgg = VGG19(include_top= False, weights = "imagenet", input_shape = (None, T_W, T_C))
img = Input(shape = (T_H, T_W, T_C))

VGG19_model = Model(inputs= vgg.input, outputs = vgg.layers[10].output)

VGG19_model.trainable = False

In [67]:
VGG19_model.summary()

Model: "model_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_9 (InputLayer)        [(None, 64, 64, 3)]       0         
                                                                 
 block1_conv1 (Conv2D)       (None, 64, 64, 64)        1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 64, 64, 64)        36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 32, 32, 64)        0         
                                                                 
 block2_conv1 (Conv2D)       (None, 32, 32, 128)       73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 32, 32, 128)       147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 16, 16, 128)       0   

In [68]:
file_path = "/content/drive/MyDrive/Coding/image_result/Pieced_final_sr/"
if not os.path.exists(file_path):
  os.makedirs(file_path)

In [69]:
check_path = "/content/drive/MyDrive/Coding/check_point/Pieced_final_sr"
if not os.path.exists(check_path):
  os.makedirs(check_path)
check_point = tf.train.Checkpoint(generator = generator, discriminator = discriminator)
check_point_manager = tf.train.CheckpointManager(checkpoint=check_point, directory= check_path, max_to_keep= 5)

In [70]:
check_point.restore(check_point_manager.latest_checkpoint)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f817bc7e290>

In [71]:
def image_to_numpy(batch_size, batch, img_path):

  N_minibatch = 256/T_H
  
  batch_np_64 = np.zeros((1, H, W, C))
  batch_np_128 = np.zeros((1, T_H, T_W, T_C))

  for img_in_path in img_path[batch:batch + batch_size]:
    image_open = cv2.imread(img_in_path)
    image_open = cv2.cvtColor(image_open, cv2.COLOR_BGR2RGB)

    image_128, image_64 = minibatch_generator(image_open)

    batch_np_64 = np.append(batch_np_64, image_64, axis = 0)
    batch_np_128 = np.append(batch_np_128, image_128, axis = 0)

  batch_np_64 = np.delete(batch_np_64, 0, axis = 0)
  batch_np_128 = np.delete(batch_np_128, 0, axis = 0)

  batch_np_64 = ((batch_np_64/255.0)-0.5)*2
  batch_np_128 = ((batch_np_128/255.0)-0.5)*2

  return batch_np_64, batch_np_128

In [72]:
def minibatch_generator(image_open):
    image_256 = cv2.resize(image_open, (256,256))

    N_minibatch = 256/T_H
    minibatch_128 = np.zeros((1, T_H, T_W, 3))
    minibatch_64 = np.zeros((1, H,W,3))

    for i in range(int(N_minibatch)):
      for j in range(int(N_minibatch)):

        mini_mini_batch_128 = image_256[i*T_H:(i+1)*T_H, j*T_W:(j+1)*T_W, 0:3]
        mini_mini_batch_64 = cv2.resize(mini_mini_batch_128, (H,W))

        mini_mini_batch_128 = np.expand_dims(mini_mini_batch_128, axis = 0)
        minibatch_128 = np.append(minibatch_128, mini_mini_batch_128, axis = 0)

        mini_mini_batch_64 = np.expand_dims(mini_mini_batch_64, axis = 0)
        minibatch_64 = np.append(minibatch_64, mini_mini_batch_64, axis = 0)

    minibatch_128 = np.delete(minibatch_128, 0, axis = 0)
    minibatch_64 = np.delete(minibatch_64, 0, axis = 0)

    return minibatch_128, minibatch_64

In [73]:
def plot_result(epoch, img_path):
    epoch = epoch
    original_size = 256
    input_image_size = 32
    image_size = 64

    rows = int(original_size/image_size)
    cols = 3

    fig, ax = plt.subplots(rows,cols, figsize = (20,40))
    # fig.tight_layout()
    plt.axis("off")

    rand_index = np.random.choice(len(img_path))
    img_open = cv2.imread(img_path[rand_index])
    real_img = cv2.cvtColor(img_open, cv2.COLOR_BGR2RGB)
    if original_size != 512:
      real_img = cv2.resize(real_img, (original_size, original_size))
      
    real_img = real_img/255.0

    cropped_images = minibatch_generator_1(real_img, input_image_size)

    # Remember that the input is from (-1,1)
    cropped_images = (cropped_images-0.5)*2

    
    start_Index = 0
    end_Index = 0
    add_Index = 0

    for Number_of_images in range(rows):

        blank_image_fake = np.zeros(((Number_of_images +1)*image_size,(Number_of_images +1)*image_size,3))

        start_Index += (Number_of_images)**2
        add_Index = (Number_of_images + 1)**2
        end_Index = start_Index + add_Index


        image_pieces = cropped_images[start_Index:end_Index, 0:image_size, 0:image_size, 0:3]
        
        # print(f"index counter : {start_Index}  end_index : {end_Index}   number_of images : {Number_of_images}")        

        counter = 0
        for image_row in range(0,Number_of_images + 1,1):
            for image_column in range(0,Number_of_images + 1,1):

              fake_img = generator(image_pieces, training = False)

              # Remember that the output is (-1,1), change it to (0,1) for plotting
              fake_img = (fake_img + 1)/2

              
              blank_image_fake[image_row*image_size:(image_row+1)*image_size, image_column*image_size:(image_column+1)*image_size, 0:3] = fake_img[counter]
              counter += 1

        plot_image = cv2.imread(img_path[rand_index])
        plot_image = cv2.resize(plot_image, ((Number_of_images +1)*image_size,(Number_of_images +1)*image_size))
        plot_image = cv2.cvtColor(plot_image, cv2.COLOR_BGR2RGB)

        input_image = cv2.resize(plot_image,((Number_of_images +1)*input_image_size,(Number_of_images +1)*input_image_size))
        

        complete_fake_img = (blank_image_fake).astype(np.float32)
        plot_image = (plot_image/255.0).astype(np.float32)
        input_image = (input_image/255.0).astype(np.float32)
        
      

        ax[Number_of_images][0].imshow(input_image)
        ax[Number_of_images][1].imshow(complete_fake_img)
        ax[Number_of_images][2].imshow(plot_image)


        


    while True:
      file_name = file_path + f"/%d.png" %epoch
      
      if not os.path.exists(file_name):
        break     
      epoch += 2

    plt.savefig(file_name)
    plt.close()

In [74]:
discriminator_losses = tf.keras.losses.BinaryCrossentropy()
generator_losses = tf.keras.losses.BinaryCrossentropy()
content_losses = tf.keras.losses.MeanSquaredError()

learning_rate = PiecewiseConstantDecay(boundaries = [100000], values = [1e-4, 1e-5])

gen_optimizer = Adam(learning_rate = learning_rate)
dis_optimizer = Adam(learning_rate = learning_rate)


def content_loss(real_imgs, fake_imgs):

  # change image rgb range from (-1,1) to (0,1)
  real_imgs = (real_imgs + 1)/2
  fake_imgs = (fake_imgs + 1)/2

  real_imgs = preprocess_input(real_imgs)
  fake_imgs = preprocess_input(fake_imgs)

  how_real_classified = VGG19_model(real_imgs)
  how_fake_classified = VGG19_model(fake_imgs)

  vgg_loss = content_losses(how_real_classified, how_fake_classified)
  

  return vgg_loss
  


def discriminator_loss(real_output, fake_output):
    real_loss = discriminator_losses(tf.ones_like(real_output), real_output)
    fake_loss = discriminator_losses(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    
    return total_loss

def generator_loss(fake_output):

    fake_loss = generator_losses(tf.ones_like(fake_output), fake_output)
    
    return fake_loss

In [75]:
# sample_img_path = glob("/content/extracted_tar/n01503061/*.JPEG")
# plot_result(99999, sample_img_path)

In [76]:
def minibatch_generator_1(image_open, image_size):
    original_size = 256
    max_pieces = int(original_size/image_size)

    N_minibatch = original_size/H
    minibatch_128 = np.zeros((1, H, W, 3))

    for N_pieces in range(max_pieces):
      full_image_size = image_size*(N_pieces + 1)
      image_open_resized = cv2.resize(image_open, (full_image_size, full_image_size))

      for i in range(N_pieces + 1):
        for j in range(N_pieces + 1):
          mini_mini_batch_128 = image_open_resized[i*image_size:(i+1)*image_size, j*image_size:(j+1)*image_size, 0:3]
          mini_mini_batch_128 = np.expand_dims(mini_mini_batch_128, axis = 0)
          minibatch_128 = np.append(minibatch_128, mini_mini_batch_128, axis = 0)

    minibatch_128 = np.delete(minibatch_128, 0, axis = 0)

    return minibatch_128

In [77]:
def train_step(real_imgs, lr_imgs):
  
  with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:

    fake_hr_imgs = generator(lr_imgs, training = True)

    real_outputs = discriminator(real_imgs, training = True)
    fake_outputs = discriminator(fake_hr_imgs, training = True)

    d_loss = discriminator_loss(real_output = real_outputs, fake_output = fake_outputs)
    g_loss = generator_loss(fake_output = fake_outputs)
    c_loss = content_loss(real_imgs = real_imgs, fake_imgs=fake_hr_imgs)

    SR_loss = c_loss + 0.001*g_loss
    
    
    g_gradient = gen_tape.gradient(SR_loss, generator.trainable_variables)
    d_gradient = dis_tape.gradient(d_loss, discriminator.trainable_variables)

    gen_optimizer.apply_gradients(zip(g_gradient, generator.trainable_variables))
    dis_optimizer.apply_gradients(zip(d_gradient, discriminator.trainable_variables))

    return d_loss, SR_loss

In [None]:
d_losses = tf.keras.metrics.Mean(name="d_loss")
g_losses = tf.keras.metrics.Mean(name="g_loss")

import tarfile
import random
d_losses = tf.keras.metrics.Mean(name="d_loss")
g_losses = tf.keras.metrics.Mean(name="g_loss")

sample_path = glob("/content/drive/MyDrive/every_Gan/imagenet_2013/*.tar")
np.random.shuffle(sample_path)
file_number = 0
for tar in sample_path:

  file_number += 1

  result_file_with_number = file_path + "file_sample_" + str(file_number)
  file_name = "/content/extracted_tar"

  if not os.path.exists(file_name):
    os.makedirs(file_name)

  
  
  my_tar = tarfile.open(tar)
  my_tar.extractall(file_name)
  my_tar.close()

  img_p_1 = glob(file_name + "/*/*.JPEG")
  img_p_2 = glob(file_name + "/*.JPEG")

  img_path = img_p_1 + img_p_2

  size_check(img_path)

  img_p_1 = glob(file_name + "/*/*.JPEG")
  img_p_2 = glob(file_name + "/*.JPEG")

  img_path = img_p_1 + img_p_2


  for epoch in range(epochs):
    for batch in tqdm(range(0,len(img_path) - 1, batch_size)):
      
      # image to numpy change image into numpy and return (-1,1) range rgb data
      image_64, image_128 = image_to_numpy(batch_size = batch_size, batch = batch, img_path = img_path)


      d_loss, SR_loss = train_step(real_imgs = image_128, lr_imgs = image_64)

      d_losses.update_state(d_loss)
      g_losses.update_state(SR_loss)
      
    if epoch % sample_every == 0:
      print(f"epoch : {epoch}/{epochs}   d_loss : {d_losses.result():0.5f}   g_loss : {g_losses.result():0.5f} ")
      # sample_images(epoch, img_path)
      plot_result(epoch, img_path)
      # check_point.step.assign_add(1)
      
      d_losses.reset_states()
      g_losses.reset_states()

      # check_point_manager.save()
  shutil.rmtree(file_name, ignore_errors = True)

100%|██████████| 355/355 [03:50<00:00,  1.54it/s]


epoch : 0/150   d_loss : 0.06091   g_loss : 1.95870 


100%|██████████| 355/355 [03:49<00:00,  1.55it/s]
100%|██████████| 355/355 [03:48<00:00,  1.55it/s]
100%|██████████| 355/355 [03:48<00:00,  1.55it/s]
100%|██████████| 355/355 [03:48<00:00,  1.56it/s]
100%|██████████| 355/355 [03:48<00:00,  1.55it/s]
100%|██████████| 355/355 [03:48<00:00,  1.55it/s]
100%|██████████| 355/355 [03:48<00:00,  1.55it/s]
100%|██████████| 355/355 [03:47<00:00,  1.56it/s]
100%|██████████| 355/355 [03:48<00:00,  1.55it/s]
100%|██████████| 355/355 [03:48<00:00,  1.56it/s]


epoch : 10/150   d_loss : 0.00368   g_loss : 1.69171 


100%|██████████| 355/355 [03:48<00:00,  1.55it/s]
100%|██████████| 355/355 [03:48<00:00,  1.55it/s]
100%|██████████| 355/355 [03:48<00:00,  1.55it/s]
100%|██████████| 355/355 [03:48<00:00,  1.56it/s]
 29%|██▊       | 102/355 [01:05<02:40,  1.57it/s]