## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings

import torch
import torchvision.datasets as datasets
import numpy as np
from diffusers import DDIMScheduler
from tensorboardX import SummaryWriter
from torchvision.transforms import Compose, ToTensor, Resize, Normalize


import matplotlib.pyplot as plt
from PIL import PngImagePlugin
from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

sys.path.append("..")
from src.resnet2 import ResNet_D
from src.unet import UNet
from src.tools import (
    set_random_seed,
    unfreeze,
    freeze,
    weights_init_D,
    fig2tensor,
    get_all_pivotal,
    get_step_t_pivotal,
    linked_push,
)
from src.plotters import (
    plot_linked_pushed_images,
    plot_linked_pushed_random_class_images,
)
from src.samplers import (
    SubsetGuidedSampler,
    SubsetGuidedDataset,
    get_indicies_subset,
)

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Config


In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
DATASET, DATASET_PATH = "fmnist2mnist", "../datasets/"
# DATASET, DATASET_PATH = "mnist2fmnist", "../datasets/"

# DATASET, DATASET_PATH = "usps2mnist", "../datasets/"
# DATASET, DATASET_PATH = "mnist2usps", "../datasets/"

# DATASET, DATASET_PATH = "usps2fmnist", "../datasets/"
# DATASET, DATASET_PATH = "fmnist2usps", "../datasets/"

# DATASET, DATASET_PATH = "mnistm2mnist", "../datasets/"
# DATASET, DATASET_PATH = "mnist2mnistm", "../datasets/"

IMG_SIZE = 32
DATASET1_CHANNELS = 1
DATASET2_CHANNELS = 1

# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [20, 50, 100]

# GPU choosing
DEVICE_ID = 1
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_ID}")

CONTINUE = [0, 0]  # first is for step, last is for sdes

# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
STRATEGY = "Adapt"  # 'Fix' or 'Adapt'
T_ITERS = 10
MAX_STEPS = 5000 + 1

# data sample settings
BATCH_SIZE = 16
SUBSET_SIZE = 4
SUBSET_CLASS = 3
NUM_LABELED = "all"  # "all" or int value, sunch as 10

# model settings
UNET_BASE_FACTOR = 48

# optimizer settings
D_LR, T_LR = 1e-4, 1e-4
T_GRADIENT_MAX_NORM = float(1e5)
D_GRADIENT_MAX_NORM = float(1e5)
SCHEDULER_MILESTONES = [1500, 3000, 4000]

# plot settings
GRAY_PLOTS = True
PLOT_N_SAMPLES = 8

# log settings
TRACK_VAR_INTERVAL = 500
PLOT_INTERVAL = 500
CKPT_INTERVAL = 1000

FID_EPOCHS = 1

EXP_NAME = f"DNOT_Unpair_{DATASET}_{SUBSET_CLASS}_{STRATEGY}_{SEED}"
OUTPUT_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

writer = SummaryWriter(f"../logdir/{EXP_NAME}")

In [None]:
config = dict(
    SEED=SEED,
    DATASET=DATASET,
    IMG_SIZE=IMG_SIZE,
    DATASET1_CHANNELS=DATASET1_CHANNELS,
    DATASET2_CHANNELS=DATASET2_CHANNELS,
    DIFFUSION_STEPS=DIFFUSION_STEPS,
    PIVOTAL_LIST=PIVOTAL_LIST,
    STRATEGY=STRATEGY,
    T_ITERS=T_ITERS,
    MAX_STEPS=MAX_STEPS,
    UNET_BASE_FACTOR=UNET_BASE_FACTOR,
    BATCH_SIZE=BATCH_SIZE,
    SUBSET_SIZE=SUBSET_SIZE,
    SUB_CLASS=SUBSET_CLASS,
    D_LR=D_LR,
    T_LR=T_LR,
    T_GRADIENT_MAX_NORM=T_GRADIENT_MAX_NORM,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    SCHEDULER_MILESTONES=SCHEDULER_MILESTONES,
    GRAY_PLOTS=GRAY_PLOTS,
    PLOT_N_SAMPLES=PLOT_N_SAMPLES,
    TRACK_VAR_INTERVAL=TRACK_VAR_INTERVAL,
    PLOT_INTERVAL=PLOT_INTERVAL,
    CKPT_INTERVAL=CKPT_INTERVAL,
    FID_EPOCHS=FID_EPOCHS,
)
with open(os.path.join(OUTPUT_PATH, "config.json"), "w") as json_file:
    json_str = json.dumps(config, indent=4)
    json_file.write(json_str)

log = dict(CONTINUE=CONTINUE)
with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
    log_str = json.dumps(log, indent=4)
    log_file.write(log_str)

## 3. Initialize samplers


###  data sampler

In [None]:
source_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_source = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}
target_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_target = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}

SUBSET_WEIGHT = [0 for _ in range(len(source_subset))]
SUBSET_WEIGHT[SUBSET_CLASS] = 1.0

In [None]:
source_transform = Compose(
    [
        Resize((IMG_SIZE, IMG_SIZE)),
        ToTensor(),
        Normalize((0.5), (0.5)),
    ]
)
target_transform = source_transform

if DATASET == "fmnist2mnist":
    source = datasets.FashionMNIST
    target = datasets.MNIST

elif DATASET == "mnist2fmnist":
    source = datasets.MNIST
    target = datasets.FashionMNIST

elif DATASET == "mnist2usps":
    source = datasets.MNIST
    target = datasets.USPS

elif DATASET == "usps2mnist":
    source = datasets.USPS
    target = datasets.MNIST

elif DATASET == "usps2fmnist":
    source = datasets.USPS
    target = datasets.FashionMNIST

elif DATASET == "fmnist2usps":
    source = datasets.FashionMNIST
    target = datasets.USPS
else:
    raise Exception(f"{DATASET} not support now...")

In [None]:
source_train = source(
    root=DATASET_PATH, train=True, download=True, transform=source_transform
)
subset_samples, labels, source_class_indicies = get_indicies_subset(
    source_train,
    new_labels=new_labels_source,
    classes=len(source_subset),
    subset_classes=source_subset,
)
source_train = torch.utils.data.TensorDataset(
    torch.stack(subset_samples), torch.LongTensor(labels)
)


target_train = target(
    root=DATASET_PATH, train=True, download=True, transform=target_transform
)
target_subset_samples, target_labels, target_class_indicies = get_indicies_subset(
    target_train,
    new_labels=new_labels_target,
    classes=len(target_subset),
    subset_classes=target_subset,
)
target_train = torch.utils.data.TensorDataset(
    torch.stack(target_subset_samples), torch.LongTensor(target_labels)
)

train_set = SubsetGuidedDataset(
    source_train,
    target_train,
    num_labeled=NUM_LABELED,
    in_indicies=source_class_indicies,
    out_indicies=target_class_indicies,
)

full_set = SubsetGuidedDataset(
    source_train,
    target_train,
    num_labeled="all",
    in_indicies=source_class_indicies,
    out_indicies=target_class_indicies,
)

In [None]:
T_XY_sampler = SubsetGuidedSampler(
    train_set, subsetsize=SUBSET_SIZE, weight=SUBSET_WEIGHT
)
D_XY_sampler = SubsetGuidedSampler(full_set, subsetsize=1, weight=SUBSET_WEIGHT)

In [None]:
torch.cuda.empty_cache()
gc.collect()
clear_output()

### pivotal sampler


In [None]:
SCHEDULER = DDIMScheduler(num_train_timesteps=DIFFUSION_STEPS)


def sample_all_pivotal(
    XY_sampler: SubsetGuidedSampler,
    batch_size: int = 4,
) -> list[torch.Tensor]:
    source, target = XY_sampler.sample(batch_size)

    return get_all_pivotal(
        source,
        target,
        SCHEDULER,
        PIVOTAL_LIST,
    )

In [None]:
def plot_all_pivotal(
    source: torch.Tensor,
    target: torch.Tensor,
    gray: bool = False,
):
    pivotal_path = get_all_pivotal(source, target, SCHEDULER, PIVOTAL_LIST)

    imgs: np.ndarray = (
        torch.stack(pivotal_path)
        .to("cpu")
        .permute(0, 2, 3, 1)
        .mul(0.5)
        .add(0.5)
        .numpy()
        .clip(0, 1)
    )
    nrows, ncols = 1, len(pivotal_path)
    fig = plt.figure(figsize=(1.5 * ncols, 1.5 * nrows), dpi=150)
    for i, img in enumerate(imgs):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        if gray:
            ax.imshow(img, cmap="gray")
        else:
            ax.imshow(img)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_title(f"$X_{i}$", fontsize=24)
        if i == imgs.shape[0] - 1:
            ax.set_title("Y", fontsize=24)

    torch.cuda.empty_cache()
    gc.collect()

## 4. Initialize models


### init models


In [None]:
Ts, Ds = [], []
T_OPTs, D_OPTs = [], []
T_SCHEDULERs, D_SCHEDULERs = [], []

for i in range(len(PIVOTAL_LIST) * 2):
    T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=UNET_BASE_FACTOR).cuda()
    Ts.append(T)

    D = ResNet_D(IMG_SIZE, nc=DATASET2_CHANNELS).cuda()
    D.apply(weights_init_D)
    Ds.append(D)

    T_opt = torch.optim.Adam(T.parameters(), lr=T_LR, weight_decay=1e-10)
    D_opt = torch.optim.Adam(D.parameters(), lr=D_LR, weight_decay=1e-10)
    T_OPTs.append(T_opt)
    D_OPTs.append(D_opt)

    T_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        T_opt, milestones=SCHEDULER_MILESTONES, gamma=0.5
    )
    D_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        D_opt, milestones=SCHEDULER_MILESTONES, gamma=0.5
    )
    T_SCHEDULERs.append(T_scheduler)
    D_SCHEDULERs.append(D_scheduler)

### Load weights for continue training


In [None]:
if CONTINUE[0] > 0 or CONTINUE[1] > 0:
    print("Loading weights for continue training")
    if STRATEGY == "Adapt":
        for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
            zip(
                Ts,
                T_OPTs,
                T_SCHEDULERs,
                Ds,
                D_OPTs,
                D_SCHEDULERs,
            )
        ):
            if i > CONTINUE[1]:
                continue
            if i < CONTINUE[1]:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{MAX_STEPS - 1}/")
            if i == CONTINUE[1]:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{CONTINUE[0] - 1}/")

            T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} T{i}_{SEED}.pt, loaded")

            T_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_opt{i}_{SEED}.pt, loaded")

            T_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_scheduler{i}_{SEED}.pt, loaded")

            D.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} D{i}_{SEED}.pt, loaded")

            D_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_opt{i}_{SEED}.pt, loaded")

            D_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_scheduler{i}_{SEED}.pt, loaded")

    if STRATEGY == "Fix":
        CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{CONTINUE[0] - 1}/")
        for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
            zip(
                Ts,
                T_OPTs,
                T_SCHEDULERs,
                Ds,
                D_OPTs,
                D_SCHEDULERs,
            )
        ):
            T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} T{i}_{SEED}.pt, loaded")

            T_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_opt{i}_{SEED}.pt, loaded")

            T_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_scheduler{i}_{SEED}.pt, loaded")

            D.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} D{i}_{SEED}.pt, loaded")

            D_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_opt{i}_{SEED}.pt, loaded")

            D_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_scheduler{i}_{SEED}.pt, loaded")

## 5. Plots Test


In [None]:
X_test_fixed, Y_test_fixed = D_XY_sampler.sample(PLOT_N_SAMPLES)
X_test_fixed, Y_test_fixed = (
    X_test_fixed.flatten(0, 1),
    Y_test_fixed.flatten(0, 1),
)

In [None]:
plot_all_pivotal(
    X_test_fixed[0],
    Y_test_fixed[0],
    gray=GRAY_PLOTS,
)

In [None]:
fig, axes = plot_linked_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    Ts,
    gray=GRAY_PLOTS,
)

In [None]:
fig, axes = plot_linked_pushed_random_class_images(
    D_XY_sampler,
    Ts,
    plot_n_samples=PLOT_N_SAMPLES,
    gray=GRAY_PLOTS,
)

## 6. Train


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
def get_cat_pivotal_tr(XY_sampler1, XY_sampler2, batch_size):
    X_tr = sample_all_pivotal(XY_sampler1, batch_size)
    Y_tr = sample_all_pivotal(XY_sampler2, batch_size)
    assert len(X_tr) == len(Y_tr)

    tr = []
    length = len(X_tr)
    mid_indx = length // 2
    tr.extend(X_tr[: mid_indx + 1])
    tr.extend(Y_tr[mid_indx + 1 :])

    del X_tr[mid_indx + 1 :]
    del Y_tr[: mid_indx + 1]

    gc.collect()
    torch.cuda.empty_cache()
    return tr

### Fix strategy


In [None]:
if STRATEGY == "Fix":
    pbar = tqdm(total=MAX_STEPS, initial=max(0, CONTINUE[0]), leave=True)

#### training

In [None]:
if STRATEGY == "Fix":
    for step in range(max(0, CONTINUE[0]), MAX_STEPS):
        for k in range(T_ITERS):
            # === sample path ===
            pivotal_path = sample_all_pivotal(T_XY_sampler, BATCH_SIZE)
            for i, (T, T_opt, D) in enumerate(zip(Ts, T_OPTs, Ds)):
                freeze(D)
                unfreeze(T)
                # === sample input ===
                X, Y = pivotal_path[i], pivotal_path[i + 1]
                # === clear grad ===
                T_opt.zero_grad()
                # === forward ===
                T_X = (
                    T(X.flatten(0, 1))
                    .permute(1, 2, 3, 0)
                    .reshape(DATASET2_CHANNELS, IMG_SIZE, IMG_SIZE, -1, SUBSET_SIZE)
                    .permute(3, 4, 0, 1, 2)
                )
                # === loss ===
                T_var = (
                    0.5
                    * torch.cdist(
                        T_X.flatten(start_dim=2), T_X.flatten(start_dim=2)
                    ).mean()
                    * SUBSET_SIZE
                    / (SUBSET_SIZE - 1)
                )
                cost = (Y - T_X).flatten(start_dim=2).norm(dim=2).mean()
                T_loss = cost - T_var - D(T_X.flatten(start_dim=0, end_dim=1)).mean()
                writer.add_scalar(f"T_loss/T{i}", T_loss.item(), step)
                # === backward ===
                T_loss.backward()
                # === clip grad ===
                T_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    T.parameters(), max_norm=T_GRADIENT_MAX_NORM
                )
                # === optim ===
                T_opt.step()
            # # === clear tmp variables ===
            # del pivotal_path, X, Y, T_X, T_loss
            # gc.collect()
            # torch.cuda.empty_cache()
        # === update lr ===
        for T_scheduler in T_SCHEDULERs:
            T_scheduler.step()

        for T, D in zip(Ts, Ds):
            freeze(T)
            freeze(D)
        # === sample path ===
        # two way
        #   1. concatnat path: pivotal_path = get_cat_pivotal_tr(T_XY_sampler, D_XY_sampler, BATCH_SIZE)
        #   2. just 1 path: pivotal_path = sample_all_pivotal(D_XY_sampler, BATCH_SIZE)
        pivotal_path = sample_all_pivotal(D_XY_sampler, BATCH_SIZE)
        for i, (D, D_opt, D_scheduler, T) in enumerate(
            zip(Ds, D_OPTs, D_SCHEDULERs, Ts)
        ):
            freeze(T)
            unfreeze(D)
            # === sample input ===
            X, Y = pivotal_path[i], pivotal_path[i + 1]
            # === clear grad ===
            D_opt.zero_grad()
            # === forward ===
            with torch.no_grad():
                T_X = T(X.flatten(start_dim=0, end_dim=1))
            # === loss ===
            D_loss = D(T_X).mean() - D(Y.flatten(start_dim=0, end_dim=1)).mean()
            writer.add_scalar(f"D_loss/D{i}", D_loss.item(), step)
            # === backward ===
            D_loss.backward()
            # === clip grad ===
            D_gradient_norm = torch.nn.utils.clip_grad_norm_(
                D.parameters(), max_norm=D_GRADIENT_MAX_NORM
            )
            # === optim ===
            D_opt.step()
            # # === clear tmp variables ===
            # del pivotal_path, X, Y, T_X, D_loss
            # gc.collect()
            # torch.cuda.empty_cache()
            # === update lr ===
            D_scheduler.step()

        CONTINUE[0] += 1
        pbar.update(1)

        if step % PLOT_INTERVAL == 0:
            clear_output(wait=True)
            print(f"{step = }, Plotting")

            print("Fix Test Images")
            fig, axes = plot_linked_pushed_images(
                X_test_fixed,
                Y_test_fixed,
                Ts,
                gray=GRAY_PLOTS,
            )
            writer.add_image("Fix Test Images", fig2tensor(fig), step)
            plt.show(fig)
            plt.close(fig)
            print("Random Test Images")
            fig, axes = plot_linked_pushed_random_class_images(
                D_XY_sampler,
                Ts,
                plot_n_samples=PLOT_N_SAMPLES,
                gray=GRAY_PLOTS,
            )
            writer.add_image("Random Test Images", fig2tensor(fig), step)
            plt.show(fig)
            plt.close(fig)

        if step != 0 and step % CKPT_INTERVAL == 0:
            CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{step}/")
            os.makedirs(CKPT_DIR, exist_ok=True)
            for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
                zip(
                    Ts,
                    T_OPTs,
                    T_SCHEDULERs,
                    Ds,
                    D_OPTs,
                    D_SCHEDULERs,
                )
            ):
                torch.save(T.state_dict(), os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt"))
                torch.save(D.state_dict(), os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt"))

                torch.save(
                    D_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    T_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    D_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"),
                )
                torch.save(
                    T_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"),
                )

            log["CONTINUE"] = CONTINUE
            with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
                log_str = json.dumps(log, indent=4)
                log_file.write(log_str)

        if step % TRACK_VAR_INTERVAL == 0:
            # after training, using test_transport.ipynb to get fid acc ...
            pass

    # gc.collect()
    # torch.cuda.empty_cache()

### Adapt strategy


In [None]:
if STRATEGY == "Adapt":
    spbar = tqdm(total=len(Ts), initial=0, position=0, desc="total")
    pbar = tqdm(
        total=MAX_STEPS,
        initial=max(0, CONTINUE[0]),
        position=1,
        desc="single",
        leave=True,
    )

#### training

In [None]:
if STRATEGY == "Adapt":
    for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
        zip(
            Ts,
            T_OPTs,
            T_SCHEDULERs,
            Ds,
            D_OPTs,
            D_SCHEDULERs,
        )
    ):
        if i < CONTINUE[1]:
            spbar.update(1)
            continue

        for _T, _D in zip(Ts, Ds):
            freeze(_T)
            freeze(_D)

        for step in range(max(0, CONTINUE[0]), MAX_STEPS):
            for k in range(T_ITERS):
                freeze(D)
                unfreeze(T)

                # === sampler path ===
                pivotal_path = sample_all_pivotal(T_XY_sampler, BATCH_SIZE)
                # === sample input ===
                X, Y = pivotal_path[0], pivotal_path[i + 1]
                X = linked_push(
                    Ts[:i],
                    X.flatten(0, 1),
                    "T_X",
                )
                # X.requires_grad_(True)
                # Y.requires_grad_(False)

                # === clear grad ===
                T_opt.zero_grad()
                # === forward ===
                T_X = (
                    T(X)
                    .permute(1, 2, 3, 0)
                    .reshape(DATASET2_CHANNELS, IMG_SIZE, IMG_SIZE, -1, SUBSET_SIZE)
                    .permute(3, 4, 0, 1, 2)
                )
                # === loss ===
                T_var = (
                    0.5
                    * torch.cdist(
                        T_X.flatten(start_dim=2), T_X.flatten(start_dim=2)
                    ).mean()
                    * SUBSET_SIZE
                    / (SUBSET_SIZE - 1)
                )
                cost = (Y - T_X).flatten(start_dim=2).norm(dim=2).mean()
                T_loss = cost - T_var - D(T_X.flatten(start_dim=0, end_dim=1)).mean()
                writer.add_scalar(f"T_loss/T{i}", T_loss.item(), step)
                # === backward ===
                T_loss.backward()
                # === clip grad ===
                T_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    T.parameters(), max_norm=T_GRADIENT_MAX_NORM
                )
                # === optim ===
                T_opt.step()
                # # === clear tmp variables ===
                # del pivotal_path, X, Y, T_X, T_loss
                # gc.collect()
                # torch.cuda.empty_cache()
            # === update lr ===
            T_scheduler.step()

            freeze(T)
            unfreeze(D)
            # === sampler path ===
            # two way
            #   1. concatnat path: pivotal_path = get_cat_pivotal_tr(XY_sampler, XY_sampler, BATCH_SIZE)
            #   2. just 1 path: pivotal_path = sample_all_pivotal(XY_sampler, BATCH_SIZE)
            pivotal_path = sample_all_pivotal(D_XY_sampler, BATCH_SIZE)
            # === sample input ===
            X, Y = pivotal_path[0], pivotal_path[i + 1]
            X = linked_push(
                Ts[:i],
                X.flatten(0, 1),
                "T_X",
            )
            # X.requires_grad_(True)
            # Y.requires_grad_(False)

            # === clear grad ===
            D_opt.zero_grad()
            # === forward ===
            T_X = T(X)
            # === loss ===
            D_loss = D(T_X).mean() - D(Y.flatten(start_dim=0, end_dim=1)).mean()
            writer.add_scalar(f"D_loss/D{i}", D_loss.item(), step)
            # === backward ===
            D_loss.backward()
            # === clip grad ===
            D_gradient_norm = torch.nn.utils.clip_grad_norm_(
                D.parameters(), max_norm=D_GRADIENT_MAX_NORM
            )
            # === optim ===
            D_opt.step()
            # # === clear tmp variables ===
            # del pivotal_path, X, Y, T_X, D_loss
            # gc.collect()
            # torch.cuda.empty_cache()
            # === update lr ===
            D_scheduler.step()

            CONTINUE[0] += 1
            pbar.update(1)

            if step % PLOT_INTERVAL == 0:
                clear_output(wait=True)
                print(f"training {i}-th T, {step = }, Plotting")

                print("Fixed Test Images")
                fig, axes = plot_linked_pushed_images(
                    X_test_fixed,
                    Y_test_fixed,
                    Ts,
                    gray=GRAY_PLOTS,
                )
                writer.add_image(
                    f"Fix Test Images/T[0..={i}]",
                    fig2tensor(fig),
                    step,
                )
                plt.show(fig)
                plt.close(fig)

                print("Random Test Images")
                fig, axes = plot_linked_pushed_random_class_images(
                    D_XY_sampler,
                    Ts,
                    plot_n_samples=PLOT_N_SAMPLES,
                    gray=GRAY_PLOTS,
                )
                writer.add_image(
                    f"Random Test Images/T[0..={i}]",
                    fig2tensor(fig),
                    step,
                )
                plt.show(fig)
                plt.close(fig)

            if step != 0 and step % CKPT_INTERVAL == 0:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{step}/")
                os.makedirs(CKPT_DIR, exist_ok=True)
                torch.save(T.state_dict(), os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt"))
                torch.save(D.state_dict(), os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt"))

                torch.save(
                    D_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    T_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    D_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"),
                )
                torch.save(
                    T_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"),
                )
                log = dict(CONTINUE=CONTINUE)
                with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
                    log_str = json.dumps(log, indent=4)
                    log_file.write(log_str)

            if i == len(Ts) - 1 and step % TRACK_VAR_INTERVAL == 0:
                # after training, using test_transport.ipynb to get fid acc ...
                pass

            gc.collect()
            torch.cuda.empty_cache()

        CONTINUE[0] = 0  # reset training steps to 0
        pbar.reset()
        CONTINUE[1] += 1
        spbar.update(1)

## Clear resources


In [None]:
try:
    writer.close()
    pbar.close()
    spbar.close()
except Exception as e:
    print(e)