# Import

In [None]:
import os
import sys
import warnings

import torch
from PIL import PngImagePlugin


sys.path.append("..")
from src.cunet import CUNet
from src.enot import SDE
from src.unet import UNet
from src.tools import (
    set_random_seed,
)
from src.plotters import (
    plot_linked_pushed_images,
    plot_linked_pushed_random_paired_images,
    plot_linked_sde_pushed_images,
    plot_linked_sde_pushed_random_paired_images,
)
from src.samplers import get_paired_sampler


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

# General Config

In [None]:
IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# GPU choosing
DEVICE_ID = 0
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_ID}")

# training algorithm settings
BATCH_SIZE = 8

SUBSET_SIZE = 2
SUBSET_CLASS = 3

# plot settings
GRAY_PLOTS = True

FID_EPOCHS = 1

# Initialize samplers

## A -> B sampler

In [None]:
# dataset choosing
# face -> comic
# DATASET, DATASET_PATH, AB_MAP_NAME, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', "face2comic", False
# mask -> face
# DATASET, DATASET_PATH, AB_MAP_NAME, REVERSE = "celeba_mask", "../datasets/CelebAMask-HQ", "colored_mask2face", False
# sketch -> face
DATASET, DATASET_PATH, AB_MAP_NAME, REVERSE = (
    "FS2K",
    "../datasets/FS2K/",
    "sketch2photo",
    False,
)

In [None]:
_, AB_test_sampler = get_paired_sampler(
    DATASET, DATASET_PATH, img_size=IMG_SIZE, batch_size=BATCH_SIZE, reverse=REVERSE
)

In [None]:
A_test_fixed, B_test_fixed = AB_test_sampler.sample(10)

## C -> D sampler

In [None]:
# dataset choosing
# face2comic
DATASET, DATASET_PATH, CD_MAP_NAME, REVERSE = (
    "comic_faces_v1",
    "../datasets/face2comics_v1.0.0_by_Sxela",
    "face2comic",
    False,
)
# colored mask -> face
# DATASET, DATASET_PATH, CD_MAP_NAME, REVERSE = "celeba_mask", "../datasets/CelebAMask-HQ", "colored_mask2face", False
# sketch -> photo
# DATASET, DATASET_PATH, CD_MAP_NAME, REVERSE = (
#     "FS2K",
#     "../datasets/FS2K/",
#     "sketch2photo",
#     False,
# )

In [None]:
_, CD_test_sampler = get_paired_sampler(
    DATASET, DATASET_PATH, img_size=IMG_SIZE, batch_size=BATCH_SIZE, reverse=REVERSE
)

In [None]:
C_test_fixed, D_test_fixed = CD_test_sampler.sample(10)

# DNOT

## A -> B

### init model and load weights

In [None]:
SEED = 0x3060
set_random_seed(SEED)
# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [20, 50, 100]  # [0, 100] for testing,  [0, 20, 50, 100]
# training algorithm settings
STRATEGY = "Adapt"  # 'Fix' or 'Adapt'
# model settings
UNET_BASE_FACTOR = 48

In [None]:
EXP_NAME = f"DNOT_paired_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

In [None]:
AB_Ts = []

for i in range(len(PIVOTAL_LIST) * 2):
    T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=UNET_BASE_FACTOR).cuda()
    AB_Ts.append(T)

In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{2000}")  # user setting
for i, T in enumerate(AB_Ts):
    w_path = os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
    T.load_state_dict(torch.load(w_path))
    print(f"{w_path}, loaded")

### plot A -> B

In [None]:
fig, axes = plot_linked_pushed_images(
    A_test_fixed,
    B_test_fixed,
    AB_Ts,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DNOT/{AB_MAP_NAME}/fix",
)

In [None]:
fig, axes = plot_linked_pushed_random_paired_images(
    AB_test_sampler,
    AB_Ts,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DNOT/{AB_MAP_NAME}/random",
)

## C -> D

### init model and load weights

In [None]:
SEED = 0x3060
set_random_seed(SEED)
# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [20, 50, 100]  # [0, 100] for testing,  [0, 20, 50, 100]
# training algorithm settings
STRATEGY = "Adapt"  # 'Fix' or 'Adapt'
# model settings
UNET_BASE_FACTOR = 48

In [None]:
EXP_NAME = f"DNOT_paired_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

In [None]:
CD_Ts = []

for i in range(len(PIVOTAL_LIST) * 2):
    T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=UNET_BASE_FACTOR).cuda()
    CD_Ts.append(T)

In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{2000}")  # user setting
for i, T in enumerate(CD_Ts):
    w_path = os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
    T.load_state_dict(torch.load(w_path))
    print(f"{w_path}, loaded")

### plot C -> D

In [None]:
fig, axes = plot_linked_pushed_images(
    C_test_fixed,
    D_test_fixed,
    CD_Ts,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DNOT/{CD_MAP_NAME}/fix",
)

In [None]:
fig, axes = plot_linked_pushed_random_paired_images(
    CD_test_sampler,
    CD_Ts,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DNOT/{CD_MAP_NAME}/random",
)

## A -> D

In [None]:
AD_MAP_NAME = AB_MAP_NAME.split("2")[0] + CD_MAP_NAME.split("2")[1]

### init model

In [None]:
AD_Ts = []
AG_T_num = len(AB_Ts) // 2
GD_T_num = len(CD_Ts) // 2
for i in range(AG_T_num + GD_T_num):
    if i < AG_T_num:
        AD_Ts.append(AB_Ts[i])
    else:
        AD_Ts.append(CD_Ts[i])

### plot A -> D

In [None]:
fig, axes = plot_linked_pushed_images(
    A_test_fixed,
    D_test_fixed,
    AD_Ts,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DNOT/{AD_MAP_NAME}/fix",
)

In [None]:
A_test_random, _ = AB_test_sampler.sample(10)
_, D_test_random = CD_test_sampler.sample(10)

fig, axes = plot_linked_pushed_images(
    A_test_random,
    D_test_random,
    AD_Ts,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DNOT/{AD_MAP_NAME}/random",
)

# DENOT

## A -> B

### init model and load weights

In [None]:
SEED = 0x3060
set_random_seed(SEED)
# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [20, 50, 100]  # [0, 100] for testing,  [0, 20, 50, 100]
# training algorithm settings
STRATEGY = "Fix"  # 'Fix' or 'Adapt'
# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 5  # num of shifts time
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

In [None]:
EXP_NAME = f"DENOT_paired_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

In [None]:
AB_SDEs = []

for i in range(len(PIVOTAL_LIST) * 2):
    T = CUNet(
        DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
    ).cuda()

    T = SDE(
        shift_model=T,
        epsilon=EPSILON,
        n_steps=N_STEPS,
        time_dim=TIME_DIM,
        n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
        use_positional_encoding=USE_POSITIONAL_ENCODING,
        use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
        predict_shift=PREDICT_SHIFT,
        image_input=IMAGE_INPUT,
    ).cuda()
    AB_SDEs.append(T)

In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{10000}/")  # user setting
for i, T in enumerate(AB_SDEs):
    T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
    print(f"{CKPT_DIR}/T{i}_{SEED}.pt, loaded")

### plot A -> B

In [None]:
fig, axes = plot_linked_sde_pushed_images(
    A_test_fixed,
    B_test_fixed,
    AB_SDEs,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DENOT/{AB_MAP_NAME}/fix",
)

In [None]:
fig, axes = plot_linked_sde_pushed_random_paired_images(
    AB_test_sampler,
    AB_SDEs,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DENOT/{AB_MAP_NAME}/random",
)

## C -> D

### init model and load weights

In [None]:
SEED = 0x3060
set_random_seed(SEED)
# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [20, 50, 100]  # [0, 100] for testing,  [0, 20, 50, 100]
# training algorithm settings
STRATEGY = "Fix"  # 'Fix' or 'Adapt'
# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 5  # num of shifts time
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

In [None]:
EXP_NAME = f"DENOT_paired_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

In [None]:
CD_SDEs = []

for i in range(len(PIVOTAL_LIST) * 2):
    T = CUNet(
        DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
    ).cuda()

    T = SDE(
        shift_model=T,
        epsilon=EPSILON,
        n_steps=N_STEPS,
        time_dim=TIME_DIM,
        n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
        use_positional_encoding=USE_POSITIONAL_ENCODING,
        use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
        predict_shift=PREDICT_SHIFT,
        image_input=IMAGE_INPUT,
    ).cuda()
    CD_SDEs.append(T)

In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{10000}/")  # user setting
for i, T in enumerate(CD_SDEs):
    T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
    print(f"{CKPT_DIR}/T{i}_{SEED}.pt, loaded")

### plot C -> D

In [None]:
fig, axes = plot_linked_sde_pushed_images(
    C_test_fixed,
    D_test_fixed,
    CD_SDEs,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DENOT/{CD_MAP_NAME}/fix",
)

In [None]:
fig, axes = plot_linked_sde_pushed_random_paired_images(
    CD_test_sampler,
    CD_SDEs,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DENOT/{CD_MAP_NAME}/random",
)

## A -> D

In [None]:
AD_MAP_NAME = AB_MAP_NAME.split("2")[0] + CD_MAP_NAME.split("2")[1]

### init model

In [None]:
AD_SDEs = []
AG_SDE_num = len(AB_SDEs) // 2
GD_SDE_num = len(CD_SDEs) // 2
for i in range(AG_SDE_num + GD_SDE_num):
    if i < AG_T_num:
        AD_SDEs.append(AB_SDEs[i])
    else:
        AD_SDEs.append(CD_SDEs[i])

### plot A -> D

In [None]:
fig, axes = plot_linked_sde_pushed_images(
    A_test_fixed,
    D_test_fixed,
    AD_SDEs,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DENOT/{AD_MAP_NAME}/fix",
)

In [None]:
A_test_random, _ = AB_test_sampler.sample(10)
_, D_test_random = CD_test_sampler.sample(10)

fig, axes = plot_linked_sde_pushed_images(
    A_test_random,
    D_test_random,
    AD_SDEs,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/DENOT/{AD_MAP_NAME}/random",
)