In [None]:
import tensorflow as tf

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
import numpy as np
import pathlib
import cv2
import resnet_network
import image_preprocess

AUTOTUNE = tf.data.AUTOTUNE

In [None]:
OUTPUT_CHANNELS = 3

generator_g_new = resnet_network.build_generator_resnet_9blocks(skip=False)
generator_f_new = resnet_network.build_generator_resnet_9blocks(skip=False)

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_path = 'directory_for_loading_checkpoints'

latest = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_path)

ckpt = tf.train.Checkpoint(generator_f = generator_f_new,
                           generator_g = generator_g_new,
                           discriminator_x = discriminator_x,
                           discriminator_y = discriminator_y,
                           generator_f_optimizer = generator_f_optimizer,
                           generator_g_optimizer = generator_g_optimizer,
                           discriminator_x_optimizer = discriminator_x_optimizer,
                           discriminator_y_optimizer = discriminator_y_optimizer)

if latest:
    print(f'Loading from latest checkpoint: {latest}')
    ckpt.restore(latest)
else:
    print('No checkpoint found')

In [None]:
test_file_dir = pathlib.Path(r'data_directory')

test_files = tf.data.Dataset.list_files(str(test_file_dir/'*.png'), shuffle=False)
test_images = tf.data.Dataset.map(test_files, image_preprocess.load_image).cache()
img_num = len(test_files)

for index, image in test_images.enumerate():
    clear_output(wait=True)
    print(f'current processing: {index+1} out of {len(test_images)} images')
    plt.imshow(image.numpy().astype('float32'))
    fake_image = generator_g_new(image[tf.newaxis,...], training=False)[0]
    cv2.imwrite(str(test_file_dir/f'{1+index:03}.bmp'), cv2.cvtColor((fake_image.numpy()+1)*127.5, cv2.COLOR_BGR2RGB))