## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings
from typing import List

import torch
import torch.nn as nn
import numpy as np
from diffusers import DDIMScheduler
from tensorboardX import SummaryWriter

import matplotlib.pyplot as plt
from PIL import PngImagePlugin
from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

sys.path.append("..")
from src.resnet2 import ResNet_D
from src.losses import VGGPerceptualLoss as VGGLoss
from src.tools import (
    set_random_seed,
    unfreeze,
    freeze,
    weights_init_D,
    fig2tensor,
)  # for wandb
from src.plotters import (
    plot_linked_pushed_images,
    plot_linked_pushed_random_paired_images,
)

from src.samplers import PairedLoaderSampler, get_paired_sampler

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Config

Dataset choosing in the first rows


In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
# face2comic
# DATASET, DATASET_PATH, MAP_NAME, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', "comic2face", False
# colored mask -> face
DATASET, DATASET_PATH, MAP_NAME, REVERSE = (
    "celeba_mask",
    "../datasets/CelebAMask-HQ",
    "colored_mask2face",
    False,
)
# sketch -> photo
# DATASET, DATASET_PATH, MAP_NAME, REVERSE = (
#     "FS2K",
#     "../datasets/FS2K/",
#     "sketch2photo",
#     False,
# )

IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [50, 80, 120]

# GPU choosing
DEVICE_IDS = [1]
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_IDS[0]}")

# first is for step, setting the value (checkpoints step + 1);
# last is for sdes, setting the value be (num of train-finished sde + 1).
CONTINUE = [4001, 0]

# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
STRATEGY = "Fix"  # 'Fix' or 'Adapt'
BATCH_SIZE = 8
T_ITERS = 10
MAX_STEPS = 10000 + 1  # 2501 for testing
COST = "vgg"  #'mse' # 'mae' # 'vgg'
# model settings
# network settings
NOT = True  # Train Neural optimal transport or pure regression
T_TYPE = "U2Net"  # 'UNet' # or  ('ResNet_pix2pix' - not implemented)
UNET_BASE_FACTOR = 48  # For UNet
D_TYPE = (
    "ResNet"  # or 'ResNet_pix2pix' - DOES NOT WORK WELL (it is actually not a resnet:)
)
D_USE_BATCHNORM = False  # For ResNet_D

# optimizer settings
D_LR, T_LR = 1e-4, 1e-4
T_GRADIENT_MAX_NORM = float(500)
D_GRADIENT_MAX_NORM = float(500)
SCHEDULER_MILESTONES = [2500, 4000]

# plot settings
GRAY_PLOTS = False
PLOT_N_SAMPLES = 8
# log settings
TRACK_VAR_INTERVAL = 500
PLOT_INTERVAL = 500
CKPT_INTERVAL = 2000

FID_EPOCHS = 1

EXP_NAME = f"DNOT_Paired_{DATASET}_{STRATEGY}_{SEED}"
OUTPUT_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

writer = SummaryWriter(f"../logdir/{EXP_NAME}")

In [None]:
config = dict(
    SEED=SEED,
    DATASET=DATASET,
    MAP_NAME=MAP_NAME,
    REVERSE=REVERSE,
    IMG_SIZE=IMG_SIZE,
    DATASET1_CHANNELS=DATASET1_CHANNELS,
    DATASET2_CHANNELS=DATASET2_CHANNELS,
    DIFFUSION_STEPS=DIFFUSION_STEPS,
    PIVOTAL_LIST=PIVOTAL_LIST,
    STRATEGY=STRATEGY,
    BATCH_SIZE=BATCH_SIZE,
    T_ITERS=T_ITERS,
    MAX_STEPS=MAX_STEPS,
    COST=COST,
    NOT=NOT,
    T_TYPE=T_TYPE,
    UNET_BASE_FACTOR=UNET_BASE_FACTOR,
    D_TYPE=D_TYPE,
    D_USE_BATCHNORM=D_USE_BATCHNORM,
    D_LR=D_LR,
    T_LR=T_LR,
    T_GRADIENT_MAX_NORM=T_GRADIENT_MAX_NORM,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    GRAY_PLOTS=GRAY_PLOTS,
    PLOT_N_SAMPLES=PLOT_N_SAMPLES,
    TRACK_VAR_INTERVAL=TRACK_VAR_INTERVAL,
    PLOT_INTERVAL=PLOT_INTERVAL,
    CKPT_INTERVAL=CKPT_INTERVAL,
    FID_EPOCHS=FID_EPOCHS,
)
with open(os.path.join(OUTPUT_PATH, "config.json"), "w") as json_file:
    json_str = json.dumps(config, indent=4)
    json_file.write(json_str)

log = dict(CONTINUE=CONTINUE)
with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
    log_str = json.dumps(log, indent=4)
    log_file.write(log_str)

In [None]:
# if not REVERSE:
#     filename = f"../stats/{DATASET}_{MAP_NAME.split('2')[1]}_{IMG_SIZE}_test.json"
# else:
#     filename = f"../stats/{DATASET}_{MAP_NAME.split('2')[0]}_{IMG_SIZE}_test.json"

# with open(filename, "r") as fp:
#     data_stats = json.load(fp)
#     mu_data, sigma_data = data_stats["mu"], data_stats["sigma"]
# del data_stats

In [None]:
if COST == "vgg":
    vgg_loss = VGGLoss().cuda()

## 3. Initialize samplers


In [None]:
XY_sampler, XY_test_sampler = get_paired_sampler(
    DATASET, DATASET_PATH, img_size=IMG_SIZE, batch_size=BATCH_SIZE, reverse=REVERSE
)

torch.cuda.empty_cache()
gc.collect()
clear_output()

### pivotal sampler


In [None]:
SCHEDULER = DDIMScheduler(num_train_timesteps=DIFFUSION_STEPS)


def sample_all_pivotal(
    XY_sampler: PairedLoaderSampler,
    batch_size: int = 4,
) -> List[torch.Tensor]:
    pivotal_path = []

    source, target = XY_sampler.sample(batch_size)

    source_list = [source]
    target_list = [target]
    for i in range(min(DIFFUSION_STEPS, PIVOTAL_LIST[-1])):
        source = SCHEDULER.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = SCHEDULER.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in PIVOTAL_LIST:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list[1:])  # just using source's last pivotal point
    # pivotal_path.extend(target_list[:]) # 2 last pivotal points mapping

    return pivotal_path


# def sample_step_t_pivotal(
#     XY_sampler: PairedLoaderSampler,
#     batch_size: int = 4,
#     pivotal_step: int = 0,
# ):
#     pivotal_path = sample_all_pivotal(XY_sampler, batch_size)
#     pivotal_t, pivotal_t_next = (
#         pivotal_path[pivotal_step],
#         pivotal_path[pivotal_step + 1],
#     )
#     return pivotal_t, pivotal_t_next

### mapping plotters


In [None]:
def plot_all_pivotal(
    source: torch.Tensor,
    target: torch.Tensor,
    gray: bool = False,
) -> list:
    pivotal_path = []

    source_list = [source]
    target_list = [target]
    for i in range(min(DIFFUSION_STEPS, PIVOTAL_LIST[-1])):
        source = SCHEDULER.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = SCHEDULER.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in PIVOTAL_LIST:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list[1:])  # just using source's last pivotal point
    # pivotal_path.extend(target_list[:]) # 2 last pivotal points mapping

    imgs: np.ndarray = (
        torch.stack(pivotal_path)
        .to("cpu")
        .permute(0, 2, 3, 1)
        .mul(0.5)
        .add(0.5)
        .numpy()
        .clip(0, 1)
    )
    nrows, ncols = 1, len(pivotal_path)
    fig = plt.figure(figsize=(1.5 * ncols, 1.5 * nrows), dpi=150)
    for i, img in enumerate(imgs):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        if gray:
            ax.imshow(img, cmap="gray")
        else:
            ax.imshow(img)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_title(f"$X_{i}$", fontsize=24)
        if i == imgs.shape[0] - 1:
            ax.set_title("Y", fontsize=24)

    return fig, fig.axes

## 4. Training


### Models initialization


In [None]:
from src.u2net import U2NET
from src.unet import UNet


Ts, Ds = [], []
T_OPTs, D_OPTs = [], []
T_SCHEDULERs, D_SCHEDULERs = [], []

for i in range(len(PIVOTAL_LIST) * 2):
    if T_TYPE == "UNet":
        T = UNet(
            DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=UNET_BASE_FACTOR
        ).cuda()
    elif T_TYPE == "U2Net":
        T = U2NET(in_ch=DATASET1_CHANNELS, out_ch=DATASET2_CHANNELS).cuda()
    else:
        raise NotImplementedError("Unknown T_TYPE: {}".format(T_TYPE))
    Ts.append(T)

    if D_TYPE == "ResNet":
        D = ResNet_D(
            IMG_SIZE,
            nc=DATASET2_CHANNELS,
            bn=D_USE_BATCHNORM,
        ).cuda()
        D.apply(weights_init_D)
    else:
        raise NotImplementedError("Unknown D_TYPE: {}".format(D_TYPE))
    D.apply(weights_init_D)
    Ds.append(D)

    T_opt = torch.optim.Adam(T.parameters(), lr=T_LR, weight_decay=1e-10)
    D_opt = torch.optim.Adam(D.parameters(), lr=D_LR, weight_decay=1e-10)
    T_OPTs.append(T_opt)
    D_OPTs.append(D_opt)

    T_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        T_opt, milestones=SCHEDULER_MILESTONES, gamma=0.5
    )
    D_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        D_opt, milestones=SCHEDULER_MILESTONES, gamma=0.5
    )
    T_SCHEDULERs.append(T_scheduler)
    D_SCHEDULERs.append(D_scheduler)


if len(DEVICE_IDS) > 1 and CONTINUE[0] == 0 and CONTINUE[1] == 0:
    for i in range(len(Ts)):
        Ts[i] = nn.DataParallel(Ts[i], device_ids=DEVICE_IDS)
        Ds[i] = nn.DataParallel(Ds[i], device_ids=DEVICE_IDS)

        print(f"T{i} params:", np.sum([np.prod(p.shape) for p in Ts[i].parameters()]))
        print(
            f"D{i} params:",
            np.sum([np.prod(p.shape) for p in Ds[i].parameters()]),
        )

### Load weights for continue training


In [None]:
if CONTINUE[0] > 0 or CONTINUE[1] > 0:
    print("Loading weights for continue training")
    if STRATEGY == "Adapt":
        for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
            zip(
                Ts,
                T_OPTs,
                T_SCHEDULERs,
                Ds,
                D_OPTs,
                D_SCHEDULERs,
            )
        ):
            if i > CONTINUE[1]:
                if len(DEVICE_IDS) > 1:
                    T = nn.DataParallel(T, device_ids=DEVICE_IDS)
                    D = nn.DataParallel(D, device_ids=DEVICE_IDS)
                continue
            if i < CONTINUE[1]:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{MAX_STEPS - 1}/")
            if i == CONTINUE[1]:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{CONTINUE[0] - 1}/")

            T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} T{i}_{SEED}.pt, loaded")
            T_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_opt{i}_{SEED}.pt, loaded")
            T_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_scheduler{i}_{SEED}.pt, loaded")

            D.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} D{i}_{SEED}.pt, loaded")
            D_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_opt{i}_{SEED}.pt, loaded")
            D_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_scheduler{i}_{SEED}.pt, loaded")
            if len(DEVICE_IDS) > 1:
                T = nn.DataParallel(T, device_ids=DEVICE_IDS)
                D = nn.DataParallel(D, device_ids=DEVICE_IDS)

    if STRATEGY == "Fix":
        CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{CONTINUE[0] - 1}/")
        for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
            zip(
                Ts,
                T_OPTs,
                T_SCHEDULERs,
                Ds,
                D_OPTs,
                D_SCHEDULERs,
            )
        ):
            T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} T{i}_{SEED}.pt, loaded")
            T_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_opt{i}_{SEED}.pt, loaded")
            T_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_scheduler{i}_{SEED}.pt, loaded")

            D.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} D{i}_{SEED}.pt, loaded")
            D_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_opt{i}_{SEED}.pt, loaded")
            D_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_scheduler{i}_{SEED}.pt, loaded")
            if len(DEVICE_IDS) > 1:
                T = nn.DataParallel(T, device_ids=DEVICE_IDS)
                D = nn.DataParallel(D, device_ids=DEVICE_IDS)

In [None]:
# writer.add_graph(
#     Ts[0], torch.rand(BATCH_SIZE, DATASET1_CHANNELS, IMG_SIZE, IMG_SIZE).cuda()
# )

### Plots Test


In [None]:
X_fixed, Y_fixed = XY_sampler.sample(PLOT_N_SAMPLES)
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(PLOT_N_SAMPLES)

In [None]:
fig, axes = plot_all_pivotal(X_fixed[0], Y_fixed[0], GRAY_PLOTS)

In [None]:
fig, axes = plot_linked_pushed_images(X_fixed, Y_fixed, Ts, gray=GRAY_PLOTS)

In [None]:
fig, axes = plot_linked_pushed_random_paired_images(
    XY_sampler, Ts, plot_n_samples=PLOT_N_SAMPLES, gray=GRAY_PLOTS
)

### Main training cycle and logging


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
def get_cat_pivotal_tr(X_sampler, Y_sampler, batch_size):
    X_tr = sample_all_pivotal(X_sampler, batch_size)
    Y_tr = sample_all_pivotal(Y_sampler, batch_size)
    assert len(X_tr) == len(Y_tr)
    length = len(X_tr)
    mid_indx = length // 2
    tr = X_tr[: mid_indx + 1] + Y_tr[mid_indx + 1 :]

    del X_tr[mid_indx + 1 :]
    del Y_tr[: mid_indx + 1]

    gc.collect()
    torch.cuda.empty_cache()
    return tr

#### Fix strategy


In [None]:
if STRATEGY == "Fix":
    progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE[0])

    for step in range(MAX_STEPS):
        if step < CONTINUE[0]:
            continue

        for T, D in zip(Ts, Ds):
            freeze(T)
            freeze(D)
        gc.collect()
        torch.cuda.empty_cache()
        for _ in range(T_ITERS):
            fixed_trajectory = sample_all_pivotal(XY_sampler, BATCH_SIZE)

            for i, (T, T_opt, D) in enumerate(zip(Ts, T_OPTs, Ds)):
                freeze(D)
                unfreeze(T)

                X, Y = fixed_trajectory[i], fixed_trajectory[i + 1]

                # === mapping and optimize ===
                T_opt.zero_grad()

                T_X = T(X)

                if COST == "rmse":
                    T_loss = (Y - T_X[:, :3]).flatten(start_dim=1).norm(dim=1).mean()
                elif COST == "mse":
                    T_loss = (
                        (Y - T_X[:, :3]).flatten(start_dim=1).square().sum(dim=1).mean()
                    )
                elif COST == "mae":
                    T_loss = (
                        (Y - T_X[:, :3]).flatten(start_dim=1).abs().sum(dim=1).mean()
                    )
                elif COST == "vgg":
                    T_loss = vgg_loss(Y, T_X[:, :3]).mean()
                else:
                    raise Exception("Unknown COST")
                if NOT:
                    T_loss -= D(T_X).mean()

                writer.add_scalar(f"T_loss/T{i}", T_loss.item(), step)
                T_loss.backward()
                T_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    T.parameters(), max_norm=T_GRADIENT_MAX_NORM
                )
                T_opt.step()
                # del X, Y, T_X, T_loss
                # gc.collect()
                # torch.cuda.empty_cache()
            # del fixed_trajectory
            # gc.collect()
            # torch.cuda.empty_cache()
        for T_scheduler in T_SCHEDULERs:
            T_scheduler.step()

        if NOT:
            for T, D in zip(Ts, Ds):
                freeze(T)
                freeze(D)
            fixed_trajectory = get_cat_pivotal_tr(XY_sampler, XY_sampler, BATCH_SIZE)
            for i, (D, D_opt, D_scheduler, T) in enumerate(
                zip(Ds, D_OPTs, D_SCHEDULERs, Ts)
            ):
                freeze(T)
                unfreeze(D)

                X, Y = fixed_trajectory[i], fixed_trajectory[i + 1]
                # === mapping ===
                with torch.no_grad():
                    T_X = T(X)
                # === optimize ===
                D_opt.zero_grad()
                D_loss = D(T_X).mean() - D(Y).mean()
                writer.add_scalar(f"D_loss/D{i}", D_loss.item(), step)
                D_loss.backward()
                D_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    D.parameters(), max_norm=D_GRADIENT_MAX_NORM
                )
                D_opt.step()
                D_scheduler.step()
                # del X, Y, T_X, D_loss
                # torch.cuda.empty_cache()
                # gc.collect()
            # del fixed_trajectory
            # gc.collect()
            # torch.cuda.empty_cache()

        CONTINUE[0] += 1
        progress_bar.update(1)

        if step % PLOT_INTERVAL == 0:
            progress_bar.close()
            clear_output(wait=True)
            progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE[0])
            print("Plotting")

            inference_Ts = Ts
            for T in inference_Ts:
                T.eval()
            print("Fixed Test Images")
            fig, axes = plot_linked_pushed_images(
                X_test_fixed, Y_test_fixed, inference_Ts, gray=GRAY_PLOTS
            )
            writer.add_image("Fixed Test Images", fig2tensor(fig), step)
            plt.show(fig)
            plt.close(fig)
            print("Random Test Images")
            fig, axes = plot_linked_pushed_random_paired_images(
                XY_test_sampler,
                inference_Ts,
                plot_n_samples=PLOT_N_SAMPLES,
                gray=GRAY_PLOTS,
            )
            writer.add_image("Random Test Images", fig2tensor(fig), step)
            plt.show(fig)
            plt.close(fig)

        if step != 0 and step % CKPT_INTERVAL == 0:
            CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{step}/")
            os.makedirs(CKPT_DIR, exist_ok=True)
            for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
                zip(
                    Ts,
                    T_OPTs,
                    T_SCHEDULERs,
                    Ds,
                    D_OPTs,
                    D_SCHEDULERs,
                )
            ):
                if len(DEVICE_IDS) > 1:
                    torch.save(
                        T.module.state_dict(),
                        os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt"),
                    )
                    torch.save(
                        D.module.state_dict(),
                        os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt"),
                    )
                else:
                    torch.save(
                        T.state_dict(), os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
                    )
                    torch.save(
                        D.state_dict(), os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")
                    )

                torch.save(
                    D_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    T_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    D_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"),
                )
                torch.save(
                    T_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"),
                )
            log["CONTINUE"] = CONTINUE
            with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
                log_str = json.dumps(log, indent=4)
                log_file.write(log_str)

        if step % TRACK_VAR_INTERVAL == 0:
            pass
            # print("Computing FID")
            # mu, sigma = get_linked_sde_pushed_loader_stats(
            #     Ts,
            #     XY_test_sampler.loader,
            #     n_epochs=FID_EPOCHS,
            #     batch_size=BATCH_SIZE,
            #     verbose=True,
            # )
            # fid = calculate_frechet_distance(mu_data, sigma_data, mu, sigma)
            # print(f"FID={fid}")
            # writer.add_scalar("Metrics/FID", fid, step)
            # del mu, sigma

            # print("Computing LPIPS(vgg) LPIPS(alex) L1 MSE")
            # metrics = get_linked_sde_pushed_loader_metrics(
            #     Ts,
            #     XY_test_sampler.loader,
            #     n_epochs=FID_EPOCHS,
            #     batch_size=BATCH_SIZE,
            #     verbose=True,
            #     log_metrics=["mse", "l1"]
            # )
            # print(f"metrics={metrics}")
            # writer.add_scalar("Metrics/LPIPS(VGG)", metrics["vgg"], step)
            # writer.add_scalar("Metrics/LPIPS(Alex)", metrics["alex"], step)
            # writer.add_scalar("Metrics/L1", metrics["l1"], step)
            # writer.add_scalar("Metrics/MSE", metrics["mse"], step)

        # gc.collect()
        # torch.cuda.empty_cache()

#### Adapt strategy


In [None]:
if STRATEGY == "Adapt":
    for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
        zip(
            Ts,
            T_OPTs,
            T_SCHEDULERs,
            Ds,
            D_OPTs,
            D_SCHEDULERs,
        )
    ):
        if i < CONTINUE[1]:
            continue

        progress_bar = tqdm(
            total=MAX_STEPS, initial=CONTINUE[0], desc=f"{i + 1}/{len(Ts)}:"
        )
        for _T, _D in zip(Ts, Ds):
            freeze(_T)
            freeze(_D)

        for step in range(MAX_STEPS):
            if step < CONTINUE[0]:
                continue

            for _ in range(T_ITERS):
                freeze(D)
                unfreeze(T)

                # === sampler training data ===
                fixed_trajectory = sample_all_pivotal(XY_sampler, BATCH_SIZE)
                X, Y = fixed_trajectory[0], fixed_trajectory[i + 1]
                with torch.no_grad():
                    for _ in range(i):
                        X = Ts[i](X)
                X = X.requires_grad_()
                # === mapping and optimize ===
                T_opt.zero_grad()

                T_X = T(X)
                if COST == "rmse":
                    T_loss = (Y - T_X[:, :3]).flatten(start_dim=1).norm(dim=1).mean()
                elif COST == "mse":
                    T_loss = (
                        (Y - T_X[:, :3]).flatten(start_dim=1).square().sum(dim=1).mean()
                    )
                elif COST == "mae":
                    T_loss = (
                        (Y - T_X[:, :3]).flatten(start_dim=1).abs().sum(dim=1).mean()
                    )
                elif COST == "vgg":
                    T_loss = vgg_loss(Y, T_X[:, :3]).mean()
                else:
                    raise Exception("Unknown COST")

                if NOT:
                    T_loss -= D(T_X).mean()
                writer.add_scalar(f"T_loss/T{i}", T_loss.item(), step)
                T_loss.backward()
                T_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    T.parameters(), max_norm=T_GRADIENT_MAX_NORM
                )
                T_opt.step()
                # del fixed_trajectory, X, Y, T_X, T_loss
                # gc.collect()
                # torch.cuda.empty_cache()
            T_scheduler.step()

            if NOT:
                freeze(T)
                unfreeze(D)
                # === sampler training data ===
                fixed_trajectory = get_cat_pivotal_tr(
                    XY_sampler, XY_sampler, BATCH_SIZE
                )
                X, Y = fixed_trajectory[0], fixed_trajectory[i + 1]
                with torch.no_grad():
                    for _ in range(i):
                        X = Ts[i](X)
                X = X.requires_grad_()
                # === mapping ===
                with torch.no_grad():
                    T_X = T(X)
                # === optimize ===
                D_opt.zero_grad()
                D_loss = D(T_X).mean() - D(Y).mean()
                writer.add_scalar(f"D_loss/D{i}", D_loss.item(), step)
                D_loss.backward()
                D_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    D.parameters(), max_norm=D_GRADIENT_MAX_NORM
                )
                D_opt.step()
                D_scheduler.step()

                # del fixed_trajectory, X, Y, T_X, D_loss
                # gc.collect()
                # torch.cuda.empty_cache()
            CONTINUE[0] += 1
            progress_bar.update(1)

            if step % PLOT_INTERVAL == 0:
                progress_bar.close()
                clear_output(wait=True)
                progress_bar = tqdm(
                    total=MAX_STEPS, initial=CONTINUE[0], desc=f"{i + 1}/{len(Ts)}:"
                )
                print("Plotting")

                inference_Ts = Ts
                for T in inference_Ts:
                    T.eval()

                print("Fixed Test Images")
                fig, axes = plot_linked_pushed_images(
                    X_test_fixed, Y_test_fixed, inference_Ts, gray=GRAY_PLOTS
                )
                writer.add_image(f"Fix Test Images/T{i + 1}", fig2tensor(fig), step)
                plt.show(fig)
                plt.close(fig)

                print("Random Test Images")
                fig, axes = plot_linked_pushed_random_paired_images(
                    XY_test_sampler,
                    inference_Ts,
                    plot_n_samples=PLOT_N_SAMPLES,
                    gray=GRAY_PLOTS,
                )
                writer.add_image(f"Random Test Images/T{i + 1}", fig2tensor(fig), step)
                plt.show(fig)
                plt.close(fig)

            if step != 0 and step % CKPT_INTERVAL == 0:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{step}/")
                os.makedirs(CKPT_DIR, exist_ok=True)
                if len(DEVICE_IDS) > 1:
                    torch.save(
                        T.module.state_dict(),
                        os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt"),
                    )
                    torch.save(
                        D.module.state_dict(),
                        os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt"),
                    )
                else:
                    torch.save(
                        T.state_dict(), os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
                    )
                    torch.save(
                        D.state_dict(), os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")
                    )

                torch.save(
                    D_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    T_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    D_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"),
                )
                torch.save(
                    T_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"),
                )
                log = dict(CONTINUE=CONTINUE)
                with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
                    log_str = json.dumps(log, indent=4)
                    log_file.write(log)

            if i == len(Ts) - 1 and step % TRACK_VAR_INTERVAL == 0:
                pass
                # print("Computing FID")
                # mu, sigma = get_linked_sde_pushed_loader_stats(
                #     Ts,
                #     XY_test_sampler.loader,
                #     n_epochs=FID_EPOCHS,
                #     batch_size=BATCH_SIZE,
                #     verbose=True,
                # )
                # fid = calculate_frechet_distance(mu_data, sigma_data, mu, sigma)
                # writer.add_scalar(f"Metrics/FID{i}", fid, step)
                # del mu, sigma

                # print("Computing LPIPS(vgg) LPIPS(alex) L1 MSE")
                # metrics = get_linked_sde_pushed_loader_metrics(
                #     Ts,
                #     XY_test_sampler.loader,
                #     n_epochs=FID_EPOCHS,
                #     batch_size=BATCH_SIZE,
                #     verbose=True,
                #     log_metrics=["mse", "l1"]
                # )
                # print(f"metrics={metrics}")
                # writer.add_scalar("LPIPS(VGG)", metrics["vgg"], step)
                # writer.add_scalar("LPIPS(Alex)", metrics["alex"], step)
                # writer.add_scalar("L1", metrics["l1"], step)
                # writer.add_scalar("MSE", metrics["mse"], step)

            # gc.collect()
            # torch.cuda.empty_cache()
        CONTINUE[0] = 0  # reset training steps to 0
        CONTINUE[1] += 1

## Clear resources


In [None]:
try:
    writer.close()
    progress_bar.close()
except Exception as e:
    print(e)

In [None]:
inference_Ts = Ts
for T in inference_Ts:
    T.eval()

print("Fixed Test Images")
fig, axes = plot_linked_pushed_images(
    X_test_fixed, Y_test_fixed, inference_Ts, gray=GRAY_PLOTS
)
writer.add_image(f"Fix Test Images/T{i + 1}", fig2tensor(fig), step)
plt.show(fig)
plt.close(fig)

print("Random Test Images")
fig, axes = plot_linked_pushed_random_paired_images(
    XY_test_sampler,
    inference_Ts,
    plot_n_samples=PLOT_N_SAMPLES,
    gray=GRAY_PLOTS,
)
writer.add_image(f"Random Test Images/T{i + 1}", fig2tensor(fig), step)
plt.show(fig)
plt.close(fig)