## 1. Imports


In [None]:
import sys
import gc
import os

import torch
from torch.utils.data import Subset
import torchvision.datasets as datasets
import numpy as np
from diffusers import DDIMScheduler
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

sys.path.append("..")
from utils import (
    set_random_seed,
    get_ddib_path,
    get_flow_path,
    cod_prob_bound,
)

from sampler import get_paired_dataset

%matplotlib inline 

## 2. Config


In [None]:
SEED = 0x4090
set_random_seed(SEED)

# dataset choosing
# DATASET, DATASET_PATH, IMG_SIZE, GRAY = "usps2mnist", "./data/", 28, True
# DATASET, DATASET_PATH, IMG_SIZE, GRAY = "mnist2fmnist", "./data/", 28, True
# DATASET, DATASET_PATH, IMG_SIZE, GRAY = "fmnist2usps", "./data/", 28, True


DATASET, DATASET_PATH, IMG_SIZE, GRAY = (
    "comic_faces_v1",
    "./data/face2comics_v1.0.0_by_Sxela",
    512,
    False,
)

# the step number adding noise in diffusion process
DIFFUSION_STEPS = 100
SCHEDULER = DDIMScheduler(num_train_timesteps=DIFFUSION_STEPS)
PIVOTAL_LIST = [i for i in range(0, DIFFUSION_STEPS * 2 + 1, 100)]
# PIVOTAL_LIST = [20, 50, 100]
# All hyperparameters below is set to the values used for the experiments, which discribed in the article
EPSILON = 0.1
# R = 1000
N = 2
P = 1
NUM_QUERY = 100
MAX_NUM_SAMPLE = 6000

# data sample settings
SELECTED_CLASSES = [2]

## 3. Initialize dataset


In [None]:
class PairedDataset2(torch.utils.data.Dataset):
    def __init__(
        self,
        name,
        root,
        train=True,
        transform=None,
        download=False,
        selected_classes=None,
        reverse=False,
    ):
        super().__init__()
        self.reverse = reverse

        if name == "usps2mnist":
            source = datasets.USPS(
                os.path.join(root, "USPS", "raw"), train, transform, download=download
            )
            target = datasets.MNIST(root, train, transform, download=download)
        elif name == "mnist2fmnist":
            source = datasets.MNIST(root, train, transform, download=download)
            target = datasets.FashionMNIST(root, train, transform, download=download)
        elif name == "fmnist2usps":
            source = datasets.FashionMNIST(root, train, transform, download=download)
            target = datasets.USPS(
                os.path.join(root, "USPS", "raw"), train, transform, download=download
            )
        else:
            raise "Invalid dataset name"

        if selected_classes is not None:
            source_indices = [
                i for i in range(len(source)) if source.targets[i] in selected_classes
            ]
            target_indices = [
                i for i in range(len(target)) if target.targets[i] in selected_classes
            ]
        else:
            source_indices = range(len(source))
            target_indices = range(len(target))

        source_indices = source_indices[: min(len(source), len(target))]
        target_indices = target_indices[: min(len(source), len(target))]

        self.x = Subset(source, source_indices)
        self.y = Subset(target, target_indices)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x, y = (self.x)[idx][0], (self.y)[idx][0]

        return (x, y) if not self.reverse else (y, x)

In [None]:
if DATASET in ["mnist2fmnist", "usps2mnist", "fmnist2usps"]:
    transform = Compose(
        [
            Resize((IMG_SIZE, IMG_SIZE)),
            ToTensor(),
            Normalize((0.5), (0.5)),
        ]
    )
    dataset = PairedDataset2(
        DATASET,
        DATASET_PATH,
        transform=transform,
        selected_classes=SELECTED_CLASSES,
    )

elif DATASET == "comic_faces_v1":
    transform = Compose(
        [
            Resize((IMG_SIZE, IMG_SIZE)),
            ToTensor(),
            Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    dataset, _ = get_paired_dataset(
        DATASET,
        DATASET_PATH,
        transform,
    )
else:
    raise "Invalid dataset"

if len(dataset) > MAX_NUM_SAMPLE:
    dataset = Subset(dataset, range(MAX_NUM_SAMPLE))

len(dataset)

In [None]:
(
    type(dataset[0][0]),
    round(sys.getsizeof(dataset[0][0]) * 2 * len(dataset) * 2000 / (1024 * 1024), 2),
)

In [None]:
def plot_data_point(x, gray=GRAY):
    if gray:
        plt.imshow(x.squeeze().numpy(), cmap="gray")
    else:
        plt.imshow(x.squeeze().permute(1, 2, 0).numpy())
    plt.axis("off")  # 不显示坐标轴
    plt.show()


plot_data_point(dataset[0][0])
plot_data_point(dataset[0][1])

In [None]:
from typing import Union


def plot_path(
    path: Union[list, torch.Tensor],
    indices: list = None,
    gray: bool = False,
):
    if indices is not None:
        path = [path[i] for i in indices]
    if isinstance(path, list):
        path = torch.stack(path)

    imgs: np.ndarray = (
        path.to("cpu").permute(0, 2, 3, 1).mul(0.5).add(0.5).numpy().clip(0, 1)
    )

    if len(path) < 10:
        ncols = len(path)
        nrows = 1
    else:
        ncols = 10
        nrows = len(path) // 10 + 1

    fig = plt.figure(figsize=(1.5 * ncols, 1.5 * nrows), dpi=150)
    for i, img in enumerate(imgs):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        if gray:
            ax.imshow(img, cmap="gray")
        else:
            ax.imshow(img)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_title(f"$X_{{{i}}}$", fontsize=16)
        if i == imgs.shape[0] - 1:
            ax.set_title("Y", fontsize=16)
    fig.tight_layout()
    torch.cuda.empty_cache()
    gc.collect()

## 4. Concentration of Distance Probability


### DDIB


#### generate asymptotic distribution


In [None]:
image_shape = (1, IMG_SIZE, IMG_SIZE) if GRAY else (3, IMG_SIZE, IMG_SIZE)

x2y_ddib_asymptotic_dataset = torch.empty(
    (DIFFUSION_STEPS * 2 + 1, len(dataset), *image_shape), dtype=torch.float32
)

In [None]:
# FIXME: 512图片内存占用巨大，需要优化（或者服务器运行尝试）
for i, (x, y) in enumerate(tqdm(dataset)):
    path = get_ddib_path(x, y, SCHEDULER)
    x2y_ddib_asymptotic_dataset[:, i] = path

# x2y_ddib_path_list = torch.stack(x2y_ddib_path_list)
# x2y_ddib_asymptotic_dataset = x2y_ddib_path_list.transpose(0, 1)

In [None]:
x2y_ddib_asymptotic_dataset.shape

In [None]:
plot_path(
    x2y_ddib_asymptotic_dataset[:, 0],
    # indices=PIVOTAL_LIST,
    gray=GRAY,
)

#### caculate CoD Prob Bound


In [None]:
mid_step = x2y_ddib_asymptotic_dataset.shape[0] // 2

Q_point_indices = np.random.choice(
    x2y_ddib_asymptotic_dataset.shape[1], NUM_QUERY, replace=False
)

In [None]:
Q_point_dataset = x2y_ddib_asymptotic_dataset[0][Q_point_indices]
T_point_dataset = x2y_ddib_asymptotic_dataset[-1]

prob_bound_list = []
for Q_point in Q_point_dataset:
    prob_bound = cod_prob_bound(
        T_point_dataset,
        EPSILON,
        Q_point,
        "euclidean",
        N,
    )
    prob_bound_list.append(prob_bound)
prob_bound_list = np.array(prob_bound_list)
x2y_ddim_prob_bound = prob_bound_list.mean()

print(
    f"t={0} -> t={x2y_ddib_asymptotic_dataset.shape[0]-1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={x2y_ddim_prob_bound:.4%}"
)

In [None]:
Q_point_dataset = x2y_ddib_asymptotic_dataset[0][Q_point_indices]
T_point_dataset = x2y_ddib_asymptotic_dataset[mid_step]

prob_bound_list = []
for Q_point in Q_point_dataset:
    prob_bound = cod_prob_bound(
        T_point_dataset,
        EPSILON,
        Q_point,
        "euclidean",
        N,
    )
    prob_bound_list.append(prob_bound)
prob_bound_list = np.array(prob_bound_list)
x2g_ddim_prob_bound = prob_bound_list.mean()

print(
    f"t={0} -> t={mid_step}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={x2g_ddim_prob_bound:.4%}"
)

In [None]:
Q_point_dataset = x2y_ddib_asymptotic_dataset[mid_step][Q_point_indices]
T_point_dataset = x2y_ddib_asymptotic_dataset[-1]

prob_bound_list = []
for Q_point in Q_point_dataset:
    prob_bound = cod_prob_bound(
        T_point_dataset,
        EPSILON,
        Q_point,
        "euclidean",
        N,
    )
    prob_bound_list.append(prob_bound)
prob_bound_list = np.array(prob_bound_list)
g2y_ddim_prob_bound = prob_bound_list.mean()

print(
    f"t={mid_step} -> t={x2y_ddib_asymptotic_dataset.shape[0] - 1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={g2y_ddim_prob_bound:.4%}"
)

In [None]:
x2y_ddim_prob_bound_list = []
for t in range(x2y_ddib_asymptotic_dataset.shape[0] - 1):
    Q_point_dataset = x2y_ddib_asymptotic_dataset[t][Q_point_indices]
    T_point_dataset = x2y_ddib_asymptotic_dataset[t + 1]

    # plot_data_point(Q_point)
    prob_bound_list = []
    for Q_point in Q_point_dataset:
        prob_bound = cod_prob_bound(
            T_point_dataset,
            EPSILON,
            Q_point,
            "euclidean",
            N,
        )
        prob_bound_list.append(prob_bound)
    prob_bound_list = np.array(prob_bound_list)
    prob_bound = prob_bound_list.mean()
    print(f"t={t} -> t={t+1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={prob_bound:.4%}")

    x2y_ddim_prob_bound_list.append(prob_bound)

#### Plot


In [None]:
plt.vlines(mid_step, 0, 1, colors="black")

plt.hlines(x2y_ddim_prob_bound, 0, DIFFUSION_STEPS * 2, colors="red")
plt.hlines(x2g_ddim_prob_bound, 0, mid_step, colors="red")
plt.hlines(g2y_ddim_prob_bound, mid_step, DIFFUSION_STEPS * 2, colors="red")

for t, prob in enumerate(x2y_ddim_prob_bound_list):
    plt.hlines(prob, t, t + 1)

plt.xlabel("t")
plt.ylabel(f"Lower Bound of P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}")

plt.xlim(0, DIFFUSION_STEPS * 2)
plt.ylim(0, 1)
plt.margins(x=0)
plt.title(f"DDIB: {DATASET}, $\epsilon$={EPSILON}, {DIFFUSION_STEPS} steps")
plt.show()

### Flow


#### generate asymptotic distribution


In [None]:
x2y_flow_path_list = []
for x, y in tqdm(dataset):
    path = get_flow_path(x, y, DIFFUSION_STEPS + 1)
    path = torch.stack(path)
    x2y_flow_path_list.append(path)

x2y_flow_path_list = torch.stack(x2y_flow_path_list)
x2y_flow_asymptotic_dataset = x2y_flow_path_list.transpose(0, 1)

In [None]:
plot_path(
    x2y_flow_path_list[0],
    indices=PIVOTAL_LIST,
    gray=GRAY,
)

#### caculate CoD Prob Bound


In [None]:
Q_point_indices = np.random.choice(
    x2y_flow_path_list.shape[0], NUM_QUERY, replace=False
)

In [None]:
Q_point_dataset = x2y_flow_asymptotic_dataset[0][Q_point_indices]
T_point_dataset = x2y_flow_asymptotic_dataset[-1]

prob_bound_list = []
for Q_point in Q_point_dataset:
    prob_bound = cod_prob_bound(
        T_point_dataset,
        EPSILON,
        Q_point,
        "euclidean",
        N,
    )
    prob_bound_list.append(prob_bound)
prob_bound_list = np.array(prob_bound_list)
x2y_flow_prob_bound = prob_bound_list.mean()

print(
    f"t={0} -> t={x2y_flow_asymptotic_dataset.shape[0]-1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={x2y_flow_prob_bound:.4%}"
)

In [None]:
x2y_flow_prob_bound_list = []
for t in range(x2y_flow_asymptotic_dataset.shape[0] - 1):
    Q_point_dataset = x2y_flow_asymptotic_dataset[t][Q_point_indices]
    T_point_dataset = x2y_flow_asymptotic_dataset[t + 1]

    # plot_data_point(Q_point)
    prob_bound_list = []
    for Q_point in Q_point_dataset:
        prob_bound = cod_prob_bound(
            T_point_dataset,
            EPSILON,
            Q_point,
            "euclidean",
            N,
        )
        prob_bound_list.append(prob_bound)
    prob_bound_list = np.array(prob_bound_list)
    prob_bound = prob_bound_list.mean()
    print(f"t={t} -> t={t+1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={prob_bound:.4%}")

    x2y_flow_prob_bound_list.append(prob_bound)

#### Plot


In [None]:
plt.hlines(x2y_flow_prob_bound, 0, DIFFUSION_STEPS, colors="red")


for t, prob in enumerate(x2y_flow_prob_bound_list):
    plt.hlines(prob, t, t + 1)


plt.xlabel("t")
plt.ylabel(f"Lower Bound of P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}")

plt.xlim(0, DIFFUSION_STEPS)
plt.ylim(0, 0.5 if x2y_flow_prob_bound < 0.5 else 1)
plt.margins(x=0)
plt.title(f"Flow: {DATASET}, $\epsilon$={EPSILON}, {DIFFUSION_STEPS} steps")
plt.show()

# TODO LIST

1. 数据集加载
   - [√] FMNIST, MNIST, USPS
   - [√] comic_faces_v1
2. 多种渐变方式
   - [√] 扩散(DDIM)：并无直接 X->Y 的分布转移渐变，只有 X->高斯->Y 分布。但高斯分布本身会造成严重距离聚集。
   - [√] Flow(插值)：X->Y 分布转移渐变，纯粹生成一般随机采样高斯噪声作为 X。
   - [] 薛定谔桥：同 Flow
3. [√]CoD 概率下界计算
4. []画图
   - [√] 直接
   - [√] 逐步
   - [] 选取节点：加噪 1000 步，选取个别节点
5. []级联 OT
