## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings
from typing import List

import torch
import torch.nn as nn
import numpy as np
from diffusers import DDIMScheduler
from tensorboardX import SummaryWriter

import matplotlib.pyplot as plt
from PIL import PngImagePlugin
from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

sys.path.append("..")
from src.enot import SDE, integrate
from src.resnet2 import ResNet_D
from src.cunet import CUNet

from src.tools import (
    set_random_seed,
    unfreeze,
    freeze,
    weights_init_D,
    fig2tensor,
    get_linked_sde_pushed_loader_metrics,
    get_linked_sde_pushed_loader_stats,
)  # for wandb
from src.fid_score import calculate_frechet_distance
from src.plotters_paired import (
    plot_linked_sde_pushed_images,
    plot_linked_sde_pushed_random_paired_images,
)

from src.samplers import PairedLoaderSampler, get_paired_sampler

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Config

Dataset choosing in the first rows


In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
# face2comic
# DATASET, DATASET_PATH, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', "comic2face", False
# colored mask -> face
# DATASET, DATASET_PATH, REVERSE = "celeba_mask", "../datasets/CelebAMask-HQ", "colored_mask2face", False
# sketch -> photo
DATASET, DATASET_PATH, MAP_NAME, REVERSE = (
    "FS2K",
    "../datasets/FS2K/",
    "sketch2photo",
    False,
)

IMG_SIZE = 256
DATASET1_CHANNELS = 3
DATASET2_CHANNELS = 3

# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [50, 100, 200]

# GPU choosing
DEVICE_IDS = [0]
assert torch.cuda.is_available()

CONTINUE = [
    0,
    0,
]  # first is for step, setting the value (checkpoints step + 1); last is for sdes, setting the value be (num of train-finished sde + 1).

# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
STRATEGY = "Fix"  # 'Fix' or 'Adapt'

BATCH_SIZE = 2
T_ITERS = 10
MAX_STEPS = 5000 + 1  # 2501 for testing
INTEGRAL_SCALE = 1 / (3 * IMG_SIZE * IMG_SIZE)
EPSILON_SCHEDULER_LAST_ITER = 20000

# optimizer settings
D_LR, T_LR = 1e-4, 1e-4
BETA_D, BETA_T = 0.9, 0.9
T_GRADIENT_MAX_NORM = float(500)
D_GRADIENT_MAX_NORM = float(500)

# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 5  # num of shifts time
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

# plot settings
GRAY_PLOTS = False
STEPS_TO_SHOW = 10

# log settings
SMART_INTERVALS = False
INTERVAL_SHRINK_START_TIME = 0.98
TRACK_VAR_INTERVAL = 10
PLOT_INTERVAL = 500
CPKT_INTERVAL = 500

FID_EPOCHS = 1

EXP_NAME = f"Ours_Paired_{DATASET}_{STRATEGY}_{SEED}"
OUTPUT_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

writer = SummaryWriter(f"../logdir/{EXP_NAME}")

In [None]:
config = dict(
    SEED=SEED,
    DIFFUSION_STEPS=DIFFUSION_STEPS,
    PIVOTAL_LIST=PIVOTAL_LIST,
    DATASET=DATASET,
    T_ITERS=T_ITERS,
    D_LR=D_LR,
    T_LR=T_LR,
    STRATEGY=STRATEGY,
    BATCH_SIZE=BATCH_SIZE,
    UNET_BASE_FACTOR=UNET_BASE_FACTOR,
    N_STEPS=N_STEPS,
    EPSILON=EPSILON,
    USE_POSITIONAL_ENCODING=USE_POSITIONAL_ENCODING,
    TIME_DIM=TIME_DIM,
    INTEGRAL_SCALE=INTEGRAL_SCALE,
    ONE_STEP_INIT_ITERS=ONE_STEP_INIT_ITERS,
    T_GRADIENT_MAX_NORM=T_GRADIENT_MAX_NORM,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    PREDICT_SHIFT=PREDICT_SHIFT,
    SMART_INTERVALS=SMART_INTERVALS,
    INTERVAL_SHRINK_START_TIME=INTERVAL_SHRINK_START_TIME,
    USE_GRADIENT_CHECKPOINT=USE_GRADIENT_CHECKPOINT,
    N_LAST_STEPS_WITHOUT_NOISE=N_LAST_STEPS_WITHOUT_NOISE,
    TRACK_VAR_INTERVAL=TRACK_VAR_INTERVAL,
    EPSILON_SCHEDULER_LAST_ITER=EPSILON_SCHEDULER_LAST_ITER,
    FID_EPOCHS=FID_EPOCHS,
)
with open(os.path.join(OUTPUT_PATH, "config.json"), "w") as json_file:
    json_str = json.dumps(config, indent=4)
    json_file.write(json_str)

log = dict(CONTINUE=CONTINUE)
with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
    log_str = json.dumps(log, indent=4)
    log_file.write(log_str)

In [None]:
if not REVERSE:
    filename = f"../stats/{DATASET}_{MAP_NAME.split('2')[1]}_{IMG_SIZE}_test.json"
else:
    filename = f"../stats/{DATASET}_{MAP_NAME.split('2')[0]}_{IMG_SIZE}_test.json"

with open(filename, "r") as fp:
    data_stats = json.load(fp)
    mu_data, sigma_data = data_stats["mu"], data_stats["sigma"]
del data_stats

## 3. Initialize samplers


In [None]:
XY_sampler, XY_test_sampler = get_paired_sampler(
    DATASET, DATASET_PATH, img_size=IMG_SIZE, batch_size=BATCH_SIZE, reverse=REVERSE
)

torch.cuda.empty_cache()
gc.collect()
clear_output()

### pivotal sampler


In [None]:
SCHEDULER = DDIMScheduler(num_train_timesteps=DIFFUSION_STEPS)


def sample_all_pivotal(
    XY_sampler: PairedLoaderSampler,
    batch_size: int = 4,
) -> List[torch.Tensor]:
    pivotal_path = []

    source, target = XY_sampler.sample(batch_size)

    source_list = [source]
    target_list = [target]
    for i in range(min(DIFFUSION_STEPS, PIVOTAL_LIST[-1])):
        source = SCHEDULER.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = SCHEDULER.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in PIVOTAL_LIST:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list[1:])  # just using source's last pivotal point
    # pivotal_path.extend(target_list[:]) # 2 last pivotal points mapping

    return pivotal_path


# def sample_step_t_pivotal(
#     XY_sampler: PairedLoaderSampler,
#     batch_size: int = 4,
#     pivotal_step: int = 0,
# ):
#     pivotal_path = sample_all_pivotal(XY_sampler, batch_size)
#     pivotal_t, pivotal_t_next = (
#         pivotal_path[pivotal_step],
#         pivotal_path[pivotal_step + 1],
#     )
#     return pivotal_t, pivotal_t_next

### mapping plotters


In [None]:
def plot_all_pivotal(
    source: torch.Tensor,
    target: torch.Tensor,
    gray: bool = False,
) -> list:
    pivotal_path = []

    source_list = [source]
    target_list = [target]
    for i in range(min(DIFFUSION_STEPS, PIVOTAL_LIST[-1])):
        source = SCHEDULER.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = SCHEDULER.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in PIVOTAL_LIST:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list[1:])  # just using source's last pivotal point
    # pivotal_path.extend(target_list[:]) # 2 last pivotal points mapping

    imgs: np.ndarray = (
        torch.stack(pivotal_path)
        .to("cpu")
        .permute(0, 2, 3, 1)
        .mul(0.5)
        .add(0.5)
        .numpy()
        .clip(0, 1)
    )
    nrows, ncols = 1, len(pivotal_path)
    fig = plt.figure(figsize=(1.5 * ncols, 1.5 * nrows), dpi=150)
    for i, img in enumerate(imgs):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        if gray:
            ax.imshow(img, cmap="gray")
        else:
            ax.imshow(img)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_title(f"$X_{i}$", fontsize=24)
        if i == imgs.shape[0] - 1:
            ax.set_title("Y", fontsize=24)

    torch.cuda.empty_cache()
    gc.collect()

## 4. Training


### Models initialization


In [None]:
SDEs, BETA_NETs = [], []
SDE_OPTs, BETA_NET_OPTs = [], []
SDE_SCHEDULERs, BETA_NET_SCHEDULERs = [], []

for i in range(len(PIVOTAL_LIST) * 2):
    T = CUNet(
        DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
    ).cuda()

    T = SDE(
        shift_model=T,
        epsilon=EPSILON,
        n_steps=N_STEPS,
        time_dim=TIME_DIM,
        n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
        use_positional_encoding=USE_POSITIONAL_ENCODING,
        use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
        predict_shift=PREDICT_SHIFT,
        image_input=IMAGE_INPUT,
    ).cuda()
    SDEs.append(T)

    D = ResNet_D(IMG_SIZE, nc=DATASET2_CHANNELS).cuda()
    D.apply(weights_init_D)
    BETA_NETs.append(D)

    T_opt = torch.optim.Adam(
        T.parameters(), lr=T_LR, weight_decay=1e-10, betas=(BETA_T, 0.999)
    )
    D_opt = torch.optim.Adam(
        D.parameters(), lr=D_LR, weight_decay=1e-10, betas=(BETA_D, 0.999)
    )
    SDE_OPTs.append(T_opt)
    BETA_NET_OPTs.append(D_opt)

    T_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        T_opt, milestones=[15000, 25000, 40000, 55000, 70000], gamma=0.5
    )
    D_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        D_opt, milestones=[15000, 25000, 40000, 55000, 70000], gamma=0.5
    )
    SDE_SCHEDULERs.append(T_scheduler)
    BETA_NET_SCHEDULERs.append(D_scheduler)


if len(DEVICE_IDS) > 1 and CONTINUE[0] == 0 and CONTINUE[1] == 0:
    for i in range(len(SDEs)):
        SDEs[i] = nn.DataParallel(SDEs[i], device_ids=DEVICE_IDS)
        BETA_NETs[i] = nn.DataParallel(BETA_NETs[i], device_ids=DEVICE_IDS)

        print(f"T{i} params:", np.sum([np.prod(p.shape) for p in SDEs[i].parameters()]))
        print(
            f"D{i} params:",
            np.sum([np.prod(p.shape) for p in BETA_NETs[i].parameters()]),
        )

### Load weights for continue training


In [None]:
if CONTINUE[0] > 0 or CONTINUE[1] > 0:
    print("Loading weights for continue training")
    if STRATEGY == "Adapt":
        for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
            zip(
                SDEs,
                SDE_OPTs,
                SDE_SCHEDULERs,
                BETA_NETs,
                BETA_NET_OPTs,
                BETA_NET_SCHEDULERs,
            )
        ):
            if i > CONTINUE[1]:
                if len(DEVICE_IDS) > 1:
                    T = nn.DataParallel(T, device_ids=DEVICE_IDS)
                    D = nn.DataParallel(D, device_ids=DEVICE_IDS)
                continue
            if i < CONTINUE[1]:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{MAX_STEPS - 1}/")
            if i == CONTINUE[1]:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{CONTINUE[0] - 1}/")

            T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} T{i}_{SEED}.pt, loaded")
            T_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_opt{i}_{SEED}.pt, loaded")
            T_scheduler[i].load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_scheduler{i}_{SEED}.pt, loaded")

            D.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} D{i}_{SEED}.pt, loaded")
            D_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_opt{i}_{SEED}.pt, loaded")
            D_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_scheduler{i}_{SEED}.pt, loaded")
            if len(DEVICE_IDS) > 1:
                T = nn.DataParallel(T, device_ids=DEVICE_IDS)
                D = nn.DataParallel(D, device_ids=DEVICE_IDS)

    if STRATEGY == "Fix":
        CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{CONTINUE[0] - 1}/")
        for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
            zip(
                SDEs,
                SDE_OPTs,
                SDE_SCHEDULERs,
                BETA_NETs,
                BETA_NET_OPTs,
                BETA_NET_SCHEDULERs,
            )
        ):
            T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} T{i}_{SEED}.pt, loaded")
            T_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_opt{i}_{SEED}.pt, loaded")
            T_scheduler[i].load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_scheduler{i}_{SEED}.pt, loaded")

            D.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} D{i}_{SEED}.pt, loaded")
            D_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_opt{i}_{SEED}.pt, loaded")
            D_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_scheduler{i}_{SEED}.pt, loaded")
            if len(DEVICE_IDS) > 1:
                T = nn.DataParallel(T, device_ids=DEVICE_IDS)
                D = nn.DataParallel(D, device_ids=DEVICE_IDS)

In [None]:
# writer.add_graph(
#     SDEs[0], torch.rand(BATCH_SIZE, DATASET1_CHANNELS, IMG_SIZE, IMG_SIZE).cuda()
# )

### Plots Test


In [None]:
X_fixed, Y_fixed = XY_sampler.sample(BATCH_SIZE)
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(BATCH_SIZE)

In [None]:
plot_all_pivotal(X_test_fixed[0], Y_test_fixed[0])

In [None]:
fig, axes = plot_linked_sde_pushed_images(X_fixed, Y_fixed, SDEs, gray=GRAY_PLOTS)

writer.add_image("paired images[linked sde]", fig2tensor(fig))

In [None]:
fig, axes = plot_linked_sde_pushed_random_paired_images(
    XY_sampler, SDEs, plot_n_samples=BATCH_SIZE, gray=GRAY_PLOTS
)

### Main training cycle and logging


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
def epsilon_scheduler(step):
    return min(EPSILON, EPSILON * (step / EPSILON_SCHEDULER_LAST_ITER))

#### Fix strategy


In [None]:
if STRATEGY == "Fix":
    progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE[0])

    for step in range(MAX_STEPS):
        if step < CONTINUE[0]:
            continue

        for T, D in zip(SDEs, BETA_NETs):
            freeze(T)
            freeze(D)

        new_epsilon = epsilon_scheduler(step)
        writer.add_scalar("Epsilon", new_epsilon, step)
        for _ in range(T_ITERS):
            fixed_trajectory = sample_all_pivotal(XY_sampler, BATCH_SIZE)

            for i, (T, T_opt, D) in enumerate(zip(SDEs, SDE_OPTs, BETA_NETs)):
                freeze(D)
                unfreeze(T)

                if len(DEVICE_IDS) > 1:
                    T.module.set_epsilon(new_epsilon)
                else:
                    T.set_epsilon(new_epsilon)

                T_opt.zero_grad()

                X0, X1 = fixed_trajectory[i], fixed_trajectory[i + 1]

                trajectory, times, shifts = T(X0)
                XN = trajectory[:, -1]

                norm = torch.norm(shifts.flatten(start_dim=2), p=2, dim=-1) ** 2
                integral = INTEGRAL_SCALE * integrate(norm, times[0])

                T_loss = (integral + D(X1) - D(XN)).mean()
                writer.add_scalar(f"T_loss/T{i}", T_loss.item(), step)
                T_loss.backward()
                T_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    T.parameters(), max_norm=T_GRADIENT_MAX_NORM
                )
                T_opt.step()
                del trajectory, X0, X1, XN, times, shifts
                gc.collect()
                torch.cuda.empty_cache()
            del fixed_trajectory
            gc.collect()
            torch.cuda.empty_cache()
        for T_scheduler in SDE_SCHEDULERs:
            T_scheduler.step()
        # wandb.log({f"T gradient norm": T_gradient_norm.item()}, step=step)
        # wandb.log({f"Mean norm": torch.sqrt(norm).mean().item()}, step=step)
        # wandb.log({f"T_loss": T_loss.item()}, step=step)

        for T, D in zip(SDEs, BETA_NETs):
            freeze(T)
            freeze(D)

        fixed_trajectory = sample_all_pivotal(XY_sampler, BATCH_SIZE)
        for i, (D, D_opt, D_scheduler, T) in enumerate(
            zip(BETA_NETs, BETA_NET_OPTs, BETA_NET_SCHEDULERs, SDEs)
        ):
            freeze(T)
            unfreeze(D)

            D_opt.zero_grad()

            X0, X1 = fixed_trajectory[i], fixed_trajectory[i + 1]
            trajectory, times, shifts = T(X0)
            XN = trajectory[:, -1]
            norm = torch.norm(shifts.flatten(start_dim=2), p=2, dim=-1) ** 2
            integral = INTEGRAL_SCALE * integrate(norm, times[0])

            D_loss = (-integral - D(X1) + D(XN)).mean()
            writer.add_scalar(f"D_loss/D{i}", D_loss.item(), step)
            D_loss.backward()
            D_gradient_norm = torch.nn.utils.clip_grad_norm_(
                D.parameters(), max_norm=D_GRADIENT_MAX_NORM
            )
            D_opt.step()
            D_scheduler.step()

            del trajectory, X0, X1, XN, times, shifts
            gc.collect()
            torch.cuda.empty_cache()
        del fixed_trajectory
        gc.collect()
        torch.cuda.empty_cache()
        # wandb.log({f"D gradient norm": D_gradient_norm.item()}, step=step)
        # wandb.log({f"D_loss": D_loss.item()}, step=step)
        # wandb.log({f"integral": integral.mean().item()}, step=step)
        # wandb.log({f"D_X1": D_X1.mean().item()}, step=step)
        # wandb.log({f"D_XN": D_XN.mean().item()}, step=step)

        CONTINUE[0] += 1
        progress_bar.update(1)

        if step % PLOT_INTERVAL == 0:
            progress_bar.close()
            clear_output(wait=True)
            progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE[0])
            print("Plotting")

            inference_SDEs = SDEs
            for T in inference_SDEs:
                T.eval()
            print("Fixed Test Images")
            fig, axes = plot_linked_sde_pushed_images(
                X_test_fixed, Y_test_fixed, inference_SDEs, gray=GRAY_PLOTS
            )
            writer.add_image("Fixed Test Images", fig2tensor(fig), step)
            plt.show(fig)
            plt.close(fig)
            print("Random Test Images")
            fig, axes = plot_linked_sde_pushed_random_paired_images(
                XY_test_sampler,
                inference_SDEs,
                plot_n_samples=BATCH_SIZE,
                gray=GRAY_PLOTS,
            )
            writer.add_image("Random Test Images", fig2tensor(fig), step)
            plt.show(fig)
            plt.close(fig)

        if step != 0 and step % CPKT_INTERVAL == 0:
            CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{step}/")
            os.makedirs(CKPT_DIR, exist_ok=True)
            for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
                zip(
                    SDEs,
                    SDE_OPTs,
                    SDE_SCHEDULERs,
                    BETA_NETs,
                    BETA_NET_OPTs,
                    BETA_NET_SCHEDULERs,
                )
            ):
                if len(DEVICE_IDS) > 1:
                    torch.save(
                        T.module.state_dict(),
                        os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt"),
                    )
                    torch.save(
                        D.module.state_dict(),
                        os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt"),
                    )
                else:
                    torch.save(
                        T.state_dict(), os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
                    )
                    torch.save(
                        D.state_dict(), os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")
                    )

                torch.save(
                    D_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    T_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    D_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"),
                )
                torch.save(
                    T_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"),
                )
            log["CONTINUE"] = CONTINUE
            with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
                log_str = json.dumps(log, indent=4)
                log_file.write(log_str)

        if step % TRACK_VAR_INTERVAL == 0:
            pass
            # print("Computing FID")
            # mu, sigma = get_linked_sde_pushed_loader_stats(
            #     SDEs,
            #     XY_test_sampler.loader,
            #     n_epochs=FID_EPOCHS,
            #     batch_size=BATCH_SIZE,
            #     verbose=True,
            # )
            # fid = calculate_frechet_distance(mu_data, sigma_data, mu, sigma)
            # print(f"FID={fid}")
            # writer.add_scalar("Metrics/FID", fid, step)
            # del mu, sigma

            # print("Computing LPIPS(vgg) LPIPS(alex) L1 MSE")
            # metrics = get_linked_sde_pushed_loader_metrics(
            #     SDEs,
            #     XY_test_sampler.loader,
            #     n_epochs=FID_EPOCHS,
            #     batch_size=BATCH_SIZE,
            #     verbose=True,
            #     log_metrics=["mse", "l1"]
            # )
            # print(f"metrics={metrics}")
            # writer.add_scalar("Metrics/LPIPS(VGG)", metrics["vgg"], step)
            # writer.add_scalar("Metrics/LPIPS(Alex)", metrics["alex"], step)
            # writer.add_scalar("Metrics/L1", metrics["l1"], step)
            # writer.add_scalar("Metrics/MSE", metrics["mse"], step)

        gc.collect()
        torch.cuda.empty_cache()

#### Adapt strategy


In [None]:
if STRATEGY == "Adapt":
    for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
        zip(
            SDEs,
            SDE_OPTs,
            SDE_SCHEDULERs,
            BETA_NETs,
            BETA_NET_OPTs,
            BETA_NET_SCHEDULERs,
        )
    ):
        if i < CONTINUE[1]:
            continue
        progress_bar = tqdm(
            total=MAX_STEPS, initial=CONTINUE[0], desc=f"{i + 1}/{len(SDEs)}:"
        )
        for _T, _D in zip(SDEs, BETA_NETs):
            freeze(_T)
            freeze(_D)

        for step in range(MAX_STEPS):
            if step < CONTINUE[0]:
                continue

            for _ in range(T_ITERS):
                freeze(D)
                unfreeze(T)

                new_epsilon = epsilon_scheduler(step)
                if len(DEVICE_IDS) > 1:
                    T.module.set_epsilon(new_epsilon)
                else:
                    T.set_epsilon(new_epsilon)
                writer.add_scalar(f"Epsilon/eps{i}", new_epsilon, step)

                # === sampler training data ===
                fixed_trajectory = sample_all_pivotal(XY_sampler, BATCH_SIZE)
                X0, X1 = fixed_trajectory[0], fixed_trajectory[i + 1]
                with torch.no_grad():
                    for _ in range(i):
                        tmp_trajectory, _, _ = SDEs[i](X0)
                        X0 = tmp_trajectory[:, -1]
                X0 = X0.requires_grad_()
                # === mapping and optimize ===
                T_opt.zero_grad()

                trajectory, times, shifts = T(X0)
                XN = trajectory[:, -1]

                norm = torch.norm(shifts.flatten(start_dim=2), p=2, dim=-1) ** 2
                integral = INTEGRAL_SCALE * integrate(norm, times[0])
                T_loss = (integral + D(X1) - D(XN)).mean()
                writer.add_scalar(f"T_loss/T{i}", T_loss.item(), step)
                T_loss.backward()
                T_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    T.parameters(), max_norm=T_GRADIENT_MAX_NORM
                )
                T_opt.step()
                del fixed_trajectory, trajectory, X0, X1, XN, times, shifts
                gc.collect()
                torch.cuda.empty_cache()
            T_scheduler.step()
            # wandb.log({f"T gradient norm": T_gradient_norm.item()}, step=step)
            # wandb.log({f"Mean norm": torch.sqrt(norm).mean().item()}, step=step)
            # wandb.log({f"T_loss": T_loss.item()}, step=step)

            freeze(T)
            unfreeze(D)
            # === sampler training data ===
            fixed_trajectory = sample_all_pivotal(XY_sampler, BATCH_SIZE)
            X0, X1 = fixed_trajectory[0], fixed_trajectory[i + 1]
            with torch.no_grad():
                for _ in range(i):
                    tmp_trajectory, _, _ = SDEs[i](X0)
                    X0 = tmp_trajectory[:, -1]
            X0 = X0.requires_grad_()
            # === mapping and optimize ===
            D_opt.zero_grad()
            trajectory, times, shifts = T(X0)
            XN = trajectory[:, -1]

            norm = torch.norm(shifts.flatten(start_dim=2), p=2, dim=-1) ** 2
            integral = INTEGRAL_SCALE * integrate(norm, times[0])
            D_loss = (-integral - D(X1) + D(XN)).mean()
            writer.add_scalar(f"D_loss/D{i}", D_loss.item(), step)
            D_loss.backward()
            D_gradient_norm = torch.nn.utils.clip_grad_norm_(
                D.parameters(), max_norm=D_GRADIENT_MAX_NORM
            )
            D_opt.step()
            D_scheduler.step()

            del fixed_trajectory, trajectory, X0, X1, XN, times, shifts, D_loss
            gc.collect()
            torch.cuda.empty_cache()
            # wandb.log({f"D gradient norm": D_gradient_norm.item()}, step=step)
            # wandb.log({f"D_loss": D_loss.item()}, step=step)

            # wandb.log({f"integral": integral.mean().item()}, step=step)
            # wandb.log({f"D_X1": D_X1.mean().item()}, step=step)
            # wandb.log({f"D_XN": D_XN.mean().item()}, step=step)

            CONTINUE[0] += 1
            progress_bar.update(1)

            if step % PLOT_INTERVAL == 0:
                progress_bar.close()
                clear_output(wait=True)
                progress_bar = tqdm(
                    total=MAX_STEPS, initial=CONTINUE[0], desc=f"{i + 1}/{len(SDEs)}:"
                )
                print("Plotting")

                inference_SDEs = SDEs
                for T in inference_SDEs:
                    T.eval()

                print("Fixed Test Images")
                fig, axes = plot_linked_sde_pushed_images(
                    X_test_fixed, Y_test_fixed, inference_SDEs, gray=GRAY_PLOTS
                )
                writer.add_image("Fixed Test Images", fig2tensor(fig), step)
                plt.show(fig)
                plt.close(fig)

                print("Random Test Images")
                fig, axes = plot_linked_sde_pushed_random_paired_images(
                    XY_test_sampler,
                    inference_SDEs,
                    plot_n_samples=BATCH_SIZE,
                    gray=GRAY_PLOTS,
                )
                writer.add_image("Random Test Images", fig2tensor(fig), step)
                plt.show(fig)
                plt.close(fig)

            if step != 0 and step % CPKT_INTERVAL == 0:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{step}/")
                os.makedirs(CKPT_DIR, exist_ok=True)
                if len(DEVICE_IDS) > 1:
                    torch.save(
                        T.module.state_dict(),
                        os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt"),
                    )
                    torch.save(
                        D.module.state_dict(),
                        os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt"),
                    )
                else:
                    torch.save(
                        T.state_dict(), os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
                    )
                    torch.save(
                        D.state_dict(), os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")
                    )

                torch.save(
                    D_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    T_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    D_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"),
                )
                torch.save(
                    T_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"),
                )
                log = dict(CONTINUE=CONTINUE)
                with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
                    log_str = json.dumps(log, indent=4)
                    log_file.write(log)

            if i == len(SDEs) - 1 and step % TRACK_VAR_INTERVAL == 0:
                print("Computing FID")
                mu, sigma = get_linked_sde_pushed_loader_stats(
                    SDEs,
                    XY_test_sampler.loader,
                    n_epochs=FID_EPOCHS,
                    batch_size=BATCH_SIZE,
                    verbose=True,
                )
                fid = calculate_frechet_distance(mu_data, sigma_data, mu, sigma)
                writer.add_scalar(f"Metrics/FID{i}", fid, step)
                del mu, sigma

                print("Computing LPIPS(vgg) LPIPS(alex) L1 MSE")
                metrics = get_linked_sde_pushed_loader_metrics(
                    SDEs,
                    XY_test_sampler.loader,
                    n_epochs=FID_EPOCHS,
                    batch_size=BATCH_SIZE,
                    verbose=True,
                    log_metrics=["mse", "l1"],
                )
                print(f"metrics={metrics}")
                writer.add_scalar("LPIPS(VGG)", metrics["vgg"], step)
                writer.add_scalar("LPIPS(Alex)", metrics["alex"], step)
                writer.add_scalar("L1", metrics["l1"], step)
                writer.add_scalar("MSE", metrics["mse"], step)

            gc.collect()
            torch.cuda.empty_cache()
        CONTINUE[0] = 0  # reset training steps to 0
        CONTINUE[1] += 1

## Clear resources


In [None]:
try:
    writer.close()
    progress_bar.close()
except Exception as e:
    print(e)