## 1. Imports


In [None]:
import os
import sys
import gc
import json
import warnings

import torch
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

from tensorboardX import SummaryWriter
from torchvision.transforms import Compose, ToTensor, Resize, Normalize, Lambda
from PIL import PngImagePlugin

from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

sys.path.append("..")
from src.resnet2 import ResNet_D
from src.unet import UNet
from src.mnistm_utils import MNISTM

from src.tools import (
    set_random_seed,
    unfreeze,
    freeze,
    weights_init_D,
    fig2tensor,
)
from src.plotters import (
    plot_pushed_images,
    plot_pushed_random_class_images,
)
from src.samplers import (
    SubsetGuidedSampler,
    SubsetGuidedDataset,
    get_indicies_subset,
)


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Config

Dataset choosing in the first rows


In [None]:
SEED = 0x0000
set_random_seed(SEED)

# dataset choosing
DATASET, DATASET_PATH = "fmnist2mnist", "../datasets/"
# DATASET, DATASET_PATH = "mnist2mnistm", "../datasets/"
# DATASET, DATASET_PATH = "mnist2usps", "../datasets/"
# DATASET, DATASET_PATH = "mnist2kmnist", "../datasets/"

IMG_SIZE = 32
DATASET1_CHANNELS = 1
DATASET2_CHANNELS = 1

# GPU choosing
DEVICE_IDS = [1]
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_IDS[0]}")

CONTINUE = 0

# All hyperparameters below is set to the values used for the experiments, which discribed in the article


# training algorithm settings
BATCH_SIZE = 32
SUBSET_SIZE = 2
NUM_LABELED = 10  # num of labeled target in training set

T_ITERS = 10
MAX_STEPS = 60000 + 1  # 2501 for testing
COST = "Energy"
SCHEDULER_MILESTONES = [10000, 20000, 30000, 40000, 50000]

# optimizer settings
D_LR, T_LR = 1e-4, 1e-4
T_GRADIENT_MAX_NORM = float(500)
D_GRADIENT_MAX_NORM = float(500)

# plot settings
GRAY_PLOTS = True

# log settings
TRACK_VAR_INTERVAL = 1000
PLOT_INTERVAL = 500
CPKT_INTERVAL = 10000

FID_EPOCHS = 1

EXP_NAME = f"GNOT_Class_{DATASET}_{SEED}"
OUTPUT_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

writer = SummaryWriter(f"../logdir/{EXP_NAME}")

In [None]:
source_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_source = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}
target_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_target = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}

In [None]:
config = dict(
    SEED=SEED,
    DATASET=DATASET,
    T_ITERS=T_ITERS,
    D_LR=D_LR,
    T_LR=T_LR,
    BATCH_SIZE=BATCH_SIZE,
    COST=COST,
    T_GRADIENT_MAX_NORM=T_GRADIENT_MAX_NORM,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    TRACK_VAR_INTERVAL=TRACK_VAR_INTERVAL,
    FID_EPOCHS=FID_EPOCHS,
)

with open(os.path.join(OUTPUT_PATH, "config.json"), "w") as json_file:
    json_str = json.dumps(config, indent=4)
    json_file.write(json_str)

log = dict(CONTINUE=CONTINUE)
with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
    log_str = json.dumps(log, indent=4)
    log_file.write(log_str)

## 3. Initialize samplers


In [None]:
source_transform = Compose(
    [
        Resize((IMG_SIZE, IMG_SIZE)),
        ToTensor(),
        Normalize((0.5), (0.5)),
    ]
)
target_transform = source_transform

if DATASET == "mnist2kmnist":
    source = datasets.MNIST
    target = datasets.KMNIST

elif DATASET == "fmnist2mnist":
    source = datasets.FashionMNIST
    target = datasets.MNIST

elif DATASET == "mnist2usps":
    source = datasets.MNIST
    target = datasets.USPS

elif DATASET == "mnist2mnistm":
    DATASET1_CHANNELS = 3
    DATASET2_CHANNELS = 3
    GRAY_PLOTS = False
    source = datasets.MNIST
    target = MNISTM
    source_transform = Compose(
        [
            Resize((IMG_SIZE, IMG_SIZE)),
            ToTensor(),
            Normalize((0.5), (0.5)),
            Lambda(lambda x: -x.repeat(3, 1, 1)),
        ]
    )
    target_transform = Compose(
        [Resize(IMG_SIZE), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

In [None]:
source_train = source(
    root=DATASET_PATH, train=True, download=True, transform=source_transform
)

subset_samples, labels, source_class_indicies = get_indicies_subset(
    source_train,
    new_labels=new_labels_source,
    classes=len(source_subset),
    subset_classes=source_subset,
)
source_train = torch.utils.data.TensorDataset(
    torch.stack(subset_samples), torch.LongTensor(labels)
)


target_train = target(
    root=DATASET_PATH, train=True, download=True, transform=target_transform
)

target_subset_samples, target_labels, target_class_indicies = get_indicies_subset(
    target_train,
    new_labels=new_labels_target,
    classes=len(target_subset),
    subset_classes=target_subset,
)
target_train = torch.utils.data.TensorDataset(
    torch.stack(target_subset_samples), torch.LongTensor(target_labels)
)

train_set = SubsetGuidedDataset(
    source_train,
    target_train,
    num_labeled=NUM_LABELED,
    in_indicies=source_class_indicies,
    out_indicies=target_class_indicies,
)

full_set = SubsetGuidedDataset(
    source_train,
    target_train,
    num_labeled="all",
    in_indicies=source_class_indicies,
    out_indicies=target_class_indicies,
)

In [None]:
T_XY_sampler = SubsetGuidedSampler(train_set, subsetsize=SUBSET_SIZE)
D_XY_sampler = SubsetGuidedSampler(full_set, subsetsize=1)

In [None]:
torch.cuda.empty_cache()
gc.collect()
clear_output()

## 4. Training


### Models initialization


In [None]:
D = ResNet_D(IMG_SIZE, nc=DATASET2_CHANNELS).cuda()
D.apply(weights_init_D)
T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=48).cuda()
T_opt = torch.optim.Adam(T.parameters(), lr=T_LR, weight_decay=1e-10)
D_opt = torch.optim.Adam(D.parameters(), lr=D_LR, weight_decay=1e-10)
T_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    T_opt, milestones=SCHEDULER_MILESTONES, gamma=0.5
)
D_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    D_opt, milestones=SCHEDULER_MILESTONES, gamma=0.5
)

### Plots Test


In [None]:
X_fixed, Y_fixed = D_XY_sampler.sample(NUM_LABELED)
X_fixed, Y_fixed = X_fixed.flatten(0, 1), Y_fixed.flatten(0, 1)
print(f"[Debug] {X_fixed.shape=}\n{Y_fixed.shape=}")

In [None]:
fig, axes = plot_pushed_images(X_fixed, Y_fixed, T, gray=GRAY_PLOTS)
writer.add_image("class images[sde]", fig2tensor(fig))

In [None]:
fig, axes = plot_pushed_random_class_images(D_XY_sampler, T, gray=GRAY_PLOTS)

### Main training cycle and logging


In [None]:
progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE)


for step in range(MAX_STEPS):
    if step < CONTINUE:
        continue

    unfreeze(T)
    freeze(D)

    for t_iter in range(T_ITERS):
        T_opt.zero_grad()

        X, Y = T_XY_sampler.sample(BATCH_SIZE)
        # X0.requires_grad_()
        T_X = (
            T(X.flatten(0, 1))
            .permute(1, 2, 3, 0)
            .reshape(DATASET2_CHANNELS, IMG_SIZE, IMG_SIZE, -1, SUBSET_SIZE)
            .permute(3, 4, 0, 1, 2)
        )
        T_var = (
            0.5
            * torch.cdist(T_X.flatten(start_dim=2), T_X.flatten(start_dim=2)).mean()
            * SUBSET_SIZE
            / (SUBSET_SIZE - 1)
        )
        cost = (Y - T_X).flatten(start_dim=2).norm(dim=2).mean()
        T_loss = cost - T_var - D(T_X.flatten(start_dim=0, end_dim=1)).mean()
        T_loss.backward()
        T_opt.step()

    T_scheduler.step()
    del T_loss, T_X, X, Y
    gc.collect()
    torch.cuda.empty_cache()

    freeze(T)
    unfreeze(D)

    D_opt.zero_grad()

    # sample unlabeled Y~Q, X~P
    X, _ = T_XY_sampler.sample(BATCH_SIZE)
    _, Y = D_XY_sampler.sample(BATCH_SIZE)

    with torch.no_grad():
        T_X = T(X.flatten(start_dim=0, end_dim=1))

    D_opt.zero_grad()
    D_loss = D(T_X).mean() - D(Y.flatten(start_dim=0, end_dim=1)).mean()
    D_loss.backward()
    D_opt.step()
    D_scheduler.step()

    del D_loss, Y, X, T_X
    gc.collect()
    torch.cuda.empty_cache()

    CONTINUE += 1
    progress_bar.update(1)

    if step % PLOT_INTERVAL == 0:
        clear_output(wait=True)
        progress_bar.close()
        progress_bar = tqdm(total=MAX_STEPS, initial=CONTINUE)
        print("Plotting")

        inference_T = T
        inference_T.eval()
        print("Fixed Test Images")
        fig, axes = plot_pushed_images(X_fixed, Y_fixed, inference_T, gray=GRAY_PLOTS)
        writer.add_image("Fixed Test Images", fig2tensor(fig), step)

        plt.show(fig)
        plt.close(fig)
        print("Random Test Images")
        fig, axes = plot_pushed_random_class_images(
            D_XY_sampler, inference_T, gray=GRAY_PLOTS
        )
        writer.add_image("Random Test Images", fig2tensor(fig), step)
        plt.show(fig)
        plt.close(fig)

    if step != 0 and step % CPKT_INTERVAL == 0:
        inference_T = T

        inference_T.eval()
        freeze(T)
        if len(DEVICE_IDS) > 1:
            torch.save(
                T.module.state_dict(), os.path.join(OUTPUT_PATH, f"T_{SEED}_{step}.pt")
            )
            torch.save(
                D.module.state_dict(), os.path.join(OUTPUT_PATH, f"D_{SEED}_{step}.pt")
            )
        else:
            torch.save(T.state_dict(), os.path.join(OUTPUT_PATH, f"T_{SEED}_{step}.pt"))
            torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, f"D_{SEED}_{step}.pt"))

        torch.save(
            D_opt.state_dict(), os.path.join(OUTPUT_PATH, f"D_opt_{SEED}_{step}.pt")
        )
        torch.save(
            T_opt.state_dict(), os.path.join(OUTPUT_PATH, f"T_opt_{SEED}_{step}.pt")
        )
        torch.save(
            D_scheduler.state_dict(),
            os.path.join(OUTPUT_PATH, f"D_scheduler_{SEED}_{step}.pt"),
        )
        torch.save(
            T_scheduler.state_dict(),
            os.path.join(OUTPUT_PATH, f"T_scheduler_{SEED}_{step}.pt"),
        )

        log["CONTINUE"] = CONTINUE
        with open(os.path.join(OUTPUT_PATH, "log.json"), "w") as log_file:
            log_str = json.dumps(log, indent=4)
            log_file.write(log_str)
    if step % TRACK_VAR_INTERVAL == 0:
        pass

    gc.collect()
    torch.cuda.empty_cache()

## Clear resources


In [None]:
try:
    writer.close()
    progress_bar.close()
except Exception as e:
    print(e)