# Code for experiments with colored MNIST and Celeba


## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from tensorboardX import SummaryWriter
from PIL import PngImagePlugin

from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

sys.path.append("..")
from src.losses import VGGPerceptualLoss as VGGLoss
from src.resnet2 import ResNet_D
from src.unet import UNet
from src.u2net import U2NET

from src.tools import (
    set_random_seed,
    unfreeze,
    freeze,
    weights_init_D,
    fig2tensor,
)
from src.samplers import get_paired_sampler
from src.plotters import (
    plot_pushed_images,
    plot_pushed_random_paired_images,
)


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Config

Dataset choosing in the first rows


In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
# DATASET, DATASET_PATH, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', False
DATASET, DATASET_PATH, REVERSE = "celeba_mask", "../datasets/CelebAMask-HQ", False
# DATASET, DATASET_PATH, REVERSE = "FS2K", "../datasets/FS2K/", False

IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# GPU choosing
DEVICE_IDS = [1]
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_IDS[0]}")

CONTINUE = 0

# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
BATCH_SIZE = 2  # 1 for testing
T_ITERS = 10
MAX_STEPS = 30000 + 1  # 2501 for testing
COST = "vgg"  #'mse' # 'mae' # 'vgg'


# optimizer settings
D_LR, T_LR = 1e-4, 1e-4
T_GRADIENT_MAX_NORM = float(100)
D_GRADIENT_MAX_NORM = float(100)

# network settings
T_TYPE = "U2Net"  # 'UNet' # or  ('ResNet_pix2pix' - not implemented)
D_TYPE = (
    "ResNet"  # or 'ResNet_pix2pix' - DOES NOT WORK WELL (it is actually not a resnet:)
)

D_NORM = "none"  # For our ResNet_D uses the "batchnorm" or "none".
CONDITIONAL = False  # Test conditional NOT (not needed anymore)
NOT = True  # Train Neural optimal transport or pure regression

# plot settings
GRAY_PLOTS = False

# log settings
TRACK_VAR_INTERVAL = 1000
PLOT_INTERVAL = 500
CPKT_INTERVAL = 5000

FID_EPOCHS = 1

EXP_NAME = f"GNOT_Paired_{DATASET}_{SEED}"
OUTPUT_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

writer = SummaryWriter(f"../logdir/{EXP_NAME}")

In [None]:
config = dict(
    SEED=SEED,
    DATASET=DATASET,
    T_TYPE=T_TYPE,
    D_TYPE=D_TYPE,
    T_ITERS=T_ITERS,
    D_LR=D_LR,
    T_LR=T_LR,
    BATCH_SIZE=BATCH_SIZE,
    CONDITIONAL=CONDITIONAL,
    NOT=NOT,
    COST=COST,
    T_GRADIENT_MAX_NORM=T_GRADIENT_MAX_NORM,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    TRACK_VAR_INTERVAL=TRACK_VAR_INTERVAL,
    FID_EPOCHS=FID_EPOCHS,
)

with open(os.path.join(OUTPUT_PATH, "config.json"), "w") as json_file:
    json_str = json.dumps(config, indent=4)
    json_file.write(json_str)

log = dict(CONTINUE=CONTINUE)
with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
    log_str = json.dumps(log, indent=4)
    log_file.write(log_str)

In [None]:
if COST == "vgg":
    vgg_loss = VGGLoss().cuda()

## Loading data stats for testing

In [None]:
# filename = "../stats/{}_{}_{}_test.json".format(DATASET, IMG_SIZE, REVERSE)
# with open(filename, "r") as fp:
#     data_stats = json.load(fp)
#     mu_data, sigma_data = data_stats["mu"], data_stats["sigma"]
# del data_stats

## 3. Initialize samplers


In [None]:
XY_sampler, XY_test_sampler = get_paired_sampler(
    DATASET,
    DATASET_PATH,
    img_size=IMG_SIZE,
    reverse=REVERSE,
    num_workers=12,
)
# XY_sampler, XY_test_sampler = get_paired_sampler(
#     DATASET, DATASET_PATH, img_size=IMG_SIZE, reverse=REVERSE, num_workers=8, device='cuda'
# )

torch.cuda.empty_cache()
gc.collect()
clear_output()

## 4. Training


### Models initialization


In [None]:
if D_TYPE == "ResNet":
    D = ResNet_D(
        IMG_SIZE,
        nc=3 if not CONDITIONAL else 6,
        bn=D_NORM != "none",
    ).cuda()
    D.apply(weights_init_D)
else:
    raise NotImplementedError("Unknown D_TYPE: {}".format(D_TYPE))

if T_TYPE == "UNet":
    T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=48).cuda()
elif T_TYPE == "U2Net":
    T = U2NET(in_ch=DATASET1_CHANNELS, out_ch=DATASET2_CHANNELS).cuda()
else:
    raise NotImplementedError("Unknown T_TYPE: {}".format(T_TYPE))

if len(DEVICE_IDS) > 1:
    T = nn.DataParallel(T, device_ids=DEVICE_IDS)
    D = nn.DataParallel(D, device_ids=DEVICE_IDS)

print("T params:", np.sum([np.prod(p.shape) for p in T.parameters()]))
print("D params:", np.sum([np.prod(p.shape) for p in D.parameters()]))

In [None]:
T_opt = torch.optim.Adam(T.parameters(), lr=T_LR, weight_decay=1e-10)
D_opt = torch.optim.Adam(D.parameters(), lr=D_LR, weight_decay=1e-10)

T_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    T_opt, milestones=[15000, 30000, 45000, 70000], gamma=0.5
)
D_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    D_opt, milestones=[15000, 30000, 45000, 70000], gamma=0.5
)

if CONTINUE > 0:
    T_opt.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"T_opt_{SEED}_{CONTINUE}.pt"))
    )
    T_scheduler.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"T_scheduler_{SEED}_{CONTINUE}.pt"))
    )

    T.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f"T_{SEED}_{CONTINUE}.pt")))
    D.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f"D_{SEED}_{CONTINUE}.pt")))

    if len(DEVICE_IDS) > 1:
        T = nn.DataParallel(T, device_ids=DEVICE_IDS)
        D = nn.DataParallel(D, device_ids=DEVICE_IDS)

    D_opt.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"D_opt_{SEED}_{CONTINUE}.pt"))
    )
    D_scheduler.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"D_scheduler_{SEED}_{CONTINUE}.pt"))
    )

In [None]:
X_fixed, Y_fixed = XY_sampler.sample(10)

In [None]:
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(10)

### Plots Test


In [None]:
fig, axes = plot_pushed_images(X_fixed, Y_fixed, T)
fig, axes = plot_pushed_random_paired_images(XY_sampler, T)
fig, axes = plot_pushed_images(X_test_fixed, Y_test_fixed, T)
fig, axes = plot_pushed_random_paired_images(XY_test_sampler, T)
writer.add_image("pushed random paired images", fig2tensor(fig))

### Main training cycle and logging


In [None]:
progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE)

for step in range(MAX_STEPS):
    if step < CONTINUE:
        continue

    # T optimization
    unfreeze(T)
    freeze(D)

    for t_iter in range(T_ITERS):
        T_opt.zero_grad()
        X, Y = XY_sampler.sample(BATCH_SIZE)
        T_X = T(X)

        if CONDITIONAL:
            T_X = torch.cat([T_X, X], dim=1)

        if COST == "rmse":
            T_loss = (Y - T_X[:, :3]).flatten(start_dim=1).norm(dim=1).mean()
        elif COST == "mse":
            T_loss = (Y - T_X[:, :3]).flatten(start_dim=1).square().sum(dim=1).mean()
        elif COST == "mae":
            T_loss = (Y - T_X[:, :3]).flatten(start_dim=1).abs().sum(dim=1).mean()
        elif COST == "vgg":
            T_loss = vgg_loss(Y, T_X[:, :3]).mean()
        else:
            raise Exception("Unknown COST")

        if NOT:
            T_loss -= D(T_X).mean()

        writer.add_scalar("T_loss", T_loss.item(), step)

        T_loss.backward()
        T_opt.step()
    T_scheduler.step()
    del T_loss, T_X, X, Y
    gc.collect()
    torch.cuda.empty_cache()

    if NOT:
        # D optimization
        freeze(T)
        unfreeze(D)
        X, _ = XY_sampler.sample(BATCH_SIZE)
        with torch.no_grad():
            T_X = T(X)
        _, Y = XY_sampler.sample(BATCH_SIZE)  # We may use the previous batch here
        if CONDITIONAL:
            with torch.no_grad():
                T_X = torch.cat([T_X, X], dim=1)
                Y = torch.cat([Y, X], dim=1)
        D_opt.zero_grad()
        D_loss = D(T_X).mean() - D(Y).mean()
        writer.add_scalar("D_loss", D_loss.item(), step)
        D_loss.backward()
        D_opt.step()
        D_scheduler.step()
        del D_loss, Y, X, T_X, _
        gc.collect()
        torch.cuda.empty_cache()

    progress_bar.update(1)

    if step % PLOT_INTERVAL == 0:
        progress_bar.close()
        clear_output(wait=True)
        progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE)
        print("Plotting")

        inference_T = T
        inference_T.eval()

        print("Fixed Test Images")
        fig, axes = plot_pushed_images(
            X_test_fixed, Y_test_fixed, inference_T, gray=GRAY_PLOTS
        )
        plt.show(fig)
        plt.close(fig)

        print("Random Test Images")
        fig, axes = plot_pushed_random_paired_images(
            XY_test_sampler, inference_T, gray=GRAY_PLOTS
        )
        plt.show(fig)
        plt.close(fig)

    if step != 0 and step % CPKT_INTERVAL == 0:
        freeze(T)
        if len(DEVICE_IDS) > 1:
            torch.save(
                T.module.state_dict(), os.path.join(OUTPUT_PATH, f"T_{SEED}_{step}.pt")
            )
            torch.save(
                D.module.state_dict(), os.path.join(OUTPUT_PATH, f"D_{SEED}_{step}.pt")
            )
        else:
            torch.save(T.state_dict(), os.path.join(OUTPUT_PATH, f"T_{SEED}_{step}.pt"))
            torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, f"D_{SEED}_{step}.pt"))

        torch.save(
            D_opt.state_dict(), os.path.join(OUTPUT_PATH, f"D_opt_{SEED}_{step}.pt")
        )
        torch.save(
            T_opt.state_dict(), os.path.join(OUTPUT_PATH, f"T_opt_{SEED}_{step}.pt")
        )
        torch.save(
            D_scheduler.state_dict(),
            os.path.join(OUTPUT_PATH, f"D_scheduler_{SEED}_{step}.pt"),
        )
        torch.save(
            T_scheduler.state_dict(),
            os.path.join(OUTPUT_PATH, f"T_scheduler_{SEED}_{step}.pt"),
        )

        log["CONTINUE"] = CONTINUE
        with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
            log_str = json.dumps(log, indent=4)
            log_file.write(log_str)

    if step % TRACK_VAR_INTERVAL == 0:
        pass
        # print("Computing FID")
        # mu, sigma = get_pushed_loader_stats(
        #     T,
        #     XY_test_sampler.loader,
        #     n_epochs=FID_EPOCHS,
        #     batch_size=BATCH_SIZE,
        #     verbose=True,
        # )
        # fid = calculate_frechet_distance(mu_data, sigma_data, mu, sigma)
        # print(f"FID={fid}")
        # writer.add_scalar("Metrics/FID", fid, step)
        # del mu, sigma

        # print("Computing LPIPS(vgg) LPIPS(alex) L1 MSE")
        # metrics = get_pushed_loader_metrics(
        #     T,
        #     XY_test_sampler.loader,
        #     n_epochs=FID_EPOCHS,
        #     batch_size=BATCH_SIZE,
        #     verbose=True,
        #     metrics=["mse", "l1"]
        # )
        # print(f"metrics={metrics}")
        # writer.add_scalar("Metrics/LPIPS(VGG)", metrics["vgg"], step)
        # writer.add_scalar("Metrics/LPIPS(Alex)", metrics["alex"], step)
        # writer.add_scalar("Metrics/L1", metrics["l1"], step)
        # writer.add_scalar("Metrics/MSE", metrics["mse"], step)

    gc.collect()
    torch.cuda.empty_cache()

## Clear resources


In [None]:
try:
    writer.close()
    progress_bar.close()
except Exception as e:
    print(e)