## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings

import torch
import numpy as np
import matplotlib.pyplot as plt

from PIL import PngImagePlugin

from IPython.display import clear_output

sys.path.append("..")
from src.enot import SDE
from src.resnet2 import ResNet_D
from src.cunet import CUNet

from src.tools import (
    set_random_seed,
    weights_init_D,
    get_sde_pushed_loader_metrics,
    get_sde_pushed_loader_stats,
)
from src.fid_score import calculate_frechet_distance
from src.samplers import get_paired_sampler
from src.plotters import (
    plot_sde_pushed_images,
    plot_sde_pushed_random_paired_images,
    plot_fixed_sde_trajectories,
    plot_several_fixed_sde_trajectories,
)


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Init Config and FID stats

the config file `config.json` is saved at saved_models/EXP_NAME/

In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
# face2comic
# DATASET, DATASET_PATH, MAP_NAME, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', "comic2face", False
# colored mask -> face
# DATASET, DATASET_PATH, MAP_NAME, REVERSE = "celeba_mask", "../datasets/CelebAMask-HQ", "colored_mask2face", False
# sketch -> photo
DATASET, DATASET_PATH, MAP_NAME, REVERSE = (
    "FS2K",
    "../datasets/FS2K/",
    "sketch2photo",
    False,
)

IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# GPU choosing
DEVICE_IDS = [0]
assert torch.cuda.is_available()

CONTINUE = 0

# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
BATCH_SIZE = 2
T_ITERS = 10
MAX_STEPS = 2500 + 1  # 2501 for testing
INTEGRAL_SCALE = 1 / (3 * IMG_SIZE * IMG_SIZE)
EPSILON_SCHEDULER_LAST_ITER = 20000

# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 10
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

# plot settings
GRAY_PLOTS = False
STEPS_TO_SHOW = 10

# log settings
SMART_INTERVALS = False
INTERVAL_SHRINK_START_TIME = 0.98
TRACK_VAR_INTERVAL = 10
PLOT_INTERVAL = 500
CPKT_INTERVAL = 500

FID_EPOCHS = 1

EXP_NAME = f"ENOT_Paired_{DATASET}_{SEED}"
OUTPUT_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(OUTPUT_PATH):
    raise "no such file or directory"

### load FID stats

In [None]:
if not REVERSE:
    filename = f"../stats/{DATASET}_{MAP_NAME.split('2')[1]}_{IMG_SIZE}_test.json"
else:
    filename = f"../stats/{DATASET}_{MAP_NAME.split('2')[0]}_{IMG_SIZE}_test.json"

with open(filename, "r") as fp:
    data_stats = json.load(fp)
    target_mu, target_sigma = data_stats["mu"], data_stats["sigma"]
del data_stats

## 3. Initialize samplers


In [None]:
_, XY_test_sampler = get_paired_sampler(
    DATASET, DATASET_PATH, img_size=IMG_SIZE, reverse=REVERSE, batch_size=BATCH_SIZE
)

torch.cuda.empty_cache()
gc.collect()
clear_output()

## 4. Testing


### init models


In [None]:
D = ResNet_D(IMG_SIZE, nc=DATASET2_CHANNELS).cuda()
D.apply(weights_init_D)

T = CUNet(
    DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
).cuda()

T = SDE(
    shift_model=T,
    epsilon=EPSILON,
    n_steps=N_STEPS,
    time_dim=TIME_DIM,
    n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
    use_positional_encoding=USE_POSITIONAL_ENCODING,
    use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
    predict_shift=PREDICT_SHIFT,
    image_input=IMAGE_INPUT,
).cuda()

print("T params:", np.sum([np.prod(p.shape) for p in T.parameters()]))
print("D params:", np.sum([np.prod(p.shape) for p in D.parameters()]))

### load weights

In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{MAX_STEPS - 1}/")  # user setting

T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T_{SEED}.pt")))
print(f"{CKPT_DIR} T_{SEED}.pt, loaded")

D.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"D_{SEED}.pt")))
print(f"{CKPT_DIR} D_{SEED}.pt, loaded")

### Plots Test


In [None]:
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(10)

In [None]:
fig, axes = plot_sde_pushed_images(X_test_fixed, Y_test_fixed, T)

In [None]:
fig, axes = plot_sde_pushed_random_paired_images(XY_test_sampler, T)

### main testing

In [None]:
clear_output(wait=True)
print("Plotting")

inference_T = T
inference_T.eval()

print("Fixed Test Images")
fig, axes = plot_sde_pushed_images(
    X_test_fixed, Y_test_fixed, inference_T, gray=GRAY_PLOTS
)
# wandb.log({"Fixed Test Images": [wandb.Image(fig2img(fig))]}, step=step)
plt.show(fig)
plt.close(fig)

print("Random Test Images")
fig, axes = plot_sde_pushed_random_paired_images(
    XY_test_sampler, inference_T, gray=GRAY_PLOTS
)
# wandb.log({"Random Test Images": [wandb.Image(fig2img(fig))]}, step=step)
plt.show(fig)
plt.close(fig)

steps_to_draw = min(N_STEPS, 10)
print("Fixed Test Trajectories")
fig, axes = plot_fixed_sde_trajectories(
    X_test_fixed,
    Y_test_fixed,
    inference_T,
    STEPS_TO_SHOW,
    N_STEPS,
    gray=GRAY_PLOTS,
)
plt.show(fig)
plt.close(fig)

In [None]:
# print("TODO: Random Test Trajectories")
# fig, axes = plot_random_sde_trajectories(
#     X_test_sampler,
#     Y_test_sampler,
#     inference_T,
#     STEPS_TO_SHOW,
#     N_STEPS,
#     gray=GRAY_PLOTS,
# )
# plt.show(fig)
# plt.close(fig)

print("Several Fixed Trajectories")
fig, axes = plot_several_fixed_sde_trajectories(
    X_test_fixed,
    Y_test_fixed,
    inference_T,
    STEPS_TO_SHOW,
    N_STEPS,
    gray=GRAY_PLOTS,
)
plt.show(fig)
plt.close(fig)

# print("TODO: Several Random Trajectories")
# fig, axes = plot_several_random_sde_trajectories(
#     X_test_sampler,
#     Y_test_sampler,
#     inference_T,
#     STEPS_TO_SHOW,
#     N_STEPS,
#     gray=GRAY_PLOTS,
# )
# plt.show(fig)
# plt.close(fig)

In [None]:
print("Computing FID")
gen_mu, gen_sigma = get_sde_pushed_loader_stats(
    T,
    XY_test_sampler.loader,
    n_epochs=FID_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=True,
)
fid = calculate_frechet_distance(gen_mu, gen_sigma, target_mu, target_sigma)
print(f"FID={fid}")
del gen_mu, gen_sigma

In [None]:
print("Computing Metrics")
metrics = get_sde_pushed_loader_metrics(
    T,
    XY_test_sampler.loader,
    n_epochs=FID_EPOCHS,
    verbose=True,
    log_metrics=["LPIPS", "PSNR", "SSIM", "MSE", "MAE"],
)
print(f"metrics={metrics}")