# Import

In [None]:
import os
import sys
import warnings

import torch
from PIL import PngImagePlugin


sys.path.append("..")
from src.cunet import CUNet
from src.enot import SDE
from src.unet import UNet
from src.u2net import U2NET

from src.tools import (
    set_random_seed,
)
from src.plotters import (
    plot_pushed_images,
    plot_pushed_random_paired_images,
    plot_linked_pushed_images,
    plot_sde_pushed_images,
    plot_sde_pushed_random_paired_images,
    plot_linked_sde_pushed_images,
)
from src.samplers import get_paired_sampler


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

# General Config

In [None]:
IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# GPU choosing
DEVICE_ID = 0
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_ID}")

BATCH_SIZE = 8

SUBSET_SIZE = 2
SUBSET_CLASS = 3

# plot settings
GRAY_PLOTS = True

FID_EPOCHS = 1

# Initialize samplers

## A -> B sampler

In [None]:
# dataset choosing
# face -> comic
# AB_DATASET, AB_DATASET_PATH, AB_MAP_NAME, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', "face2comic", False
# mask -> face
# AB_DATASET, AB_DATASET_PATH, AB_MAP_NAME, REVERSE = "celeba_mask", "../datasets/CelebAMask-HQ", "colored_mask2face", False
# sketch -> face
AB_DATASET, AB_DATASET_PATH, AB_MAP_NAME, REVERSE = (
    "FS2K",
    "../datasets/FS2K/",
    "sketch2photo",
    False,
)

In [None]:
_, AB_test_sampler = get_paired_sampler(
    AB_DATASET,
    AB_DATASET_PATH,
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    reverse=REVERSE,
)

In [None]:
A_test_fixed, B_test_fixed = AB_test_sampler.sample(10)

## C -> D sampler

In [None]:
# dataset choosing
# face -> comic
CD_DATASET, CD_DATASET_PATH, CD_MAP_NAME, REVERSE = (
    "comic_faces_v1",
    "../datasets/face2comics_v1.0.0_by_Sxela",
    "face2comic",
    False,
)
# colored mask -> face
# DATASET, DATASET_PATH, CD_MAP_NAME, REVERSE = "celeba_mask", "../datasets/CelebAMask-HQ", "colored_mask2face", False
# sketch -> photo
# DATASET, DATASET_PATH, CD_MAP_NAME, REVERSE = (
#     "FS2K",
#     "../datasets/FS2K/",
#     "sketch2photo",
#     False,
# )

In [None]:
_, CD_test_sampler = get_paired_sampler(
    CD_DATASET,
    CD_DATASET_PATH,
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    reverse=REVERSE,
)

In [None]:
C_test_fixed, D_test_fixed = CD_test_sampler.sample(10)

# GNOT

## A -> B

### init model and load weights

In [None]:
SEED = 0x3060
set_random_seed(SEED)
# training algorithm settings
BATCH_SIZE = 32

T_TYPE = "U2Net"  # 'UNet' # or  ('ResNet_pix2pix' - not implemented)

In [None]:
EXP_NAME = f"GNOT_paired_{AB_DATASET}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

In [None]:
if T_TYPE == "UNet":
    AB_T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=48).cuda()
elif T_TYPE == "U2Net":
    AB_T = U2NET(in_ch=DATASET1_CHANNELS, out_ch=DATASET2_CHANNELS).cuda()
else:
    raise NotImplementedError("Unknown T_TYPE: {}".format(T_TYPE))

In [None]:
print("Loading weights")

w_path = os.path.join(LOAD_PATH, f"T_{SEED}_30000.pt")  # user setting

AB_T.load_state_dict(torch.load(w_path))

print(f"{w_path}, loaded")

### plot A -> B

In [None]:
fig, axes = plot_pushed_images(
    A_test_fixed,
    B_test_fixed,
    AB_T,
    gray=GRAY_PLOTS,
    savefig=True,
    save_path=f"./figs/Exchange/GNOT/{AB_MAP_NAME}/fix",
)

In [None]:
fig, axes = plot_pushed_random_paired_images(
    AB_test_sampler,
    AB_T,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    savefig=True,
    save_path=f"./figs/Exchange/GNOT/{AB_MAP_NAME}/random",
)

## C -> D

### init model and load weights

In [None]:
SEED = 0x3060
set_random_seed(SEED)
# training algorithm settings
BATCH_SIZE = 32

T_TYPE = "U2Net"  # 'UNet' # or  ('ResNet_pix2pix' - not implemented)

In [None]:
EXP_NAME = f"GNOT_paired_{CD_DATASET}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

In [None]:
if T_TYPE == "UNet":
    CD_T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=48).cuda()
elif T_TYPE == "U2Net":
    CD_T = U2NET(in_ch=DATASET1_CHANNELS, out_ch=DATASET2_CHANNELS).cuda()
else:
    raise NotImplementedError("Unknown T_TYPE: {}".format(T_TYPE))

In [None]:
print("Loading weights")

w_path = os.path.join(LOAD_PATH, f"T_{SEED}_30000.pt")  # user setting

CD_T.load_state_dict(torch.load(w_path))

print(f"{w_path}, loaded")

### plot C -> D

In [None]:
fig, axes = plot_pushed_images(
    A_test_fixed,
    B_test_fixed,
    CD_T,
    gray=GRAY_PLOTS,
    savefig=True,
    save_path=f"./figs/Exchange/GNOT/{CD_MAP_NAME}/fix",
)

In [None]:
fig, axes = plot_pushed_random_paired_images(
    AB_test_sampler,
    CD_T,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    savefig=True,
    save_path=f"./figs/Exchange/GNOT/{CD_MAP_NAME}/random",
)

## A -> D

In [None]:
AD_MAP_NAME = AB_MAP_NAME.split("2")[0] + CD_MAP_NAME.split("2")[1]

### init model

In [None]:
AD_Ts = [AB_T, CD_T]

### plot A -> D

In [None]:
# actually the Y sample is useless
fig, axes = plot_linked_pushed_images(
    A_test_fixed,
    D_test_fixed,
    AD_Ts,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/GNOT/{AD_MAP_NAME}/fix",
)

In [None]:
A_test_random, _ = AB_test_sampler.sample(10)
_, D_test_random = CD_test_sampler.sample(10)

# actually the Y sample is useless
fig, axes = plot_linked_pushed_images(
    A_test_random,
    D_test_random,
    AD_Ts,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/GNOT/{AD_MAP_NAME}/random",
)

# ENOT

## A -> B

### init model and load weights

In [None]:
SEED = 0x3060
set_random_seed(SEED)

# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 10
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

In [None]:
EXP_NAME = f"ENOT_paired_{AB_DATASET}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

In [None]:
T = CUNet(
    DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
).cuda()


AB_SDE = SDE(
    shift_model=T,
    epsilon=EPSILON,
    n_steps=N_STEPS,
    time_dim=TIME_DIM,
    n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
    use_positional_encoding=USE_POSITIONAL_ENCODING,
    use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
    predict_shift=PREDICT_SHIFT,
    image_input=IMAGE_INPUT,
).cuda()

In [None]:
print("Loading weights")

w_path = os.path.join(LOAD_PATH, f"iter{5000}/T_{SEED}.pt")  # user setting

AB_SDE.load_state_dict(torch.load(w_path))

print(f"{w_path}, loaded")

### plot A -> B

In [None]:
fig, axes = plot_sde_pushed_images(
    A_test_fixed,
    B_test_fixed,
    AB_SDE,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/ENOT/{AB_MAP_NAME}/fix",
)

In [None]:
fig, axes = plot_sde_pushed_random_paired_images(
    AB_test_sampler,
    AB_SDE,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/ENOT/{AB_MAP_NAME}/random",
)

## C -> D

### init model and load weights

In [None]:
SEED = 0x3060
set_random_seed(SEED)

# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 5  # num of shifts time
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

In [None]:
EXP_NAME = f"ENOT_paired_{CD_DATASET}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

In [None]:
T = CUNet(
    DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
).cuda()


CD_SDE = SDE(
    shift_model=T,
    epsilon=EPSILON,
    n_steps=N_STEPS,
    time_dim=TIME_DIM,
    n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
    use_positional_encoding=USE_POSITIONAL_ENCODING,
    use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
    predict_shift=PREDICT_SHIFT,
    image_input=IMAGE_INPUT,
).cuda()

In [None]:
print("Loading weights")

w_path = os.path.join(LOAD_PATH, f"iter{5000}/T_{SEED}.pt")  # user setting

CD_SDE.load_state_dict(torch.load(w_path))

print(f"{w_path}, loaded")

### plot C -> D

In [None]:
fig, axes = plot_sde_pushed_images(
    A_test_fixed,
    B_test_fixed,
    CD_SDE,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/ENOT/{CD_MAP_NAME}/fix",
)

In [None]:
fig, axes = plot_sde_pushed_random_paired_images(
    AB_test_sampler,
    CD_SDE,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/ENOT/{CD_MAP_NAME}/random",
)

## A -> D

In [None]:
AD_MAP_NAME = AB_MAP_NAME.split("2")[0] + CD_MAP_NAME.split("2")[1]

### init model

In [None]:
AD_SDEs = [AB_SDE, CD_SDE]

### plot A -> D

In [None]:
# actually the Y sample is useless
fig, axes = plot_linked_sde_pushed_images(
    A_test_fixed,
    D_test_fixed,
    AD_SDEs,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/ENOT/{AD_MAP_NAME}/fix",
)

In [None]:
A_test_random, _ = AB_test_sampler.sample(10)
_, D_test_random = CD_test_sampler.sample(10)

# actually the Y sample is useless
fig, axes = plot_linked_sde_pushed_images(
    A_test_random,
    D_test_random,
    AD_SDEs,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path=f"./figs/Exchange/ENOT/{AD_MAP_NAME}/random",
)