## 1. Imports


In [None]:
import sys
import gc

import torch
from torch.utils.data import Subset
import torchvision.datasets as datasets
import numpy as np
from diffusers import DDIMScheduler
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

sys.path.append("..")
from utils import (
    set_random_seed,
    get_ddim_path,
    get_flow_path,
    cod_prob_bound,
)

%matplotlib inline 

## 2. Config


In [None]:
SEED = 0x4090
set_random_seed(SEED)

# dataset choosing
# DATASET, DATASET_PATH = "fmnist2mnist", "./data/"
# DATASET, DATASET_PATH = "usps2mnist", "./data/"
# DATASET, DATASET_PATH = "mnist2fmnist", "./data/"
DATASET, DATASET_PATH = "usps2fmnist", "./data/"

IMG_SIZE = 28

# the step number adding noise in diffusion process
DIFFUSION_STEPS = 100
SCHEDULER = DDIMScheduler(num_train_timesteps=DIFFUSION_STEPS)
PIVOTAL_LIST = None
# PIVOTAL_LIST = [20, 50, 100]
# All hyperparameters below is set to the values used for the experiments, which discribed in the article
EPSILON = 0.25
# R = 1000
N = 2
P = 1
NUM_QUERY = 10

# data sample settings
SUBSET_CLASS = 2

# plot settings
GRAY_PLOTS = True

## 3. Initialize samplers


In [None]:
source_transform = Compose(
    [
        Resize((IMG_SIZE, IMG_SIZE)),
        ToTensor(),
        Normalize((0.5), (0.5)),
    ]
)
target_transform = source_transform

if DATASET == "fmnist2mnist":
    source = datasets.FashionMNIST
    target = datasets.MNIST
elif DATASET == "usps2mnist":
    source = datasets.USPS
    target = datasets.MNIST
elif DATASET == "usps2fmnist":
    source = datasets.USPS
    target = datasets.FashionMNIST
else:
    raise "Invalid dataset"

In [None]:
source_dataset = source(
    root=DATASET_PATH, train=True, download=True, transform=source_transform
)
target_dataset = target(
    root=DATASET_PATH, train=True, download=True, transform=target_transform
)

source_indices = [
    i for i, label in enumerate(source_dataset.targets) if label == SUBSET_CLASS
]
target_indices = [
    i for i, label in enumerate(target_dataset.targets) if label == SUBSET_CLASS
]

source_dataset = Subset(source_dataset, source_indices)
target_dataset = Subset(target_dataset, target_indices)

len(source_dataset), len(target_dataset)

In [None]:
def plot_data_point(x):
    plt.imshow(x.squeeze().numpy(), cmap="gray")
    plt.axis("off")  # 不显示坐标轴
    plt.show()


plot_data_point(source_dataset[0][0])
plot_data_point(target_dataset[0][0])

In [None]:
from typing import Union


def plot_path(
    path: Union[list, torch.Tensor],
    indices: list = None,
    gray: bool = False,
):
    if indices is not None:
        path = [path[i] for i in indices]
    if isinstance(path, list):
        path = torch.stack(path)

    imgs: np.ndarray = (
        path.to("cpu").permute(0, 2, 3, 1).mul(0.5).add(0.5).numpy().clip(0, 1)
    )

    if len(path) < 10:
        ncols = len(path)
        nrows = 1
    else:
        ncols = 10
        nrows = len(path) // 10 + 1

    fig = plt.figure(figsize=(1.5 * ncols, 1.5 * nrows), dpi=150)
    for i, img in enumerate(imgs):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        if gray:
            ax.imshow(img, cmap="gray")
        else:
            ax.imshow(img)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_title(f"$X_{{{i}}}$", fontsize=16)
        if i == imgs.shape[0] - 1:
            ax.set_title("Y", fontsize=16)
    fig.tight_layout()
    torch.cuda.empty_cache()
    gc.collect()


## 4. Concentration of Distance Probability


### DDIM


#### X2G: generate asymptotic distribution


In [None]:
x_path_list = []
for x, _ in tqdm(source_dataset, total=len(source_dataset)):
    path = get_ddim_path(x, SCHEDULER, reverse=False)
    path = torch.stack(path)
    x_path_list.append(path)

x_path_list = torch.stack(x_path_list)
x2g_asymptotic_dataset = x_path_list.transpose(0, 1)

In [None]:
x_path_list.shape

In [None]:
plot_path(
    x_path_list[0],
    # indices=PIVOTAL_LIST,
    gray=GRAY_PLOTS,
)

#### X2G: caculate CoD Prob Bound


In [None]:
Q_point_indices = np.random.choice(x_path_list.shape[0], NUM_QUERY, replace=False)
Q_point_dataset = x2g_asymptotic_dataset[0][Q_point_indices]
T_point_dataset = x2g_asymptotic_dataset[-1]

prob_bound_list = []
for Q_point in Q_point_dataset:
    prob_bound = cod_prob_bound(
        T_point_dataset,
        EPSILON,
        Q_point,
        "euclidean",
        N,
    )
    prob_bound_list.append(prob_bound)
prob_bound_list = np.array(prob_bound_list)
x2g_ddim_prob_bound = prob_bound_list.mean()

print(
    f"t={0} -> t={x2g_asymptotic_dataset.shape[0]-1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={x2g_ddim_prob_bound:.4%}"
)

In [None]:
x2g_ddim_prob_bound_list = []
for t in range(x2g_asymptotic_dataset.shape[0] - 1):
    Q_point_dataset = x2g_asymptotic_dataset[t][Q_point_indices]
    T_point_dataset = x2g_asymptotic_dataset[t + 1]

    # plot_data_point(Q_point)
    prob_bound_list = []
    for Q_point in Q_point_dataset:
        prob_bound = cod_prob_bound(
            T_point_dataset,
            EPSILON,
            Q_point,
            "euclidean",
            N,
        )
        prob_bound_list.append(prob_bound)
    prob_bound_list = np.array(prob_bound_list)
    prob_bound = prob_bound_list.mean()
    print(f"t={t} -> t={t+1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={prob_bound:.4%}")

    x2g_ddim_prob_bound_list.append(prob_bound)

#### G2Y: generate asymptotic distribution


In [None]:
y_path_list = []
for y, _ in tqdm(target_dataset, total=len(target_dataset)):
    path = get_ddim_path(y, SCHEDULER)
    path = torch.stack(path)
    y_path_list.append(path)

y_path_list = torch.stack(y_path_list)
g2y_asymptotic_dataset = y_path_list.transpose(0, 1)

In [None]:
y_path_list.shape

In [None]:
plot_path(
    y_path_list[0],
    indices=PIVOTAL_LIST,
    gray=GRAY_PLOTS,
)

#### G2Y: caculate CoD Prob Bound


In [None]:
Q_point_indices = np.random.choice(y_path_list.shape[0], NUM_QUERY, replace=False)
Q_point_dataset = g2y_asymptotic_dataset[0][Q_point_indices]
T_point_dataset = g2y_asymptotic_dataset[-1]

prob_bound_list = []
for Q_point in Q_point_dataset:
    prob_bound = cod_prob_bound(
        T_point_dataset,
        EPSILON,
        Q_point,
        "euclidean",
        N,
    )
    prob_bound_list.append(prob_bound)
prob_bound_list = np.array(prob_bound_list)
g2y_ddim_prob_bound = prob_bound_list.mean()

print(
    f"t={0} -> t={g2y_asymptotic_dataset.shape[0]-1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={g2y_ddim_prob_bound:.4%}"
)

In [None]:
g2y_ddim_prob_bound_list = []
for t in range(g2y_asymptotic_dataset.shape[0] - 1):
    Q_point_dataset = g2y_asymptotic_dataset[t][Q_point_indices]
    T_point_dataset = g2y_asymptotic_dataset[t + 1]

    # plot_data_point(Q_point)
    prob_bound_list = []
    for Q_point in Q_point_dataset:
        prob_bound = cod_prob_bound(
            T_point_dataset,
            EPSILON,
            Q_point,
            "euclidean",
            N,
        )
        prob_bound_list.append(prob_bound)
    prob_bound_list = np.array(prob_bound_list)
    prob_bound = prob_bound_list.mean()
    print(f"t={t} -> t={t+1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={prob_bound:.4%}")

    g2y_ddim_prob_bound_list.append(prob_bound)

In [None]:
len(g2y_ddim_prob_bound_list)

## Flow


### generate asymptotic distribution


In [None]:
flow_path_list = []
for (x, _), (y, _) in tqdm(
    zip(source_dataset, target_dataset),
    total=min(len(source_dataset), len(target_dataset)),
):
    path = get_flow_path(x, y, DIFFUSION_STEPS + 1)
    path = torch.stack(path)
    flow_path_list.append(path)

flow_path_list = torch.stack(flow_path_list)
x2y_flow_asymptotic_dataset = flow_path_list.transpose(0, 1)

In [None]:
plot_path(
    flow_path_list[0],
    indices=PIVOTAL_LIST,
    gray=GRAY_PLOTS,
)

### caculate CoD Prob Bound


In [None]:
Q_point_indices = np.random.choice(flow_path_list.shape[0], NUM_QUERY, replace=False)

In [None]:
Q_point_dataset = x2y_flow_asymptotic_dataset[0][Q_point_indices]
T_point_dataset = x2y_flow_asymptotic_dataset[-1]

prob_bound_list = []
for Q_point in Q_point_dataset:
    prob_bound = cod_prob_bound(
        T_point_dataset,
        EPSILON,
        Q_point,
        "euclidean",
        N,
    )
    prob_bound_list.append(prob_bound)
prob_bound_list = np.array(prob_bound_list)
x2y_flow_prob_bound = prob_bound_list.mean()

print(
    f"t={0} -> t={x2y_flow_asymptotic_dataset.shape[0]-1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={x2y_flow_prob_bound:.4%}"
)

In [None]:
x2y_flow_prob_bound_list = []
for t in range(x2y_flow_asymptotic_dataset.shape[0] - 1):
    Q_point_dataset = x2y_flow_asymptotic_dataset[t][Q_point_indices]
    T_point_dataset = x2y_flow_asymptotic_dataset[t + 1]

    # plot_data_point(Q_point)
    prob_bound_list = []
    for Q_point in Q_point_dataset:
        prob_bound = cod_prob_bound(
            T_point_dataset,
            EPSILON,
            Q_point,
            "euclidean",
            N,
        )
        prob_bound_list.append(prob_bound)
    prob_bound_list = np.array(prob_bound_list)
    prob_bound = prob_bound_list.mean()
    print(f"t={t} -> t={t+1}: P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}>={prob_bound:.4%}")

    x2y_flow_prob_bound_list.append(prob_bound)

## Plot


In [None]:
plt.hlines(x2y_flow_prob_bound, 0, DIFFUSION_STEPS, colors="red")


for t, prob in enumerate(x2y_flow_prob_bound_list):
    plt.hlines(prob, t, t + 1)


plt.xlabel("t")
plt.ylabel(f"Lower Bound of P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}")

plt.xlim(0, DIFFUSION_STEPS)
plt.ylim(0, 0.5 if x2y_flow_prob_bound < 0.5 else 1)
plt.margins(x=0)
plt.title(f"Flow: {DATASET}, $\epsilon$={EPSILON}, {DIFFUSION_STEPS} steps")
plt.show()


In [None]:
x2g_ddim_prob_bound

In [None]:
plt.vlines(DIFFUSION_STEPS, 0, 1, colors="black")

plt.hlines(x2g_ddim_prob_bound, 0, DIFFUSION_STEPS, colors="red")
plt.hlines(g2y_ddim_prob_bound, DIFFUSION_STEPS, DIFFUSION_STEPS * 2, colors="red")

for t, prob in enumerate(x2g_ddim_prob_bound_list):
    plt.hlines(prob, t, t + 1)
for t, prob in enumerate(g2y_ddim_prob_bound_list):
    plt.hlines(prob, DIFFUSION_STEPS + t, DIFFUSION_STEPS + t + 1)

plt.xlabel("t")
plt.ylabel(f"Lower Bound of P{{DMAX({N})<=(1+{EPSILON})DMIN({N})}}")

plt.xlim(0, DIFFUSION_STEPS * 2)
plt.ylim(0, 1)
plt.margins(x=0)
plt.title(f"DDIM: {DATASET}, $\epsilon$={EPSILON}, {DIFFUSION_STEPS} steps")
plt.show()

# TODO LIST

[] USPS 数据集

[] 加噪 1000 步，选取个别节点

0. []画图
1. 数据集
   - [√] FMNIST, MNIST
   - [] comic_faces_v1
2. 多种渐变方式
   - [√] 扩散(DDIM)：并无直接 X->Y 的分布转移渐变，只有 X->高斯->Y 分布。但高斯分布本身会造成严重距离聚集。
   - [√] Flow(插值)：X->Y 分布转移渐变，纯粹生成一般随机采样高斯噪声作为 X。
   - [] 薛定谔桥：同 Flow
3. []级联 OT
