# DATA Load
kaggle에 업로드되어 있는 danbooru2020 dataset을 사용.<br>
kaggle: https://www.kaggle.com/muoncollider/danbooru2020

In [None]:
!pip install kaggle

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from google.colab import files
files.upload()

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d muoncollider/danbooru2020

In [None]:
!unzip danbooru2020.zip -d /content/test_folder

# MODEL architecture
model architecture는 Swinir 모델 사용<br>
model 설명: https://github.com/alzoqm/transformer_model/tree/main/models/swinIR


In [None]:
%cd drive/MyDrive/ColabNotebooks/project/super_resolution

/content/drive/MyDrive/ColabNotebooks/project/super_resolution/model


In [None]:
!ls

new_swinir_sr.ipynb  save_weights	   swinir_tf.py
__pycache__	     SwinIR_SR_test.ipynb


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
import os
import cv2
import tqdm
import torch
import math
import h5py
import imageio
import random

from tqdm import tqdm

from model import swinir_tf as swinir #기존 model load

In [None]:
#####PARAMETERS######
IMG_SIZE = 64
PATCH_SIZE = 1
IN_CHANS = 3
EMB_SIZE = 180
DEPTHS = [6, 6, 6, 6]
NUM_HEADS = [6, 6, 6, 6]
WINDOW_SIZE = 4
MLP_RATIO = 4
QKV_BIAS = True
DROP_RATE = 0.1
ATTN_DROP_RATE = 0.1
DROP_PATH_RATE = 0.1
APE = False
PATCH_NORM = True
UPSCALE = 2
IMG_RANGE = 255
RESI_CONNECTION = '3conv'

# TPU 학습 조성 및 drive 불러오기

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])

tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)

strategy = tf.distribute.TPUStrategy(resolver)

# DATA preprocessing
colab tpu를 사용하더라도 512 x 512 이미지를 한번에 학습을 할 수가 없기에 이미지를 64 x 64로 나누어서 학습.<br>
한번에 모든 데이터를 불러오기에는 데이터 크기에 비해 ram이 부족하므로 데이터를 나누어서 불러올 수 있는 함수를 생성(dataload_function).<br>
이 함수는 아래의 training 상황에서 epoch 값에 해당하는 데이터를 불러옴.

In [None]:
def image_slice(image, patch_size):
  slice_list = []
  height, width = image.shape[0], image.shape[1]
  height_slice = height // patch_size[0]
  width_slice = width // patch_size[1]

  for i in range(height_slice):
    for j in range(width_slice):
      slice_img = image[i * patch_size[0] : (i+1) * patch_size[0], j * patch_size[1] : (j+1) * patch_size[1], :]
      slice_img = np.array(slice_img)
      slice_list.append(slice_img)

  return np.array(slice_list)

In [None]:
danbooru_path = '/content/test_folder'
danbooru_list = os.listdir(danbooru_path)
accl = []

for danbooru in danbooru_list:
  data = os.listdir(danbooru_path + '/' + danbooru)
  accl.append(data)

def dataload_function(cnt, accl, danbooru_path, danbooru_list):
  all_img_SR = []
  all_img_LR = []
  for idx, line in enumerate(accl):
    if idx == cnt:
      for pic in line:
        image_rgb_SR = cv2.imread(danbooru_path + '/' + danbooru_list[idx] + '/' + pic, cv2.IMREAD_COLOR)
        image_rgb_LR = cv2.resize(image_rgb_SR, (256, 256)) #LR 이미지 크기

        image_rgb_SR = image_slice(image_rgb_SR, (256, 256))
        image_rgb_LR = image_slice(image_rgb_LR, (128, 128)) #이미지 크기
        for image_SR in image_rgb_SR:
          all_img_SR.append(image_SR)


        for image_LR in image_rgb_LR:
          all_img_LR.append(image_LR)


        del image_rgb_SR
        del image_rgb_LR
  
  length = len(all_img_LR)
  length = length // global_batch_size
  all_img_SR = np.array(all_img_SR)
  all_img_LR = np.array(all_img_LR)
  return all_img_LR, all_img_SR, length

In [None]:
BATCH_SIZE_PER_REPLICA = 1
global_batch_size = (BATCH_SIZE_PER_REPLICA *
                     strategy.num_replicas_in_sync)

In [None]:
def loss_function(y_true, y_pred): 
  loss = tf.keras.losses.MAE(y_true, y_pred)

  return loss

In [None]:
path = '/content/drive/MyDrive/ColabNotebooks/project/super_resolution/save_weights/new_swin_ir_64_save_weights.h5'

In [None]:
with strategy.scope():
  model = swinir.swinIR(IMG_SIZE, PATCH_SIZE, IN_CHANS, EMB_SIZE, DEPTHS, NUM_HEADS, WINDOW_SIZE, MLP_RATIO, QKV_BIAS, DROP_RATE, ATTN_DROP_RATE, DROP_PATH_RATE, APE, PATCH_NORM, UPSCALE, IMG_RANGE, RESI_CONNECTION)
  optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  
  x = tf.random.normal(shape=(1, 64, 64, 3))
  output = model(x)
  model.load_weights(path)

# training
한 번에 모든 데이터를 불러와서 학습하기에는 학습 데이터가 너무 많고 ram이 부족함.<br>
따라서 데이터를 나누어서 가져온 뒤 그 데이터를 삭제하고, 다음 데이터를 불러오는 방식을 사용함.

In [None]:
def train_step(inputs):
  all_img_LR, all_img_SR = inputs
  all_img_LR = tf.cast(all_img_LR, dtype=tf.float32)
  with tf.GradientTape() as tape:
    logits = model(all_img_LR, training=True)
    loss = loss_function(all_img_SR, logits)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  del all_img_LR, all_img_SR, logits

  return loss 

@tf.function
def distributed_train_step(inputs):
  per_replica_losses = strategy.run(train_step, args=(inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)
  
for epoch in range(13, len(accl)):
  all_img_LR, all_img_SR, length = dataload_function(epoch, accl, danbooru_path, danbooru_list) 
  total_loss = 0.0
  step_loss = 0.0
  num_batch = 0

  with tqdm(total=length, desc=f"Train_file_number({epoch})") as pbar:
    for input_len in range(length):
      dataset = tf.data.Dataset.from_tensor_slices((
          all_img_LR[input_len * global_batch_size : (input_len * global_batch_size) + global_batch_size], all_img_SR[input_len * global_batch_size : (input_len * global_batch_size) + global_batch_size]
      )) #ram 절약을 위해 하나씩만 담음
      dataset = dataset.cache()
      dataset = dataset.batch(global_batch_size)
      dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
      dataset = strategy.experimental_distribute_dataset(dataset)

      for x in dataset:
        step_loss = distributed_train_step(x)
        total_loss += tf.reduce_mean(step_loss)
        num_batch += 1

      pbar.update(1)

      if num_batch % 400 == 99:
        print(f"epoch: {epoch}, step: {num_batch}, loss:{total_loss / num_batch}")
        model.save_weights(path, overwrite=True) #model이 크기 때문에 학습 중간 중간에 모델 weights를 저장

  print(f"epoch: {epoch}, loss: {total_loss / num_batch}")

  del all_img_LR
  del all_img_SR

  model.save_weights(path, overwrite=True)