## Imports


In [None]:
import os
import sys
import gc
import warnings

import torch
import numpy as np
from PIL import PngImagePlugin

from IPython.display import clear_output

sys.path.append("..")
from src.cunet import CUNet
from src.enot import SDE
from src.unet import UNet
from src.tools import (
    set_random_seed,
)
from src.plotters import (
    plot_pushed_images,
    plot_pushed_random_paired_images,
    plot_sde_pushed_images,
    plot_sde_pushed_random_paired_images,
    plot_linked_pushed_images,
    plot_linked_pushed_random_paired_images,
    plot_linked_sde_pushed_images,
    plot_linked_sde_pushed_random_paired_images,
)
from src.samplers import get_paired_sampler


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## General Config

In [None]:
SEED = 0x0000
set_random_seed(SEED)

# dataset choosing
# face2comic
# DATASET, DATASET_PATH, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', "face2comic", False
# colored mask -> face
# DATASET, DATASET_PATH, REVERSE = "celeba_mask", "../datasets/CelebAMask-HQ", "colored_mask2face", False
# sketch -> photo
DATASET, DATASET_PATH, MAP_NAME, REVERSE = (
    "FS2K",
    "../datasets/FS2K/",
    "sketch2photo",
    False,
)

IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# GPU choosing
DEVICE_IDS = [1]
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_IDS[0]}")

# training algorithm settings
BATCH_SIZE = 32
SUBSET_SIZE = 2
SUBSET_CLASS = 3

# plot settings
GRAY_PLOTS = True

FID_EPOCHS = 1

## Initialize samplers


In [None]:
_, XY_test_sampler = get_paired_sampler(
    DATASET, DATASET_PATH, img_size=IMG_SIZE, batch_size=BATCH_SIZE, reverse=REVERSE
)

In [None]:
torch.cuda.empty_cache()
gc.collect()
clear_output()

In [None]:
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(10)

## GNOT

In [None]:
EXP_NAME = f"GNOT_Unpair_{DATASET}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

### init model

In [None]:
T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=48).cuda()

### load weights

In [None]:
print("Loading weights")

w_path = os.path.join(LOAD_PATH, "T_10000_no_z.pt")  # user setting

T.load_state_dict(torch.load(w_path))
print(f"{w_path}, loaded")

### plot

In [None]:
fig, axes = plot_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    T,
    gray=GRAY_PLOTS,
    savefig=True,
    save_path="./figs/Paired/GNOT/fix",
)

In [None]:
fig, axes = plot_pushed_random_paired_images(
    XY_test_sampler,
    T,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    savefig=True,
    save_path="./figs/Paired/GNOT/random",
)

## ENOT

In [None]:
# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 5  #
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

In [None]:
EXP_NAME = f"ENOT_Unpair_{DATASET}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

### init model

In [None]:
T = CUNet(
    DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
).cuda()

T = SDE(
    shift_model=T,
    epsilon=EPSILON,
    n_steps=N_STEPS,
    time_dim=TIME_DIM,
    n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
    use_positional_encoding=USE_POSITIONAL_ENCODING,
    use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
    predict_shift=PREDICT_SHIFT,
    image_input=IMAGE_INPUT,
).cuda()

print("T params:", np.sum([np.prod(p.shape) for p in T.parameters()]))

### Load weights


In [None]:
print("Loading weights")

w_path = os.path.join(LOAD_PATH, f"T_{SEED}_5000.pt")  # user setting

T.load_state_dict(torch.load(w_path))

print(f"{w_path}, loaded")

### plot

In [None]:
fig, axes = plot_sde_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    T,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Paired/ENOT/fix",
)

In [None]:
fig, axes = plot_sde_pushed_random_paired_images(
    XY_test_sampler,
    T,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Paired/ENOT/random",
)

## DNOT

In [None]:
# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [20, 50, 100]  # [0, 100] for testing,  [0, 20, 50, 100]
# training algorithm settings
STRATEGY = "Adapt"  # 'Fix' or 'Adapt'
# model settings
UNET_BASE_FACTOR = 48

In [None]:
EXP_NAME = f"DNOT_paired_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

### init model

In [None]:
Ts = []

for i in range(len(PIVOTAL_LIST) * 2):
    T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=UNET_BASE_FACTOR).cuda()
    Ts.append(T)

### load weights


In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{2000}")  # user setting
for i, T in enumerate(Ts):
    w_path = os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
    T.load_state_dict(torch.load(w_path))
    print(f"{w_path}, loaded")

### plot

In [None]:
fig, axes = plot_linked_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    Ts,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Paired/DNOT/fix",
)

In [None]:
fig, axes = plot_linked_pushed_random_paired_images(
    XY_test_sampler,
    Ts,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Paired/DNOT/random",
)

## DENOT

In [None]:
# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [20, 50, 100]  # [0, 100] for testing,  [0, 20, 50, 100]
# training algorithm settings
STRATEGY = "Fix"  # 'Fix' or 'Adapt'
# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 5  # num of shifts time
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

In [None]:
EXP_NAME = f"DENOT_paired_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

### init model

In [None]:
SDEs = []

for i in range(len(PIVOTAL_LIST) * 2):
    T = CUNet(
        DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
    ).cuda()

    T = SDE(
        shift_model=T,
        epsilon=EPSILON,
        n_steps=N_STEPS,
        time_dim=TIME_DIM,
        n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
        use_positional_encoding=USE_POSITIONAL_ENCODING,
        use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
        predict_shift=PREDICT_SHIFT,
        image_input=IMAGE_INPUT,
    ).cuda()
    SDEs.append(T)

### load weights

In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{10000}/")  # user setting
for i, T in enumerate(SDEs):
    T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
    print(f"{CKPT_DIR}/T{i}_{SEED}.pt, loaded")

### plot

In [None]:
fig, axes = plot_linked_sde_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    SDEs,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Paired/DENOT/fix",
)

In [None]:
fig, axes = plot_linked_sde_pushed_random_paired_images(
    XY_test_sampler,
    SDEs,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Paired/DENOT/random",
)