In [None]:
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import trange, tqdm
from IPython.display import clear_output

In [None]:
target_filename = "woman.jpg"
source_filename = "chameleon.jpg"
target_img = Image.open("../inputs/" + target_filename)
source_img = Image.open("../inputs/" + source_filename)
source_img = source_img.resize(target_img.size)

In [None]:
def plot_images(source_arr, result_arr, target_arr, iteration_number=None):
    plt.figure(figsize=(15, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(source_arr)
    plt.axis("off")

    plt.subplot(1, 3, 2)
    if iteration_number:
        plt.title("iteration: %d" % iteration_number)
    plt.imshow(result_arr)
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(target_arr)
    plt.axis("off")
    plt.show()

In [None]:
def plot_hit_rate_and_mse(hit_rate_percent_list, mse_loss_list):
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.loglog(hit_rate_percent_list, label="Hit Rate")
    plt.loglog(np.ones_like(hit_rate_percent_list) * 0.1, label="0.01%")
    plt.xlabel("Iteration")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(mse_loss_list, label="RMSE loss")
    plt.xlabel("Iteration")
    plt.legend()

    plt.show()

In [None]:
def pixel_level_shuffle(source_arr, target_arr, permute_source=True, num_of_iter=1000, plot_after_iter=100, num_of_tries_in_iter=None, clear_before_plot=True):
    if permute_source:
        result_arr = np.random.permutation(source_arr.reshape(-1, 3)).reshape(*source_arr.shape)
    else:
        result_arr = np.array(source_arr)

    # plot initial values
    plot_images(source_arr, result_arr, target_arr, 0)

    rows, cols, _ = target_arr.shape
    if not num_of_tries_in_iter:
        num_of_tries_in_iter = rows * cols

    images = []
    mse_loss_list = []
    hit_rate_percent_list = []

    pbar = trange(num_of_iter)
    for iter in pbar:
        hit = 0

        selected_rows1 = np.random.randint(0, rows, num_of_tries_in_iter)
        selected_cols1 = np.random.randint(0, cols, num_of_tries_in_iter)
        selected_rows2 = np.random.randint(0, rows, num_of_tries_in_iter)
        selected_cols2 = np.random.randint(0, cols, num_of_tries_in_iter)

        for idx in range(selected_cols1.shape[0]):
            s1 = result_arr[selected_rows1[idx], selected_cols1[idx]]
            t1 = target_arr[selected_rows1[idx], selected_cols1[idx]]
            s2 = result_arr[selected_rows2[idx], selected_cols2[idx]]
            t2 = target_arr[selected_rows2[idx], selected_cols2[idx]]
            if np.abs(s1 - t2).sum() + np.abs(s2 - t1).sum() < np.abs(s1 - t1).sum() + np.abs(s2 - t2).sum():
                result_arr[selected_rows1[idx], selected_cols1[idx]] , result_arr[selected_rows2[idx], selected_cols2[idx]] = result_arr[selected_rows2[idx], selected_cols2[idx]], result_arr[selected_rows1[idx], selected_cols1[idx]]
                hit += 1

        hit_rate_percent_list.append(hit / num_of_tries_in_iter * 100)
        mse_loss_list.append(np.mean((result_arr - target_arr) ** 2) ** 0.5)
        pbar.set_description("Hit Rate %0.6f%%, loss: %0.4f" % (hit_rate_percent_list[-1], mse_loss_list[-1]))

        if (iter + 1) % plot_after_iter == 0:
            if clear_before_plot:
                clear_output()
            images.append(result_arr)
            plot_images(source_arr, result_arr, target_arr, len(mse_loss_list))
            plot_hit_rate_and_mse(hit_rate_percent_list, mse_loss_list)
        
    return hit_rate_percent_list, mse_loss_list, images

In [None]:
target_arr = np.array(target_img.resize((target_img.size[0] // 10, target_img.size[1] // 10)))
source_arr = np.array(source_img.resize((source_img.size[0] // 10, source_img.size[1] // 10)))
hit_rate_percent_list, mse_loss_list, images = pixel_level_shuffle(
    source_arr,
    target_arr,
    permute_source=False,
    num_of_iter=1000,
    plot_after_iter=100,
    clear_before_plot=False)

In [None]:
target_arr = np.array(target_img.resize((target_img.size[0] // 10, target_img.size[1] // 10)))
source_arr = np.array(source_img.resize((source_img.size[0] // 10, source_img.size[1] // 10)))
hit_rate_percent_list, mse_loss_list, images = pixel_level_shuffle(
    source_arr,
    target_arr,
    permute_source=True,
    num_of_iter=1000,
    plot_after_iter=100,
    clear_before_plot=False)

In [None]:
target_arr = np.array(target_img)
source_arr = np.array(source_img)
hit_rate_percent_list, mse_loss_list, images = pixel_level_shuffle(
    source_arr,
    target_arr,
    permute_source=False,
    num_of_iter=1000,
    plot_after_iter=100,
    clear_before_plot=False)

In [None]:
target_arr = np.array(target_img)
source_arr = np.array(source_img)
hit_rate_percent_list, mse_loss_list, images = pixel_level_shuffle(
    source_arr,
    target_arr,
    permute_source=True,
    num_of_iter=1000,
    plot_after_iter=100,
    clear_before_plot=False)