## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings

import torch
import numpy as np
from diffusers import DDIMScheduler


import matplotlib.pyplot as plt
from PIL import PngImagePlugin
from IPython.display import clear_output

sys.path.append("..")
from src.enot import SDE
from src.cunet import CUNet

from fid_score import calculate_frechet_distance
from src.tools import (
    set_random_seed,
    get_all_pivotal,
    get_step_t_pivotal,
    get_loader_stats,
    get_linked_sde_pushed_loader_stats,
    get_linked_sde_pushed_loader_metrics,
)
from src.plotters import (
    plot_linked_sde_pushed_images,
    plot_linked_sde_pushed_random_paired_images,
)
from src.samplers import PairedLoaderSampler, get_paired_sampler

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Init Config and FID stats

the config file `config.json` is saved at saved_models/EXP_NAME/


### init config

In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
# face2comic
DATASET, DATASET_PATH, REVERSE = (
    "comic_faces_v1",
    "../datasets/face2comics_v1.0.0_by_Sxela",
    False,
)

# colored mask -> face
# DATASET, DATASET_PATH, REVERSE = (
#     "celeba_mask",
#     "../datasets/CelebAMask-HQ",
#     False,
# )

# sketch -> face
# DATASET, DATASET_PATH, REVERSE = (
#     "FS2K",
#     "../datasets/FS2K/",
#     False,
# )

IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [40, 60, 90]

# GPU choosing
DEVICE_ID = 0
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_ID}")

# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
STRATEGY = "Fix"  # 'Fix' or 'Adapt'
# data sample settings
BATCH_SIZE = 2

# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 5  # num of shifts time
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

# plot settings
GRAY_PLOTS = False
PLOT_N_SAMPLES = 8

FID_EPOCHS = 1

EXP_NAME = f"DENOT_Paired_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise "no such file or directory"

## 3. Initialize samplers


In [None]:
_, XY_test_sampler = get_paired_sampler(
    DATASET,
    DATASET_PATH,
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    reverse=REVERSE,
)

In [None]:
torch.cuda.empty_cache()
gc.collect()
clear_output()

### pivotal sampler


In [None]:
SCHEDULER = DDIMScheduler(num_train_timesteps=DIFFUSION_STEPS)


def sample_all_pivotal(
    XY_sampler: PairedLoaderSampler,
    batch_size: int = 4,
) -> list[torch.Tensor]:
    source, target = XY_sampler.sample(batch_size)

    return get_all_pivotal(
        source,
        target,
        SCHEDULER,
        PIVOTAL_LIST,
    )

### mapping plotters


In [None]:
def plot_all_pivotal(
    source: torch.Tensor,
    target: torch.Tensor,
    gray: bool = False,
) -> list:
    pivotal_path = get_all_pivotal(
        source,
        target,
        SCHEDULER,
        PIVOTAL_LIST,
    )

    imgs: np.ndarray = (
        torch.stack(pivotal_path)
        .to("cpu")
        .permute(0, 2, 3, 1)
        .mul(0.5)
        .add(0.5)
        .numpy()
        .clip(0, 1)
    )
    nrows, ncols = 1, len(pivotal_path)
    fig = plt.figure(figsize=(1.5 * ncols, 1.5 * nrows), dpi=150)
    for i, img in enumerate(imgs):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        if gray:
            ax.imshow(img, cmap="gray")
        else:
            ax.imshow(img)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_title(f"$X_{i}$", fontsize=24)
        if i == imgs.shape[0] - 1:
            ax.set_title("Y", fontsize=24)

    torch.cuda.empty_cache()
    gc.collect()

## 4. Initialize models


### init models


In [None]:
Ts = []

for i in range(len(PIVOTAL_LIST) * 2):
    T = CUNet(
        DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
    ).cuda()

    T = SDE(
        shift_model=T,
        epsilon=EPSILON,
        n_steps=N_STEPS,
        time_dim=TIME_DIM,
        n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
        use_positional_encoding=USE_POSITIONAL_ENCODING,
        use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
        predict_shift=PREDICT_SHIFT,
        image_input=IMAGE_INPUT,
    ).cuda()
    Ts.append(T)

### load weights


In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{5000}/")  # user setting
for i, (T, D) in enumerate(Ts):
    w_path = os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
    T.load_state_dict(torch.load(w_path))
    print(f"{w_path}, loaded")

## 5. Plots Test


In [None]:
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(PLOT_N_SAMPLES)

In [None]:
plot_all_pivotal(X_test_fixed[0], Y_test_fixed[0])

In [None]:
fig, axes = plot_linked_sde_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    Ts,
)

In [None]:
fig, axes = plot_linked_sde_pushed_random_paired_images(
    XY_test_sampler,
    Ts,
    plot_n_samples=PLOT_N_SAMPLES,
    gray=GRAY_PLOTS,
)

## 6. Testing


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
clear_output(wait=True)
print("Plotting")

print("Fixed Test Images")
fig, axes = plot_linked_sde_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    Ts,
    gray=GRAY_PLOTS,
)
plt.show(fig)
plt.close(fig)

print("Random Test Images")
fig, axes = plot_linked_sde_pushed_random_paired_images(
    XY_test_sampler,
    Ts,
    plot_n_samples=PLOT_N_SAMPLES,
    gray=GRAY_PLOTS,
)
plt.show(fig)
plt.close(fig)

In [None]:
print("Computing FID")
use_Y = not REVERSE
target_mu, target_sigma = get_loader_stats(
    XY_test_sampler.loader,
    BATCH_SIZE,
    FID_EPOCHS,
    verbose=True,
    use_Y=use_Y,
)
gen_mu, gen_sigma = get_linked_sde_pushed_loader_stats(
    Ts,
    XY_test_sampler.loader,
    n_epochs=FID_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=True,
)
fid = calculate_frechet_distance(gen_mu, gen_sigma, target_mu, target_sigma)
print(f"FID={fid}")

In [None]:
print("Computing Mtrics")
metrics = get_linked_sde_pushed_loader_metrics(
    Ts,
    XY_test_sampler.loader,
    n_epochs=FID_EPOCHS,
    verbose=True,
    log_metrics=["LPIPS", "PSNR", "SSIM", "MSE", "MAE"],
)
print(f"metrics={metrics}")