In [2]:
import sys

sys.path.append("..")

import numpy as np

%matplotlib inline 

import gc
import json

import torch

# This needed to use dataloaders for some datasets
from PIL import PngImagePlugin
from tqdm.auto import tqdm

from src.samplers import get_loader_sampler, get_paired_sampler
from src.tools import get_loader_stats

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

In [3]:
gc.collect()
torch.cuda.empty_cache()

## Main Config


### paired data(X-Y: img-img)


In [4]:
DEVICE_ID = 0

BATCH_SIZE = 64
FID_EPOCHS = 1
DATASET_LIST = [
    ("celeba_mask", "../datasets/CelebAMask-HQ", 256, "colored_mask2face", False),
    ("celeba_mask", "../datasets/CelebAMask-HQ", 256, "colored_mask2face", True),
    ("FS2K", "../datasets/FS2K/", 256, "sketch2photo", False),
    ("FS2K", "../datasets/FS2K/", 256, "sketch2photo", True),
    (
        "comic_faces_v1",
        "../datasets/face2comics_v1.0.0_by_Sxela",
        256,
        "face2comic",
        False,
    ),
    (
        "comic_faces_v1",
        "../datasets/face2comics_v1.0.0_by_Sxela",
        256,
        "face2comic",
        True,
    ),
]

assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_ID}")

In [5]:
for DATASET, DATASET_PATH, IMG_SIZE, MAP_NAME, USE_Y in tqdm(DATASET_LIST):
    print(f"Processing dataset: {DATASET}, image size: {IMG_SIZE}, use Y: {USE_Y}")
    train_sampler, test_sampler = get_paired_sampler(
        DATASET, DATASET_PATH, IMG_SIZE, batch_size=BATCH_SIZE
    )
    print("Dataset {} loaded".format(DATASET))
    mu, sigma = get_loader_stats(
        test_sampler.loader,
        n_epochs=FID_EPOCHS,
        batch_size=BATCH_SIZE,
        use_Y=USE_Y,
        verbose=True,
    )
    print("Trace of sigma: {}".format(np.trace(sigma)))
    stats = {"mu": mu.tolist(), "sigma": sigma.tolist()}
    print("Stats computed")
    if USE_Y:
        filename = f"{DATASET}_{MAP_NAME.split('2')[1]}_{IMG_SIZE}_test.json"
    else:
        filename = f"{DATASET}_{MAP_NAME.split('2')[0]}_{IMG_SIZE}_test.json"

    with open(filename, "w") as fp:
        json.dump(stats, fp)
    print("States saved to {}".format(filename))

  0%|          | 0/6 [00:00<?, ?it/s]

Processing dataset: celeba_mask, image size: 256, use Y: False
Dataset celeba_mask loaded


'Epoch 1/1: Processing batch 47/47'

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Trace of sigma: 63.48624262914124
Stats computed
States saved to celeba_mask_colored_mask_256_test.json
Processing dataset: celeba_mask, image size: 256, use Y: True
Dataset celeba_mask loaded


'Epoch 1/1: Processing batch 8/47'

### single dataset(X-Y: img-label)


In [None]:
DEVICE_ID = 0

BATCH_SIZE = 64
FID_EPOCHS = 1
DATASET_LIST = [
    # ("CelebA_high", "/gpfs/data/gpfs0/n.gushchin/img_align_celeba", 64),
    ("MNIST-colored_3", "../datasets/", 32),
    ("MNIST-colored_2", "../datasets/", 32),
]

assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_ID}")

In [None]:
for DATASET, DATASET_PATH, IMG_SIZE in tqdm(DATASET_LIST):
    print("Processing {}".format(DATASET))

    train_sampler, test_sampler = get_loader_sampler(
        DATASET, DATASET_PATH, img_size=IMG_SIZE, batch_size=BATCH_SIZE
    )
    print("Dataset {} loaded".format(DATASET))

    mu, sigma = get_loader_stats(
        test_sampler.loader,
        n_epochs=FID_EPOCHS,
        batch_size=BATCH_SIZE,
        use_Y=False,
        verbose=True,
    )
    print("Trace of sigma: {}".format(np.trace(sigma)))
    stats = {"mu": mu.tolist(), "sigma": sigma.tolist()}
    print("Stats computed")

    filename = "{}_{}_test.json".format(DATASET, IMG_SIZE)
    with open(filename, "w") as fp:
        json.dump(stats, fp)
    print("States saved to {}".format(filename))