# Code for experiments with colored MNIST and Celeba


## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from tensorboardX import SummaryWriter
from PIL import PngImagePlugin

from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

sys.path.append("..")
from src.enot import SDE, integrate
from src.resnet2 import ResNet_D
from src.cunet import CUNet

from src.tools import (
    set_random_seed,
    unfreeze,
    freeze,
    weights_init_D,
    fig2tensor,
)
from src.samplers import get_paired_sampler
from src.plotters import (
    plot_sde_pushed_images,
    plot_sde_pushed_random_paired_images,
)


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Config

Dataset choosing in the first rows


In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing

# dataset choosing
# DATASET, DATASET_PATH, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', False
# DATASET, DATASET_PATH, REVERSE = "celeba_mask", "../datasets/CelebAMask-HQ", False
DATASET, DATASET_PATH, REVERSE = "FS2K", "../datasets/FS2K/", False

IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# GPU choosing
DEVICE_IDS = [0]
assert torch.cuda.is_available()

CONTINUE = 0

# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
BATCH_SIZE = 1  # 1 for testing
T_ITERS = 10
MAX_STEPS = 2500 + 1  # 2501 for testing
INTEGRAL_SCALE = 1 / (3 * IMG_SIZE * IMG_SIZE)
EPSILON_SCHEDULER_LAST_ITER = 20000

# optimizer settings
D_LR, T_LR = 1e-4, 1e-4
BETA_D, BETA_T = 0.9, 0.9
T_GRADIENT_MAX_NORM = float(500)
D_GRADIENT_MAX_NORM = float(500)

# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 10
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

# plot settings
GRAY_PLOTS = False
STEPS_TO_SHOW = 10

# log settings
SMART_INTERVALS = False
INTERVAL_SHRINK_START_TIME = 0.98
TRACK_VAR_INTERVAL = 10
PLOT_INTERVAL = 500
CPKT_INTERVAL = 500

FID_EPOCHS = 50

EXP_NAME = f"ENOT_Paired_{DATASET}_{SEED}"
OUTPUT_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

writer = SummaryWriter(f"../logdir/{EXP_NAME}")

In [None]:
config = dict(
    SEED=SEED,
    DATASET=DATASET,
    T_ITERS=T_ITERS,
    D_LR=D_LR,
    T_LR=T_LR,
    BATCH_SIZE=BATCH_SIZE,
    UNET_BASE_FACTOR=UNET_BASE_FACTOR,
    N_STEPS=N_STEPS,
    EPSILON=EPSILON,
    USE_POSITIONAL_ENCODING=USE_POSITIONAL_ENCODING,
    TIME_DIM=TIME_DIM,
    INTEGRAL_SCALE=INTEGRAL_SCALE,
    ONE_STEP_INIT_ITERS=ONE_STEP_INIT_ITERS,
    T_GRADIENT_MAX_NORM=T_GRADIENT_MAX_NORM,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    PREDICT_SHIFT=PREDICT_SHIFT,
    SMART_INTERVALS=SMART_INTERVALS,
    INTERVAL_SHRINK_START_TIME=INTERVAL_SHRINK_START_TIME,
    USE_GRADIENT_CHECKPOINT=USE_GRADIENT_CHECKPOINT,
    N_LAST_STEPS_WITHOUT_NOISE=N_LAST_STEPS_WITHOUT_NOISE,
    TRACK_VAR_INTERVAL=TRACK_VAR_INTERVAL,
    EPSILON_SCHEDULER_LAST_ITER=EPSILON_SCHEDULER_LAST_ITER,
    FID_EPOCHS=FID_EPOCHS,
)

with open(os.path.join(OUTPUT_PATH, "config.json"), "w") as json_file:
    json_str = json.dumps(config, indent=4)
    json_file.write(json_str)

log = dict(CONTINUE=CONTINUE)
with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
    log_str = json.dumps(log, indent=4)
    log_file.write(log_str)

In [None]:
filename = "../stats/{}_{}_{}_test.json".format(DATASET, IMG_SIZE, REVERSE)
with open(filename, "r") as fp:
    data_stats = json.load(fp)
    mu_data, sigma_data = data_stats["mu"], data_stats["sigma"]
del data_stats

## 3. Initialize samplers


In [None]:
XY_sampler, XY_test_sampler = get_paired_sampler(
    DATASET, DATASET_PATH, img_size=IMG_SIZE, reverse=REVERSE
)

torch.cuda.empty_cache()
gc.collect()
clear_output()

## 4. Training


### Models initialization


In [None]:
D = ResNet_D(IMG_SIZE, nc=DATASET2_CHANNELS).cuda()
D.apply(weights_init_D)

T = CUNet(
    DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
).cuda()

T = SDE(
    shift_model=T,
    epsilon=EPSILON,
    n_steps=N_STEPS,
    time_dim=TIME_DIM,
    n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
    use_positional_encoding=USE_POSITIONAL_ENCODING,
    use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
    predict_shift=PREDICT_SHIFT,
    image_input=IMAGE_INPUT,
).cuda()

if len(DEVICE_IDS) > 1 and CONTINUE == 0:
    T = nn.DataParallel(T, device_ids=DEVICE_IDS)
    D = nn.DataParallel(D, device_ids=DEVICE_IDS)

print("T params:", np.sum([np.prod(p.shape) for p in T.parameters()]))
print("D params:", np.sum([np.prod(p.shape) for p in D.parameters()]))

In [None]:
T_opt = torch.optim.Adam(
    T.parameters(), lr=T_LR, weight_decay=1e-10, betas=(BETA_T, 0.999)
)
D_opt = torch.optim.Adam(
    D.parameters(), lr=D_LR, weight_decay=1e-10, betas=(BETA_D, 0.999)
)
T_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    T_opt, milestones=[15000, 25000, 40000, 55000, 70000], gamma=0.5
)
D_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    D_opt, milestones=[15000, 25000, 40000, 55000, 70000], gamma=0.5
)

if CONTINUE > 0:
    T_opt.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"T_opt_{SEED}_{CONTINUE}.pt"))
    )
    T_scheduler.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"T_scheduler_{SEED}_{CONTINUE}.pt"))
    )

    T.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f"T_{SEED}_{CONTINUE}.pt")))
    D.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f"D_{SEED}_{CONTINUE}.pt")))

    if len(DEVICE_IDS) > 1:
        T = nn.DataParallel(T, device_ids=DEVICE_IDS)
        D = nn.DataParallel(D, device_ids=DEVICE_IDS)

    D_opt.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"D_opt_{SEED}_{CONTINUE}.pt"))
    )
    D_scheduler.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"D_scheduler_{SEED}_{CONTINUE}.pt"))
    )

In [None]:
X_fixed, Y_fixed = XY_sampler.sample(10)
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(10)

In [None]:
# wandb.init(name=EXP_NAME, config=config)

### Plots Test


In [None]:
fig, axes = plot_sde_pushed_images(X_fixed, Y_fixed, T)
fig, axes = plot_sde_pushed_random_paired_images(XY_sampler, T)
fig, axes = plot_sde_pushed_images(X_test_fixed, Y_test_fixed, T)
fig, axes = plot_sde_pushed_random_paired_images(XY_test_sampler, T)
writer.add_image("sde pushed random paired images", fig2tensor(fig))

### Main training cycle and logging


In [None]:
def epsilon_scheduler(step):
    return min(EPSILON, EPSILON * (step / EPSILON_SCHEDULER_LAST_ITER))


progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE)

for step in range(MAX_STEPS):
    if step < CONTINUE:
        continue

    unfreeze(T)
    freeze(D)

    new_epsilon = epsilon_scheduler(step)
    if len(DEVICE_IDS) > 1:
        T.module.set_epsilon(new_epsilon)
    else:
        T.set_epsilon(new_epsilon)

    writer.add_scalar("Epsilon", new_epsilon, step)

    for t_iter in range(T_ITERS):
        T_opt.zero_grad()

        X0, X1 = XY_sampler.sample(BATCH_SIZE)
        X0.requires_grad_()

        trajectory, times, shifts = T(X0)
        XN = trajectory[:, -1]
        norm = torch.norm(shifts.flatten(start_dim=2), p=2, dim=-1) ** 2
        integral = INTEGRAL_SCALE * integrate(norm, times[0])

        T_loss = (integral + D(X1) - D(XN)).mean()
        writer.add_scalar("T_loss", T_loss.item(), step)
        T_loss.backward()
        T_gradient_norm = torch.nn.utils.clip_grad_norm_(
            T.parameters(), max_norm=T_GRADIENT_MAX_NORM
        )
        T_opt.step()

    # wandb.log({f"T gradient norm": T_gradient_norm.item()}, step=step)
    # wandb.log({f"Mean norm": torch.sqrt(norm).mean().item()}, step=step)
    # wandb.log({f"T_loss": T_loss.item()}, step=step)

    T_scheduler.step()
    del T_loss, X0, X1, XN
    gc.collect()
    torch.cuda.empty_cache()

    freeze(T)
    unfreeze(D)

    D_opt.zero_grad()

    X0, X1 = XY_sampler.sample(BATCH_SIZE)
    trajectory, times, shifts = T(X0)
    XN = trajectory[:, -1]
    norm = torch.norm(shifts.flatten(start_dim=2), p=2, dim=-1) ** 2
    integral = INTEGRAL_SCALE * integrate(norm, times[0])

    D_X1 = D(X1)
    D_XN = D(XN)

    D_loss = (-integral - D_X1 + D_XN).mean()
    writer.add_scalar("D_loss", D_loss.item(), step)
    D_loss.backward()
    D_gradient_norm = torch.nn.utils.clip_grad_norm_(
        D.parameters(), max_norm=D_GRADIENT_MAX_NORM
    )
    D_opt.step()
    D_scheduler.step()

    # wandb.log({f"D gradient norm": D_gradient_norm.item()}, step=step)
    # wandb.log({f"D_loss": D_loss.item()}, step=step)

    # wandb.log({f"integral": integral.mean().item()}, step=step)
    # wandb.log({f"D_X1": D_X1.mean().item()}, step=step)
    # wandb.log({f"D_XN": D_XN.mean().item()}, step=step)
    del D_loss, X0, X1, XN
    gc.collect()
    torch.cuda.empty_cache()

    progress_bar.update(1)

    if step % PLOT_INTERVAL == 0:
        progress_bar.close()
        clear_output(wait=True)
        progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE)
        print("Plotting")

        inference_T = T
        inference_T.eval()

        print("Fixed Test Images")
        fig, axes = plot_sde_pushed_images(
            X_test_fixed, Y_test_fixed, inference_T, gray=GRAY_PLOTS
        )
        writer.add_image("Fixed Test Images", fig2tensor(fig), step)
        plt.show(fig)
        plt.close(fig)

        print("Random Test Images")
        fig, axes = plot_sde_pushed_random_paired_images(
            XY_test_sampler, inference_T, gray=GRAY_PLOTS
        )
        writer.add_image("Random Test Images", fig2tensor(fig), step)
        plt.show(fig)
        plt.close(fig)

    if step != 0 and step % CPKT_INTERVAL == 0:
        inference_T = T

        inference_T.eval()
        freeze(T)
        if len(DEVICE_IDS) > 1:
            torch.save(
                T.module.state_dict(), os.path.join(OUTPUT_PATH, f"T_{SEED}_{step}.pt")
            )
            torch.save(
                D.module.state_dict(), os.path.join(OUTPUT_PATH, f"D_{SEED}_{step}.pt")
            )
        else:
            torch.save(T.state_dict(), os.path.join(OUTPUT_PATH, f"T_{SEED}_{step}.pt"))
            torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, f"D_{SEED}_{step}.pt"))

        torch.save(
            D_opt.state_dict(), os.path.join(OUTPUT_PATH, f"D_opt_{SEED}_{step}.pt")
        )
        torch.save(
            T_opt.state_dict(), os.path.join(OUTPUT_PATH, f"T_opt_{SEED}_{step}.pt")
        )
        torch.save(
            D_scheduler.state_dict(),
            os.path.join(OUTPUT_PATH, f"D_scheduler_{SEED}_{step}.pt"),
        )
        torch.save(
            T_scheduler.state_dict(),
            os.path.join(OUTPUT_PATH, f"T_scheduler_{SEED}_{step}.pt"),
        )

        log["CONTINUE"] = CONTINUE
        with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
            log_str = json.dumps(log, indent=4)
            log_file.write(log_str)

    if step % TRACK_VAR_INTERVAL == 0:
        pass
        # print("Computing FID")
        # mu, sigma = get_sde_pushed_loader_stats(
        #     T,
        #     XY_test_sampler.loader,
        #     n_epochs=FID_EPOCHS,
        #     batch_size=BATCH_SIZE,
        #     verbose=True,
        # )
        # fid = calculate_frechet_distance(mu_data, sigma_data, mu, sigma)
        # print(f"FID={fid}")
        # writer.add_scalar("Metrics/FID", fid, step)
        # del mu, sigma

        # print("Computing LPIPS(vgg) LPIPS(alex) L1 MSE")
        # metrics = get_sde_pushed_loader_metrics(
        #     T,
        #     XY_test_sampler.loader,
        #     n_epochs=FID_EPOCHS,
        #     batch_size=BATCH_SIZE,
        #     verbose=True,
        #     log_metrics=["mse", "l1"]
        # )
        # print(f"metrics={metrics}")
        # writer.add_scalar("Metrics/LPIPS(VGG)", metrics["vgg"], step)
        # writer.add_scalar("Metrics/LPIPS(Alex)", metrics["alex"], step)
        # writer.add_scalar("Metrics/L1", metrics["l1"], step)
        # writer.add_scalar("Metrics/MSE", metrics["mse"], step)

    gc.collect()
    torch.cuda.empty_cache()

## Clear resources


In [None]:
try:
    writer.close()
    progress_bar.close()
except Exception as e:
    print(e)