In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import zipfile

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
%cd /content/drive/MyDrive/데이콘 대회/이미지품질향상/code

/content/drive/MyDrive/데이콘 대회/이미지품질향상/code


# 학습 모델 불러오기

In [4]:
import Generator
from Generator import Generator
import Discriminator
from Discriminator import Discriminator

In [5]:
Generator_class = Generator()
generator = Generator_class.make_generator()
Discriminator_class = Discriminator()
discriminator = Discriminator_class.make_discriminator()

In [6]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [7]:
%cd /content/drive/MyDrive/데이콘 대회/이미지품질향상/code/training_checkpoints

/content/drive/MyDrive/데이콘 대회/이미지품질향상/code/training_checkpoints


In [8]:
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [9]:
checkpoint_dir = "/content/drive/MyDrive/데이콘 대회/이미지품질향상/code/training_checkpoints"

In [10]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fbde00d4b10>

# 학습 이미지 테스트 하기

test_data_loader

In [11]:
Root_Path = "/content/drive/MyDrive/데이콘 대회/이미지품질향상/"

In [12]:
test_csv = pd.read_csv(Root_Path + "test.csv")
test_input_files = Root_Path + 'test_input_img/'+test_csv['input_img']

In [28]:
def predict(img_paths, stride=256, batch_size=128):
    img_size = 256
    results = []
    for img_path in img_paths:
        img = cv2.imread(img_path)
        img = img.astype(np.float32)/255
        crop = []
        position = []
        batch_count = 0

        result_img = np.zeros_like(img)
        voting_mask = np.zeros_like(img)

        for top in tqdm(range(0, img.shape[0], stride)):
            for left in range(0, img.shape[1], stride):
                piece = np.zeros([img_size, img_size, 3], np.float32)
                temp = img[top:top+img_size, left:left+img_size, :]
                piece[:temp.shape[0], :temp.shape[1], :] = temp
                crop.append(piece)
                position.append([top, left])
                batch_count += 1
                if batch_count == batch_size:
                    crop = np.array(crop)
                    pred = generator(crop)*255
                    crop = []
                    batch_count = 0
                    for num, (t, l) in enumerate(position):
                        piece = pred[num]
                        h, w, c = result_img[t:t+img_size, l:l+img_size, :].shape
                        result_img[t:t+img_size, l:l+img_size, :] += piece[:h, :w]
                        voting_mask[t:t+img_size, l:l+img_size, :] += 1
                    position = []
        
        result_img = result_img/voting_mask
        result_img = result_img.astype(np.uint8)
        results.append(result_img)
        
    return results

In [31]:
test_result = predict(test_input_files, stride = 32)

100%|██████████| 77/77 [00:35<00:00,  2.19it/s]
100%|██████████| 77/77 [00:34<00:00,  2.20it/s]
100%|██████████| 77/77 [00:34<00:00,  2.21it/s]
100%|██████████| 77/77 [00:34<00:00,  2.21it/s]
100%|██████████| 77/77 [00:35<00:00,  2.20it/s]
100%|██████████| 77/77 [00:34<00:00,  2.20it/s]
100%|██████████| 77/77 [00:34<00:00,  2.21it/s]
100%|██████████| 77/77 [00:34<00:00,  2.20it/s]
100%|██████████| 77/77 [00:34<00:00,  2.21it/s]
100%|██████████| 77/77 [00:34<00:00,  2.20it/s]
100%|██████████| 77/77 [00:35<00:00,  2.20it/s]
100%|██████████| 77/77 [00:34<00:00,  2.21it/s]
100%|██████████| 77/77 [00:34<00:00,  2.20it/s]
100%|██████████| 77/77 [00:34<00:00,  2.21it/s]
100%|██████████| 77/77 [00:34<00:00,  2.20it/s]
100%|██████████| 77/77 [00:34<00:00,  2.20it/s]
100%|██████████| 77/77 [00:34<00:00,  2.21it/s]
100%|██████████| 77/77 [00:35<00:00,  2.20it/s]
100%|██████████| 77/77 [00:34<00:00,  2.20it/s]
100%|██████████| 77/77 [00:35<00:00,  2.19it/s]


In [32]:
for i, input_path in enumerate(test_input_files):
    input_img = cv2.imread(input_path)
    input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
    pred_img = test_result[i]
    pred_img = cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB)
    
    plt.figure(figsize=(20,10))
    plt.subplot(1,2,1)
    plt.imshow(input_img)
    plt.title('input_img', fontsize=10)
    plt.subplot(1,2,2)
    plt.imshow(pred_img)
    plt.title('output_img', fontsize=10)
    plt.show()

Output hidden; open in https://colab.research.google.com to view.

In [38]:
%cd /content/drive/MyDrive/데이콘 대회/이미지품질향상/

/content/drive/MyDrive/데이콘 대회/이미지품질향상


In [39]:
def make_submission(result):
    os.makedirs('submission', exist_ok=True)
    os.chdir("./submission/")
    sub_imgs = []
    for i, img in enumerate(result):
        path = f'test_{20000+i}.png'
        cv2.imwrite(path, img)
        sub_imgs.append(path)
    submission = zipfile.ZipFile("submission.zip", 'w')
    for path in sub_imgs:
        submission.write(path)
    submission.close()

In [40]:
make_submission(test_result)