## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings
from typing import List

import torch
import numpy as np
from diffusers import DDIMScheduler

import matplotlib.pyplot as plt
from PIL import PngImagePlugin
from IPython.display import clear_output

sys.path.append("..")
from src.u2net import U2NET
from src.unet import UNet
from src.fid_score import calculate_frechet_distance

from src.tools import (
    set_random_seed,
    get_linked_pushed_loader_metrics,
    get_linked_pushed_loader_stats,
)
from src.plotters import (
    plot_linked_pushed_images,
    plot_linked_pushed_random_paired_images,
)

from src.samplers import PairedLoaderSampler, get_paired_sampler

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Config

Dataset choosing in the first rows


In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
# face2comic
# DATASET, DATASET_PATH, MAP_NAME, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', "face2comic", False
# colored mask -> face
DATASET, DATASET_PATH, MAP_NAME, REVERSE = (
    "celeba_mask",
    "../datasets/CelebAMask-HQ",
    "colored_mask2face",
    False,
)
# sketch -> photo
# DATASET, DATASET_PATH, MAP_NAME, REVERSE = (
#     "FS2K",
#     "../datasets/FS2K/",
#     "sketch2photo",
#     False,
# )

IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [50, 80, 120]

# GPU choosing
DEVICE_IDS = [1]
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_IDS[0]}")

# first is for step, setting the value (checkpoints step + 1);
# last is for sdes, setting the value be (num of train-finished sde + 1).
CONTINUE = [0, 0]

# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
STRATEGY = "Fix"  # 'Fix' or 'Adapt'
BATCH_SIZE = 8
# model settings
NOT = True  # Train Neural optimal transport or pure regression
T_TYPE = "U2Net"  # 'UNet' # or  ('ResNet_pix2pix' - not implemented)
UNET_BASE_FACTOR = 48  # For UNet
D_TYPE = (
    "ResNet"  # or 'ResNet_pix2pix' - DOES NOT WORK WELL (it is actually not a resnet:)
)
D_USE_BATCHNORM = False  # For ResNet_D

# plot settings
GRAY_PLOTS = False
PLOT_N_SAMPLES = 8

FID_EPOCHS = 1

EXP_NAME = f"DNOT_Paired_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise Exception("no such file or directory...")

In [None]:
use_Y = not REVERSE
if use_Y:
    filename = f"../stats/{DATASET}_{MAP_NAME.split('2')[1]}_{IMG_SIZE}_test.json"
else:
    filename = f"../stats/{DATASET}_{MAP_NAME.split('2')[0]}_{IMG_SIZE}_test.json"

with open(filename, "r") as fp:
    data_stats = json.load(fp)
    mu_data, sigma_data = data_stats["mu"], data_stats["sigma"]
del data_stats

## 3. Initialize samplers


In [None]:
_, XY_test_sampler = get_paired_sampler(
    DATASET, DATASET_PATH, img_size=IMG_SIZE, batch_size=BATCH_SIZE, reverse=REVERSE
)

torch.cuda.empty_cache()
gc.collect()
clear_output()

### pivotal sampler


In [None]:
SCHEDULER = DDIMScheduler(num_train_timesteps=DIFFUSION_STEPS)


def sample_all_pivotal(
    XY_sampler: PairedLoaderSampler,
    batch_size: int = 4,
) -> List[torch.Tensor]:
    pivotal_path = []

    source, target = XY_sampler.sample(batch_size)

    source_list = [source]
    target_list = [target]
    for i in range(min(DIFFUSION_STEPS, PIVOTAL_LIST[-1])):
        source = SCHEDULER.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = SCHEDULER.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in PIVOTAL_LIST:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list[1:])  # just using source's last pivotal point
    # pivotal_path.extend(target_list[:]) # 2 last pivotal points mapping

    return pivotal_path


# def sample_step_t_pivotal(
#     XY_sampler: PairedLoaderSampler,
#     batch_size: int = 4,
#     pivotal_step: int = 0,
# ):
#     pivotal_path = sample_all_pivotal(XY_sampler, batch_size)
#     pivotal_t, pivotal_t_next = (
#         pivotal_path[pivotal_step],
#         pivotal_path[pivotal_step + 1],
#     )
#     return pivotal_t, pivotal_t_next

### mapping plotters


In [None]:
def plot_all_pivotal(
    source: torch.Tensor,
    target: torch.Tensor,
    gray: bool = False,
):
    pivotal_path = []

    source_list = [source]
    target_list = [target]
    for i in range(min(DIFFUSION_STEPS, PIVOTAL_LIST[-1])):
        source = SCHEDULER.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = SCHEDULER.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in PIVOTAL_LIST:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list[1:])  # just using source's last pivotal point
    # pivotal_path.extend(target_list[:]) # 2 last pivotal points mapping

    imgs: np.ndarray = (
        torch.stack(pivotal_path)
        .to("cpu")
        .permute(0, 2, 3, 1)
        .mul(0.5)
        .add(0.5)
        .numpy()
        .clip(0, 1)
    )
    nrows, ncols = 1, len(pivotal_path)
    fig = plt.figure(figsize=(1.5 * ncols, 1.5 * nrows), dpi=150)
    for i, img in enumerate(imgs):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        if gray:
            ax.imshow(img, cmap="gray")
        else:
            ax.imshow(img)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_title(f"$X_{i}$", fontsize=24)
        if i == imgs.shape[0] - 1:
            ax.set_title("Y", fontsize=24)

    return fig, fig.axes

## 4. Testing


### init models


In [None]:
Ts = []

for i in range(len(PIVOTAL_LIST) * 2):
    if T_TYPE == "UNet":
        T = UNet(
            DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=UNET_BASE_FACTOR
        ).cuda()
    elif T_TYPE == "U2Net":
        T = U2NET(in_ch=DATASET1_CHANNELS, out_ch=DATASET2_CHANNELS).cuda()
    else:
        raise NotImplementedError("Unknown T_TYPE: {}".format(T_TYPE))
    Ts.append(T)

### Load weights for testing


In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{2000}")  # user setting
for i, T in enumerate(Ts):
    w_path = os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
    T.load_state_dict(torch.load(w_path))
    print(f"{w_path}, loaded")

In [None]:
# writer.add_graph(
#     Ts[0], torch.rand(BATCH_SIZE, DATASET1_CHANNELS, IMG_SIZE, IMG_SIZE).cuda()
# )

### Plots Test


In [None]:
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(PLOT_N_SAMPLES)

In [None]:
fig, axes = plot_all_pivotal(X_test_fixed[0], Y_test_fixed[0], GRAY_PLOTS)

In [None]:
fig, axes = plot_linked_pushed_images(X_test_fixed, Y_test_fixed, Ts, gray=GRAY_PLOTS)

In [None]:
fig, axes = plot_linked_pushed_random_paired_images(
    XY_test_sampler, Ts, plot_n_samples=PLOT_N_SAMPLES, gray=GRAY_PLOTS
)

### mian test


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
clear_output(wait=True)
print("Plotting")

inference_Ts = Ts
for T in inference_Ts:
    T.eval()
print("Fixed Test Images")
fig, axes = plot_linked_pushed_images(
    X_test_fixed, Y_test_fixed, inference_Ts, gray=GRAY_PLOTS
)
plt.show(fig)
plt.close(fig)
print("Random Test Images")
fig, axes = plot_linked_pushed_random_paired_images(
    XY_test_sampler,
    inference_Ts,
    plot_n_samples=PLOT_N_SAMPLES,
    gray=GRAY_PLOTS,
)
plt.show(fig)
plt.close(fig)

In [None]:
print("Computing FID")
gen_mu, gen_sigma = get_linked_pushed_loader_stats(
    inference_Ts,
    XY_test_sampler.loader,
    n_epochs=FID_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=True,
)
fid = calculate_frechet_distance(gen_mu, gen_sigma, target_mu, target_sigma)
print(f"FID={fid}")

In [None]:
print("Computing LPIPS(vgg) LPIPS(alex) L1 MSE")
metrics = get_linked_pushed_loader_metrics(
    inference_Ts,
    XY_test_sampler.loader,
    n_epochs=FID_EPOCHS,
    verbose=True,
    log_metrics=["LPIPS", "PSNR", "SSIM", "MSE", "MAE"],
)
print(f"metrics={metrics}")