## 1. Imports


In [None]:
import os
import sys
import gc
import warnings

import torch
import torchvision
import matplotlib.pyplot as plt

from tensorboardX import SummaryWriter
from torchvision.transforms import Compose, ToTensor, Resize, Normalize, Lambda
from PIL import PngImagePlugin

from IPython.display import clear_output

sys.path.append("..")
from src.unet import UNet
from src.mnistm_utils import MNISTM
from src.fid_score import calculate_frechet_distance
from src.tools import (
    set_random_seed,
    get_loader_stats,
    get_pushed_loader_stats,
    get_pushed_loader_metrics,
    get_pushed_loader_accuracy,
)
from src.plotters import (
    plot_pushed_images,
    plot_pushed_random_class_images,
)
from src.samplers import (
    SubsetGuidedSampler,
    SubsetGuidedDataset,
    get_indicies_subset,
)


LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

warnings.filterwarnings("ignore")

%matplotlib inline 

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 2. Init Config

the config file `config.json` is saved at saved_models/EXP_NAME/



In [None]:
SEED = 0x3060
set_random_seed(SEED)

# dataset choosing
DATASET, DATASET_PATH = "fmnist2mnist", "../datasets/"
# DATASET, DATASET_PATH = "mnist2mnistm", "../datasets/"
# DATASET, DATASET_PATH = "mnist2usps", "../datasets/"
# DATASET, DATASET_PATH = "mnist2kmnist", "../datasets/"

IMG_SIZE = 32
DATASET1_CHANNELS = 1
DATASET2_CHANNELS = 1

# GPU choosing
DEVICE_IDS = [1]
assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_IDS[0]}")

CONTINUE = 0

# All hyperparameters below is set to the values used for the experiments, which discribed in the article


# training algorithm settings
BATCH_SIZE = 32
SUBSET_CLASS = 3

T_ITERS = 10
MAX_STEPS = 60000 + 1  # 2501 for testing
COST = "Energy"
SCHEDULER_MILESTONES = [10000, 20000, 30000, 40000, 50000]

# plot settings
GRAY_PLOTS = True

FID_EPOCHS = 1

EXP_NAME = f"GNOT_Unpair_{DATASET}_{SUBSET_CLASS}_{SEED}"
LOAD_PATH = f"../saved_models/{EXP_NAME}/"

if not os.path.exists(LOAD_PATH):
    raise FileNotFoundError("no such file or directory...")

In [None]:
source_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_source = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}
target_subset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
new_labels_target = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}

SUBSET_WEIGHT = [0 for _ in range(len(source_subset))]
SUBSET_WEIGHT[SUBSET_CLASS] = 1.0

In [None]:
classifier = torchvision.models.resnet18()
classifier.fc = torch.nn.Linear(in_features=512, out_features=10, bias=True)

source_transform = Compose(
    [
        Resize((IMG_SIZE, IMG_SIZE)),
        ToTensor(),
        Normalize((0.5), (0.5)),
    ]
)
target_transform = source_transform

if DATASET == "mnist2kmnist":
    source = torchvision.datasets.MNIST
    target = torchvision.datasets.KMNIST
    classifier.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    classifier.load_state_dict(torch.load("../saved_models/classifiers/kmnist.pt"))

elif DATASET == "fmnist2mnist":
    source = torchvision.datasets.FashionMNIST
    target = torchvision.datasets.MNIST
    classifier.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    classifier.load_state_dict(torch.load("../saved_models/classifiers/mnist.pt"))

elif DATASET == "mnist2usps":
    source = torchvision.datasets.MNIST
    target = torchvision.datasets.USPS
    classifier.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    classifier.load_state_dict(torch.load("../saved_models/classifiers/usps.pt"))

elif DATASET == "mnist2mnistm":
    DATASET1_CHANNELS = 3
    DATASET2_CHANNELS = 3
    classifier.load_state_dict(torch.load("../saved_models/classifiers/mnistm.pt"))
    GRAY_PLOTS = False
    source = torchvision.datasets.MNIST
    target = MNISTM
    source_transform = Compose(
        [
            Resize((IMG_SIZE, IMG_SIZE)),
            ToTensor(),
            Normalize((0.5), (0.5)),
            Lambda(lambda x: -x.repeat(3, 1, 1)),
        ]
    )
    target_transform = Compose(
        [Resize(IMG_SIZE), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

classifier.cuda()
classifier.eval()

## 3. Initialize samplers


In [None]:
source_test = source(
    root=DATASET_PATH, train=False, download=True, transform=source_transform
)
source_subset_samples, source_labels, source_class_indicies = get_indicies_subset(
    source_test,
    new_labels=new_labels_source,
    classes=len(source_subset),
    subset_classes=source_subset,
)
source_test = torch.utils.data.TensorDataset(
    torch.stack(source_subset_samples), torch.LongTensor(source_labels)
)


target_test = target(
    root=DATASET_PATH, train=False, download=True, transform=target_transform
)
target_subset_samples, target_labels, target_class_indicies = get_indicies_subset(
    target_test,
    new_labels=new_labels_target,
    classes=len(target_subset),
    subset_classes=target_subset,
)
target_test = torch.utils.data.TensorDataset(
    torch.stack(target_subset_samples), torch.LongTensor(target_labels)
)

full_set_test = SubsetGuidedDataset(
    source_test,
    target_test,
    num_labeled="all",
    in_indicies=source_class_indicies,
    out_indicies=target_class_indicies,
)

XY_test_sampler = SubsetGuidedSampler(full_set_test, subsetsize=1, weight=SUBSET_WEIGHT)

# for accuracy
X_test_loader = torch.utils.data.DataLoader(
    source_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    # pin_memory=True,
)
Y_test_loader = torch.utils.data.DataLoader(
    target_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    # pin_memory=True,
)

In [None]:
torch.cuda.empty_cache()
gc.collect()
clear_output()

## 4. Testing

### init models



In [None]:
T = UNet(DATASET1_CHANNELS, DATASET2_CHANNELS, base_factor=48).cuda()

### Load weights for continue training

In [None]:
print("Loading weights")

w_path = os.path.join(LOAD_PATH, "T_10000_no_z.pt")  # user setting

T.load_state_dict(torch.load(w_path))
print(f"{w_path}, loaded")

### Plots Test


In [None]:
X_test_fixed, Y_test_fixed = XY_test_sampler.sample(10)
X_test_fixed, Y_test_fixed = X_test_fixed.flatten(0, 1), Y_test_fixed.flatten(0, 1)

In [None]:
fig, axes = plot_pushed_images(X_test_fixed, Y_test_fixed, T, gray=GRAY_PLOTS)

In [None]:
fig, axes = plot_pushed_random_class_images(XY_test_sampler, T, gray=GRAY_PLOTS)

### main testing


In [None]:
clear_output(wait=True)
print("Plotting")

inference_T = T
inference_T.eval()
print("Fixed Test Images")
fig, axes = plot_pushed_images(X_test_fixed, Y_test_fixed, inference_T, gray=GRAY_PLOTS)
plt.show(fig)
plt.close(fig)
print("Random Test Images")
fig, axes = plot_pushed_random_class_images(
    XY_test_sampler, inference_T, gray=GRAY_PLOTS
)
plt.show(fig)
plt.close(fig)

In [None]:
print("Computing FID")
target_mu, target_sigma = get_loader_stats(
    Y_test_loader, BATCH_SIZE, FID_EPOCHS, verbose=True, use_Y=False
)
gen_mu, gen_sigma = get_pushed_loader_stats(
    T,
    X_test_loader,
    n_epochs=FID_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=True,
)
fid = calculate_frechet_distance(gen_mu, gen_sigma, target_mu, target_sigma)
print(f"FID={fid}")

In [None]:
print("Computing Accuracy")
accuracy = get_pushed_loader_accuracy(T, X_test_loader, classifier)

In [None]:
print("Computing Metrics")
metrics = get_pushed_loader_metrics(
    T,
    XY_test_sampler.loader,
    n_epochs=FID_EPOCHS,
    verbose=True,
    log_metrics=["LPIPS", "PSNR", "SSIM", "MSE", "MAE"],
)
print(f"metrics={metrics}")