## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings
from typing import List

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import numpy as np
from diffusers import DDIMScheduler
from tensorboardX import SummaryWriter
from torchvision.transforms import Compose, ToTensor, Resize, Normalize, Lambda


import matplotlib.pyplot as plt
from PIL import PngImagePlugin
from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

sys.path.append("..")
from src.resnet2 import ResNet_D
from src.mnistm_utils import MNISTM
from src.unet import UNet
from src.tools import (
    set_random_seed,
    unfreeze,
    freeze,
    weights_init_D,
    fig2tensor,
)
from src.plotters import (
    plot_linked_pushed_images,
    plot_linked_pushed_random_class_images,
)
from src.samplers import (
    SubsetGuidedSampler,
    SubsetGuidedDataset,
    get_indicies_subset,
)

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Config


In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
DATASET, DATASET_PATH = "fmnist2mnist", "../datasets/"
# DATASET, DATASET_PATH = "mnist2mnistm", "../datasets/"
# DATASET, DATASET_PATH = "mnist2usps", "../datasets/"
# DATASET, DATASET_PATH = "mnist2kmnist", "../datasets/"

IMG_SIZE = 32
DATASET1_CHANNELS = 1
DATASET2_CHANNELS = 1

# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [40, 100]  # [0, 100] for testing,  [0, 20, 50, 100]

# GPU choosing
DEVICE_IDS = [0]
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_IDS[0]}")

CONTINUE = [0, 0]  # first is for step, last is for sdes

# All hyperparameters below is set to the values used for the experiments, which discribed in the article

# training algorithm settings
STRATEGY = "Fix"  # 'Fix' or 'Adapt'
T_ITERS = 10
MAX_STEPS = 10000 + 1
# model settings
UNET_BASE_FACTOR = 48
# data sample settings
BATCH_SIZE = 32
SUBSET_SIZE = 4
SUBSET_CLASS = 4
# optimizer settings
D_LR, T_LR = 1e-4, 1e-4
T_GRADIENT_MAX_NORM = float(1e4)
D_GRADIENT_MAX_NORM = float(1e4)
SCHEDULER_MILESTONES = [2000, 4000, 6000, 8000]

# plot settings
GRAY_PLOTS = True
PLOT_N_SAMPLES = 10

# log settings
TRACK_VAR_INTERVAL = 500
PLOT_INTERVAL = 500
CKPT_INTERVAL = 1000

FID_EPOCHS = 1

EXP_NAME = f"DNOT_Unpair_{DATASET}_{STRATEGY}_{SEED}"
OUTPUT_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

writer = SummaryWriter(f"../logdir/{EXP_NAME}")

In [None]:
source_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_source = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}
target_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_target = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}

SUBSET_WEIGHT = [0 for _ in range(len(source_subset))]
SUBSET_WEIGHT[SUBSET_CLASS] = 1.0

In [None]:
config = dict(
    SEED=SEED,
    DATASET=DATASET,
    IMG_SIZE=IMG_SIZE,
    DATASET1_CHANNELS=DATASET1_CHANNELS,
    DATASET2_CHANNELS=DATASET2_CHANNELS,
    DIFFUSION_STEPS=DIFFUSION_STEPS,
    PIVOTAL_LIST=PIVOTAL_LIST,
    STRATEGY=STRATEGY,
    T_ITERS=T_ITERS,
    MAX_STEPS=MAX_STEPS,
    UNET_BASE_FACTOR=UNET_BASE_FACTOR,
    BATCH_SIZE=BATCH_SIZE,
    SUBSET_SIZE=SUBSET_SIZE,
    SUB_CLASS=SUBSET_CLASS,
    D_LR=D_LR,
    T_LR=T_LR,
    T_GRADIENT_MAX_NORM=T_GRADIENT_MAX_NORM,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    SCHEDULER_MILESTONES=SCHEDULER_MILESTONES,
    GRAY_PLOTS=GRAY_PLOTS,
    PLOT_N_SAMPLES=PLOT_N_SAMPLES,
    TRACK_VAR_INTERVAL=TRACK_VAR_INTERVAL,
    PLOT_INTERVAL=PLOT_INTERVAL,
    CKPT_INTERVAL=CKPT_INTERVAL,
    FID_EPOCHS=FID_EPOCHS,
)
with open(os.path.join(OUTPUT_PATH, "config.json"), "w") as json_file:
    json_str = json.dumps(config, indent=4)
    json_file.write(json_str)

log = dict(CONTINUE=CONTINUE)
with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
    log_str = json.dumps(log, indent=4)
    log_file.write(log_str)

## 3. Initialize samplers


In [None]:
source_transform = Compose(
    [
        Resize((IMG_SIZE, IMG_SIZE)),
        ToTensor(),
        Normalize((0.5), (0.5)),
    ]
)
target_transform = source_transform

if DATASET == "mnist2kmnist":
    source = datasets.MNIST
    target = datasets.KMNIST

elif DATASET == "fmnist2mnist":
    source = datasets.FashionMNIST
    target = datasets.MNIST

elif DATASET == "mnist2usps":
    source = datasets.MNIST
    target = datasets.USPS

elif DATASET == "mnist2mnistm":
    DATASET1_CHANNELS = 3
    DATASET2_CHANNELS = 3
    GRAY_PLOTS = False
    source = datasets.MNIST
    target = MNISTM
    source_transform = Compose(
        [
            Resize((IMG_SIZE, IMG_SIZE)),
            ToTensor(),
            Normalize((0.5), (0.5)),
            Lambda(lambda x: -x.repeat(3, 1, 1)),
        ]
    )
    target_transform = Compose(
        [Resize(IMG_SIZE), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

In [None]:
source_train = source(
    root=DATASET_PATH, train=True, download=True, transform=source_transform
)
subset_samples, labels, source_class_indicies = get_indicies_subset(
    source_train,
    new_labels=new_labels_source,
    classes=len(source_subset),
    subset_classes=source_subset,
)
source_train = torch.utils.data.TensorDataset(
    torch.stack(subset_samples), torch.LongTensor(labels)
)


target_train = target(
    root=DATASET_PATH, train=True, download=True, transform=target_transform
)
target_subset_samples, target_labels, target_class_indicies = get_indicies_subset(
    target_train,
    new_labels=new_labels_target,
    classes=len(target_subset),
    subset_classes=target_subset,
)
target_train = torch.utils.data.TensorDataset(
    torch.stack(target_subset_samples), torch.LongTensor(target_labels)
)

train_set = SubsetGuidedDataset(
    source_train,
    target_train,
    num_labeled="all",
    in_indicies=source_class_indicies,
    out_indicies=target_class_indicies,
)

full_set = SubsetGuidedDataset(
    source_train,
    target_train,
    num_labeled="all",
    in_indicies=source_class_indicies,
    out_indicies=target_class_indicies,
)

In [None]:
T_XY_sampler = SubsetGuidedSampler(
    train_set, subsetsize=SUBSET_SIZE, weight=SUBSET_WEIGHT
)
D_XY_sampler = SubsetGuidedSampler(full_set, subsetsize=1, weight=SUBSET_WEIGHT)

In [None]:
torch.cuda.empty_cache()
gc.collect()
clear_output()

### pivotal sampler


In [None]:
SCHEDULER = DDIMScheduler(num_train_timesteps=DIFFUSION_STEPS)


def sample_all_pivotal(
    XY_sampler: SubsetGuidedSampler,
    batch_size: int = 4,
) -> List[torch.Tensor]:
    pivotal_path = []

    source, target = XY_sampler.sample(batch_size)
    source_list = [source]
    target_list = [target]
    for i in range(min(DIFFUSION_STEPS, PIVOTAL_LIST[-1])):
        source = SCHEDULER.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = SCHEDULER.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in PIVOTAL_LIST:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list[1:])  # just using source's last pivotal point
    # pivotal_path.extend(target_list[:]) # 2 last pivotal points mapping

    return pivotal_path


# def sample_step_t_pivotal(
#     XY_sampler: PairedLoaderSampler,
#     batch_size: int = 4,
#     pivotal_step: int = 0,
# ):
#     pivotal_path = sample_all_pivotal(XY_sampler, batch_size)
#     pivotal_t, pivotal_t_next = (
#         pivotal_path[pivotal_step],
#         pivotal_path[pivotal_step + 1],
#     )
#     return pivotal_t, pivotal_t_next

### mapping plotters


In [None]:
def plot_all_pivotal(
    source: torch.Tensor,
    target: torch.Tensor,
    gray: bool = False,
) -> list:
    pivotal_path = []

    source_list = [source]
    target_list = [target]
    for i in range(min(DIFFUSION_STEPS, PIVOTAL_LIST[-1])):
        source = SCHEDULER.add_noise(
            source, torch.randn_like(source), torch.Tensor([i]).long()
        )
        target = SCHEDULER.add_noise(
            target, torch.randn_like(target), torch.Tensor([i]).long()
        )
        if (i + 1) in PIVOTAL_LIST:
            source_list.append(source)
            target_list.append(target)

    target_list.reverse()

    pivotal_path.extend(source_list)
    pivotal_path.extend(target_list[1:])  # just using source's last pivotal point
    # pivotal_path.extend(target_list[:]) # 2 last pivotal points mapping

    imgs: np.ndarray = (
        torch.stack(pivotal_path)
        .to("cpu")
        .permute(0, 2, 3, 1)
        .mul(0.5)
        .add(0.5)
        .numpy()
        .clip(0, 1)
    )
    nrows, ncols = 1, len(pivotal_path)
    fig = plt.figure(figsize=(1.5 * ncols, 1.5 * nrows), dpi=150)
    for i, img in enumerate(imgs):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        if gray:
            ax.imshow(img, cmap="gray")
        else:
            ax.imshow(img)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_title(f"$X_{i}$", fontsize=24)
        if i == imgs.shape[0] - 1:
            ax.set_title("Y", fontsize=24)

    torch.cuda.empty_cache()
    gc.collect()

## 4. Training


### Models initialization


In [None]:
Ts, Ds = [], []
T_OPTs, D_OPTs = [], []
T_SCHEDULERs, D_SCHEDULERs = [], []

for i in range(len(PIVOTAL_LIST) * 2):
    T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=UNET_BASE_FACTOR).cuda()
    Ts.append(T)

    D = ResNet_D(IMG_SIZE, nc=DATASET2_CHANNELS).cuda()
    D.apply(weights_init_D)
    Ds.append(D)

    T_opt = torch.optim.Adam(T.parameters(), lr=T_LR, weight_decay=1e-10)
    D_opt = torch.optim.Adam(D.parameters(), lr=D_LR, weight_decay=1e-10)
    T_OPTs.append(T_opt)
    D_OPTs.append(D_opt)

    T_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        T_opt, milestones=SCHEDULER_MILESTONES, gamma=0.5
    )
    D_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        D_opt, milestones=SCHEDULER_MILESTONES, gamma=0.5
    )
    T_SCHEDULERs.append(T_scheduler)
    D_SCHEDULERs.append(D_scheduler)


if len(DEVICE_IDS) > 1 and CONTINUE[0] == 0 and CONTINUE[1] == 0:
    for i in range(len(Ts)):
        Ts[i] = nn.DataParallel(Ts[i], device_ids=DEVICE_IDS)
        Ds[i] = nn.DataParallel(Ds[i], device_ids=DEVICE_IDS)

        print(f"T{i} params:", np.sum([np.prod(p.shape) for p in Ts[i].parameters()]))
        print(
            f"D{i} params:",
            np.sum([np.prod(p.shape) for p in Ds[i].parameters()]),
        )

### Load weights for continue training


In [None]:
if CONTINUE[0] > 0 or CONTINUE[1] > 0:
    print("Loading weights for continue training")
    if STRATEGY == "Adapt":
        for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
            zip(
                Ts,
                T_OPTs,
                T_SCHEDULERs,
                Ds,
                D_OPTs,
                D_SCHEDULERs,
            )
        ):
            if i > CONTINUE[1]:
                if len(DEVICE_IDS) > 1:
                    T = nn.DataParallel(T, device_ids=DEVICE_IDS)
                    D = nn.DataParallel(D, device_ids=DEVICE_IDS)
                continue
            if i < CONTINUE[1]:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{MAX_STEPS - 1}/")
            if i == CONTINUE[1]:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{CONTINUE[0] - 1}/")

            T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} T{i}_{SEED}.pt, loaded")
            T_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_opt{i}_{SEED}.pt, loaded")
            T_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_scheduler{i}_{SEED}.pt, loaded")

            D.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} D{i}_{SEED}.pt, loaded")
            D_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_opt{i}_{SEED}.pt, loaded")
            D_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_scheduler{i}_{SEED}.pt, loaded")
            if len(DEVICE_IDS) > 1:
                T = nn.DataParallel(T, device_ids=DEVICE_IDS)
                D = nn.DataParallel(D, device_ids=DEVICE_IDS)

    if STRATEGY == "Fix":
        CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{CONTINUE[0] - 1}/")
        for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
            zip(
                Ts,
                T_OPTs,
                T_SCHEDULERs,
                Ds,
                D_OPTs,
                D_SCHEDULERs,
            )
        ):
            T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} T{i}_{SEED}.pt, loaded")
            T_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_opt{i}_{SEED}.pt, loaded")
            T_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} T_scheduler{i}_{SEED}.pt, loaded")

            D.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")))
            print(f"{CKPT_DIR} D{i}_{SEED}.pt, loaded")
            D_opt.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_opt{i}_{SEED}.pt, loaded")
            D_scheduler.load_state_dict(
                torch.load(os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"))
            )
            print(f"{CKPT_DIR} D_scheduler{i}_{SEED}.pt, loaded")
            if len(DEVICE_IDS) > 1:
                T = nn.DataParallel(T, device_ids=DEVICE_IDS)
                D = nn.DataParallel(D, device_ids=DEVICE_IDS)

In [None]:
# writer.add_graph(
#     SDEs[0], torch.rand(BATCH_SIZE, DATASET1_CHANNELS, IMG_SIZE, IMG_SIZE).cuda()
# )

### Plots Test


In [None]:
X_fixed, Y_fixed = D_XY_sampler.sample(10)
X_fixed, Y_fixed = X_fixed.flatten(0, 1), Y_fixed.flatten(0, 1)

In [None]:
plot_all_pivotal(X_fixed[0], Y_fixed[0], gray=GRAY_PLOTS)

In [None]:
fig, axes = plot_linked_pushed_images(X_fixed, Y_fixed, Ts, gray=GRAY_PLOTS)

In [None]:
fig, axes = plot_linked_pushed_random_class_images(
    D_XY_sampler, Ts, plot_n_samples=PLOT_N_SAMPLES, gray=GRAY_PLOTS
)

### Main training cycle and logging


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
def get_cat_pivotal_tr(X_sampler, Y_sampler, batch_size):
    X_tr = sample_all_pivotal(X_sampler, batch_size)
    Y_tr = sample_all_pivotal(Y_sampler, batch_size * SUBSET_SIZE)
    assert len(X_tr) == len(Y_tr)
    length = len(X_tr)
    mid_indx = length // 2
    tr = X_tr[: mid_indx + 1] + Y_tr[mid_indx + 1 :]

    del X_tr[mid_indx + 1 :]
    del Y_tr[: mid_indx + 1]

    gc.collect()
    torch.cuda.empty_cache()
    return tr

#### Fix strategy


In [None]:
if STRATEGY == "Fix":
    progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE[0])

    for step in range(MAX_STEPS):
        if step < CONTINUE[0]:
            continue

        for T, D in zip(Ts, Ds):
            freeze(T)
            freeze(D)
        for _ in range(T_ITERS):
            fixed_trajectory = sample_all_pivotal(T_XY_sampler, BATCH_SIZE)

            for i, (T, T_opt, D) in enumerate(zip(Ts, T_OPTs, Ds)):
                freeze(D)
                unfreeze(T)

                T_opt.zero_grad()

                X, Y = fixed_trajectory[i], fixed_trajectory[i + 1]

                T_X = (
                    T(X.flatten(0, 1))
                    .permute(1, 2, 3, 0)
                    .reshape(DATASET2_CHANNELS, IMG_SIZE, IMG_SIZE, -1, SUBSET_SIZE)
                    .permute(3, 4, 0, 1, 2)
                )
                T_var = (
                    0.5
                    * torch.cdist(
                        T_X.flatten(start_dim=2), T_X.flatten(start_dim=2)
                    ).mean()
                    * SUBSET_SIZE
                    / (SUBSET_SIZE - 1)
                )
                cost = (Y - T_X).flatten(start_dim=2).norm(dim=2).mean()
                T_loss = cost - T_var - D(T_X.flatten(start_dim=0, end_dim=1)).mean()
                writer.add_scalar(f"T_loss/T{i}", T_loss.item(), step)
                T_loss.backward()
                T_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    T.parameters(), max_norm=T_GRADIENT_MAX_NORM
                )
                T_opt.step()
            del fixed_trajectory, X, Y, T_X, T_loss
            gc.collect()
            torch.cuda.empty_cache()
        for _ in range(T_ITERS):
            for T_scheduler in T_SCHEDULERs:
                T_scheduler.step()

        for T, D in zip(Ts, Ds):
            freeze(T)
            freeze(D)

        fixed_trajectory = get_cat_pivotal_tr(T_XY_sampler, D_XY_sampler, BATCH_SIZE)
        for i, (D, D_opt, D_scheduler, T) in enumerate(
            zip(Ds, D_OPTs, D_SCHEDULERs, Ts)
        ):
            freeze(T)
            unfreeze(D)

            D_opt.zero_grad()

            X, Y = fixed_trajectory[i], fixed_trajectory[i + 1]
            with torch.no_grad():
                T_X = T(X.flatten(start_dim=0, end_dim=1))

            D_loss = D(T_X).mean() - D(Y.flatten(start_dim=0, end_dim=1)).mean()
            writer.add_scalar(f"D_loss/D{i}", D_loss.item(), step)
            D_loss.backward()
            D_gradient_norm = torch.nn.utils.clip_grad_norm_(
                D.parameters(), max_norm=D_GRADIENT_MAX_NORM
            )
            D_opt.step()
            D_scheduler.step()

        del fixed_trajectory, X, Y, T_X, D_loss
        gc.collect()
        torch.cuda.empty_cache()

        CONTINUE[0] += 1
        progress_bar.update(1)

        if step % PLOT_INTERVAL == 0:
            progress_bar.close()
            clear_output(wait=True)
            progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE[0])
            print("Plotting")

            inference_Ts = Ts
            for T in inference_Ts:
                T.eval()
            print("Fix Test Images")
            fig, axes = plot_linked_pushed_images(
                X_fixed, Y_fixed, inference_Ts, gray=GRAY_PLOTS
            )
            writer.add_image("Fix Test Images", fig2tensor(fig), step)
            plt.show(fig)
            plt.close(fig)
            print("Random Test Images")
            fig, axes = plot_linked_pushed_random_class_images(
                D_XY_sampler,
                inference_Ts,
                plot_n_samples=PLOT_N_SAMPLES,
                gray=GRAY_PLOTS,
            )
            writer.add_image("Random Test Images", fig2tensor(fig), step)
            plt.show(fig)
            plt.close(fig)

        if step != 0 and step % CKPT_INTERVAL == 0:
            CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{step}/")
            os.makedirs(CKPT_DIR, exist_ok=True)
            for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
                zip(
                    Ts,
                    T_OPTs,
                    T_SCHEDULERs,
                    Ds,
                    D_OPTs,
                    D_SCHEDULERs,
                )
            ):
                if len(DEVICE_IDS) > 1:
                    torch.save(
                        T.module.state_dict(),
                        os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt"),
                    )
                    torch.save(
                        D.module.state_dict(),
                        os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt"),
                    )
                else:
                    torch.save(
                        T.state_dict(), os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
                    )
                    torch.save(
                        D.state_dict(), os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")
                    )

                torch.save(
                    D_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    T_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    D_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"),
                )
                torch.save(
                    T_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"),
                )
            log["CONTINUE"] = CONTINUE
            with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
                log_str = json.dumps(log, indent=4)
                log_file.write(log_str)

        if step % TRACK_VAR_INTERVAL == 0:
            # after training, using test_transport.ipynb to get fid acc ...
            pass

        gc.collect()
        torch.cuda.empty_cache()

#### Adapt strategy


In [None]:
if STRATEGY == "Adapt":
    for i, (T, T_opt, T_scheduler, D, D_opt, D_scheduler) in enumerate(
        zip(
            Ts,
            T_OPTs,
            T_SCHEDULERs,
            Ds,
            D_OPTs,
            D_SCHEDULERs,
        )
    ):
        if i < CONTINUE[1]:
            continue
        progress_bar = tqdm(
            total=MAX_STEPS, initial=CONTINUE[0], desc=f"{i + 1}/{len(Ts)}:"
        )
        for _T, _D in zip(Ts, Ds):
            freeze(_T)
            freeze(_D)

        for step in range(MAX_STEPS):
            if step < CONTINUE[0]:
                continue

            for t_iter in range(T_ITERS):
                freeze(D)
                unfreeze(T)

                # === sampler training data ===
                fixed_trajectory = sample_all_pivotal(T_XY_sampler, BATCH_SIZE)
                X, Y = fixed_trajectory[0], fixed_trajectory[i + 1]
                X = X.flatten(0, 1)
                with torch.no_grad():
                    for j in range(i):
                        X = Ts[j](X)
                X.requires_grad_()
                # === mapping and optimize ===
                T_opt.zero_grad()

                T_X = (
                    T(X)
                    .permute(1, 2, 3, 0)
                    .reshape(DATASET2_CHANNELS, IMG_SIZE, IMG_SIZE, -1, SUBSET_SIZE)
                    .permute(3, 4, 0, 1, 2)
                )
                T_var = (
                    0.5
                    * torch.cdist(
                        T_X.flatten(start_dim=2), T_X.flatten(start_dim=2)
                    ).mean()
                    * SUBSET_SIZE
                    / (SUBSET_SIZE - 1)
                )
                cost = (Y - T_X).flatten(start_dim=2).norm(dim=2).mean()
                T_loss = cost - T_var - D(T_X.flatten(start_dim=0, end_dim=1)).mean()

                writer.add_scalar(f"T_loss/T{i}", T_loss.item(), step)
                T_loss.backward()
                T_gradient_norm = torch.nn.utils.clip_grad_norm_(
                    T.parameters(), max_norm=T_GRADIENT_MAX_NORM
                )
                T_opt.step()

                del fixed_trajectory, X, Y, T_X, T_loss
                gc.collect()
                torch.cuda.empty_cache()
            T_scheduler.step()

            freeze(T)
            unfreeze(D)
            # === sampler training data ===
            fixed_trajectory = get_cat_pivotal_tr(
                T_XY_sampler, D_XY_sampler, BATCH_SIZE
            )
            X, Y = fixed_trajectory[0], fixed_trajectory[i + 1]
            X = X.flatten(0, 1)
            with torch.no_grad():
                for j in range(i):
                    X = Ts[j](X)
            X.requires_grad_()
            # === mapping and optimize ===
            D_opt.zero_grad()
            T_X = T(X)
            D_loss = D(T_X).mean() - D(Y.flatten(start_dim=0, end_dim=1)).mean()
            writer.add_scalar(f"D_loss/D{i}", D_loss.item(), step)
            D_loss.backward()
            D_gradient_norm = torch.nn.utils.clip_grad_norm_(
                D.parameters(), max_norm=D_GRADIENT_MAX_NORM
            )
            D_opt.step()
            D_scheduler.step()

            del fixed_trajectory, X, Y, T_X, D_loss
            gc.collect()
            torch.cuda.empty_cache()

            CONTINUE[0] += 1
            progress_bar.update(1)

            if step % PLOT_INTERVAL == 0:
                progress_bar.close()
                clear_output(wait=True)
                progress_bar = tqdm(
                    total=MAX_STEPS, initial=CONTINUE[0], desc=f"{i + 1}/{len(Ts)}:"
                )
                print("Plotting")

                inference_Ts = Ts
                for T in inference_Ts:
                    T.eval()

                print("Fixed Test Images")
                fig, axes = plot_linked_pushed_images(
                    X_fixed, Y_fixed, inference_Ts, gray=GRAY_PLOTS
                )
                writer.add_image(f"Fix Test Images/T{i + 1}", fig2tensor(fig), step)
                plt.show(fig)
                plt.close(fig)

                print("Random Test Images")
                fig, axes = plot_linked_pushed_random_class_images(
                    D_XY_sampler,
                    inference_Ts,
                    plot_n_samples=PLOT_N_SAMPLES,
                    gray=GRAY_PLOTS,
                )
                writer.add_image(f"Random Test Images/T{i + 1}", fig2tensor(fig), step)
                plt.show(fig)
                plt.close(fig)

            if step != 0 and step % CKPT_INTERVAL == 0:
                CKPT_DIR = os.path.join(OUTPUT_PATH, f"iter{step}/")
                os.makedirs(CKPT_DIR, exist_ok=True)
                if len(DEVICE_IDS) > 1:
                    torch.save(
                        T.module.state_dict(),
                        os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt"),
                    )
                    torch.save(
                        D.module.state_dict(),
                        os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt"),
                    )
                else:
                    torch.save(
                        T.state_dict(), os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
                    )
                    torch.save(
                        D.state_dict(), os.path.join(CKPT_DIR, f"D{i}_{SEED}.pt")
                    )

                torch.save(
                    D_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"D_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    T_opt.state_dict(),
                    os.path.join(CKPT_DIR, f"T_opt{i}_{SEED}.pt"),
                )
                torch.save(
                    D_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"D_scheduler{i}_{SEED}.pt"),
                )
                torch.save(
                    T_scheduler.state_dict(),
                    os.path.join(CKPT_DIR, f"T_scheduler{i}_{SEED}.pt"),
                )
                log = dict(CONTINUE=CONTINUE)
                with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
                    log_str = json.dumps(log, indent=4)
                    log_file.write(log_str)

            if i == len(Ts) - 1 and step % TRACK_VAR_INTERVAL == 0:
                # after training, using test_transport.ipynb to get fid acc ...
                pass

            gc.collect()
            torch.cuda.empty_cache()
        CONTINUE[0] = 0  # reset training steps to 0
        CONTINUE[1] += 1

## Clear resources


In [None]:
try:
    writer.close()
    progress_bar.close()
except Exception as e:
    print(e)

In [None]:
print("Plotting")

inference_Ts = Ts
for T in inference_Ts:
    T.eval()

print("Fixed Test Images")
fig, axes = plot_linked_pushed_images(X_fixed, Y_fixed, inference_Ts, gray=GRAY_PLOTS)
writer.add_image(f"Fix Test Images/T{i + 1}", fig2tensor(fig), step)
plt.show(fig)
plt.close(fig)

print("Random Test Images")
fig, axes = plot_linked_pushed_random_class_images(
    D_XY_sampler,
    inference_Ts,
    plot_n_samples=PLOT_N_SAMPLES,
    gray=GRAY_PLOTS,
)
writer.add_image(f"Random Test Images/T{i + 1}", fig2tensor(fig), step)
plt.show(fig)
plt.close(fig)