## Imports


In [None]:
import os
import sys
import gc
import warnings

import torch
import torchvision
import numpy as np
from torchvision.transforms import Compose, ToTensor, Resize, Normalize, Lambda
from PIL import PngImagePlugin

from IPython.display import clear_output

sys.path.append("..")
from src.cunet import CUNet
from src.enot import SDE
from src.unet import UNet
from src.mnistm_utils import MNISTM
from src.tools import (
    set_random_seed,
)
from src.plotters import (
    plot_pushed_images,
    plot_pushed_random_class_images,
    plot_sde_pushed_images,
    plot_sde_pushed_random_class_images,
    plot_linked_pushed_images,
    plot_linked_pushed_random_class_images,
    plot_linked_sde_pushed_images,
    plot_linked_sde_pushed_random_class_images,
)
from src.samplers import (
    SubsetGuidedSampler,
    SubsetGuidedDataset,
    get_indicies_subset,
)


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## General Config

In [None]:
SEED = 0x0000
set_random_seed(SEED)

# dataset choosing
DATASET, DATASET_PATH = "fmnist2mnist", "../datasets/"
# DATASET, DATASET_PATH = "mnist2mnistm", "../datasets/"
# DATASET, DATASET_PATH = "mnist2usps", "../datasets/"
# DATASET, DATASET_PATH = "mnist2kmnist", "../datasets/"

IMG_SIZE = 32
DATASET1_CHANNELS = 1
DATASET2_CHANNELS = 1

# GPU choosing
DEVICE_IDS = [1]
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_IDS[0]}")

# training algorithm settings
BATCH_SIZE = 32
SUBSET_SIZE = 2
SUBSET_CLASS = 3

# plot settings
GRAY_PLOTS = True

FID_EPOCHS = 1

In [None]:
source_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_source = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}
target_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_target = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}

SUBSET_WEIGHT = [0 for _ in range(len(source_subset))]
SUBSET_WEIGHT[SUBSET_CLASS] = 1.0

In [None]:
source_transform = Compose(
    [
        Resize((IMG_SIZE, IMG_SIZE)),
        ToTensor(),
        Normalize((0.5), (0.5)),
    ]
)
target_transform = source_transform

if DATASET == "mnist2kmnist":
    source = torchvision.datasets.MNIST
    target = torchvision.datasets.KMNIST


elif DATASET == "fmnist2mnist":
    source = torchvision.datasets.FashionMNIST
    target = torchvision.datasets.MNIST


elif DATASET == "mnist2usps":
    source = torchvision.datasets.MNIST
    target = torchvision.datasets.USPS


elif DATASET == "mnist2mnistm":
    DATASET1_CHANNELS = 3
    DATASET2_CHANNELS = 3

    GRAY_PLOTS = False
    source = torchvision.datasets.MNIST
    target = MNISTM
    source_transform = Compose(
        [
            Resize((IMG_SIZE, IMG_SIZE)),
            ToTensor(),
            Normalize((0.5), (0.5)),
            Lambda(lambda x: -x.repeat(3, 1, 1)),
        ]
    )
    target_transform = Compose(
        [Resize(IMG_SIZE), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

## Initialize samplers


In [None]:
source_test = source(
    root=DATASET_PATH, train=False, download=True, transform=source_transform
)
source_subset_samples, source_labels, source_class_indicies = get_indicies_subset(
    source_test,
    new_labels=new_labels_source,
    classes=len(source_subset),
    subset_classes=source_subset,
)
source_test = torch.utils.data.TensorDataset(
    torch.stack(source_subset_samples), torch.LongTensor(source_labels)
)


target_test = target(
    root=DATASET_PATH, train=False, download=True, transform=target_transform
)
target_subset_samples, target_labels, target_class_indicies = get_indicies_subset(
    target_test,
    new_labels=new_labels_target,
    classes=len(target_subset),
    subset_classes=target_subset,
)
target_test = torch.utils.data.TensorDataset(
    torch.stack(target_subset_samples), torch.LongTensor(target_labels)
)

full_set_test = SubsetGuidedDataset(
    source_test,
    target_test,
    num_labeled="all",
    in_indicies=source_class_indicies,
    out_indicies=target_class_indicies,
)

XY_test_sampler = SubsetGuidedSampler(full_set_test, subsetsize=1, weight=SUBSET_WEIGHT)

# for accuracy
X_test_loader = torch.utils.data.DataLoader(
    source_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    # pin_memory=True,
)
Y_test_loader = torch.utils.data.DataLoader(
    target_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    # pin_memory=True,
)

In [None]:
torch.cuda.empty_cache()
gc.collect()
clear_output()

In [None]:
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(10)
X_test_fixed, Y_test_fixed = X_test_fixed.flatten(0, 1), Y_test_fixed.flatten(0, 1)

## GNOT

In [None]:
EXP_NAME = f"GNOT_Unpair_{DATASET}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

### init model

In [None]:
T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=48).cuda()

### load weights

In [None]:
print("Loading weights")

w_path = os.path.join(LOAD_PATH, "T_10000_no_z.pt")  # user setting

T.load_state_dict(torch.load(w_path))
print(f"{w_path}, loaded")

### plot

In [None]:
fig, axes = plot_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    T,
    gray=GRAY_PLOTS,
    savefig=True,
    save_path="./figs/Unpair/GNOT/fix",
)

In [None]:
fig, axes = plot_pushed_random_class_images(
    XY_test_sampler,
    T,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    savefig=True,
    save_path="./figs/Unpair/GNOT/random",
)

## ENOT

In [None]:
# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 5  #
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

In [None]:
EXP_NAME = f"ENOT_Unpair_{DATASET}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

### init model

In [None]:
T = CUNet(
    DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
).cuda()

sde = SDE(
    shift_model=T,
    epsilon=EPSILON,
    n_steps=N_STEPS,
    time_dim=TIME_DIM,
    n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
    use_positional_encoding=USE_POSITIONAL_ENCODING,
    use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
    predict_shift=PREDICT_SHIFT,
    image_input=IMAGE_INPUT,
).cuda()

print("sde params:", np.sum([np.prod(p.shape) for p in sde.parameters()]))

### Load weights


In [None]:
print("Loading weights")

w_path = os.path.join(LOAD_PATH, f"T_{SEED}_5000.pt")  # user setting

sde.load_state_dict(torch.load(w_path))

print(f"{w_path}, loaded")

### plot

In [None]:
fig, axes = plot_sde_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    sde,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Unpair/ENOT/fix",
)

In [None]:
fig, axes = plot_sde_pushed_random_class_images(
    XY_test_sampler,
    sde,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Unpair/ENOT/random",
)

## DNOT

In [None]:
# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [20, 50, 100]  # [0, 100] for testing,  [0, 20, 50, 100]
# training algorithm settings
STRATEGY = "Adapt"  # 'Fix' or 'Adapt'
# model settings
UNET_BASE_FACTOR = 48

In [None]:
EXP_NAME = f"DNOT_Class_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

### init model

In [None]:
Ts = []

for i in range(len(PIVOTAL_LIST) * 2):
    T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=UNET_BASE_FACTOR).cuda()
    Ts.append(T)

### load weights


In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{2000}")  # user setting
for i, T in enumerate(Ts):
    w_path = os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")
    T.load_state_dict(torch.load(w_path))
    print(f"{w_path}, loaded")

### plot

In [None]:
fig, axes = plot_linked_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    Ts,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Unpair/DNOT/fix",
)

In [None]:
fig, axes = plot_linked_pushed_random_class_images(
    XY_test_sampler,
    Ts,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Unpair/DNOT/random",
)

## DENOT

In [None]:
# the step number adding noise in diffusion process
DIFFUSION_STEPS = 1000
PIVOTAL_LIST = [20, 50, 100]  # [0, 100] for testing,  [0, 20, 50, 100]
# training algorithm settings
STRATEGY = "Fix"  # 'Fix' or 'Adapt'
# SDE network settings
EPSILON = 0  # [0 , 1, 10]
IMAGE_INPUT = True
PREDICT_SHIFT = True
N_STEPS = 5  # num of shifts time
UNET_BASE_FACTOR = 128
TIME_DIM = 128
USE_POSITIONAL_ENCODING = True
ONE_STEP_INIT_ITERS = 0
USE_GRADIENT_CHECKPOINT = False
N_LAST_STEPS_WITHOUT_NOISE = 1

In [None]:
EXP_NAME = f"DENOT_Class_{DATASET}_{STRATEGY}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

### init model

In [None]:
SDEs = []

for i in range(len(PIVOTAL_LIST) * 2):
    T = CUNet(
        DATASET1_CHANNELS, DATASET2_CHANNELS, TIME_DIM, base_factor=UNET_BASE_FACTOR
    ).cuda()

    T = SDE(
        shift_model=T,
        epsilon=EPSILON,
        n_steps=N_STEPS,
        time_dim=TIME_DIM,
        n_last_steps_without_noise=N_LAST_STEPS_WITHOUT_NOISE,
        use_positional_encoding=USE_POSITIONAL_ENCODING,
        use_gradient_checkpoint=USE_GRADIENT_CHECKPOINT,
        predict_shift=PREDICT_SHIFT,
        image_input=IMAGE_INPUT,
    ).cuda()
    SDEs.append(T)

### load weights

In [None]:
print("Loading weights")

CKPT_DIR = os.path.join(LOAD_PATH, f"iter{10000}/")  # user setting
for i, T in enumerate(SDEs):
    T.load_state_dict(torch.load(os.path.join(CKPT_DIR, f"T{i}_{SEED}.pt")))
    print(f"{CKPT_DIR}/T{i}_{SEED}.pt, loaded")

### plot

In [None]:
fig, axes = plot_linked_sde_pushed_images(
    X_test_fixed,
    Y_test_fixed,
    SDEs,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Unpair/DENOT/fix",
)

In [None]:
fig, axes = plot_linked_sde_pushed_random_class_images(
    XY_test_sampler,
    SDEs,
    plot_n_samples=10,
    gray=GRAY_PLOTS,
    plot_trajectory=False,
    savefig=True,
    save_path="./figs/Unpair/DENOT/random",
)