In [1]:
import sys

sys.path.append("..")

import numpy as np

%matplotlib inline 

import gc
import json

import torch

# This needed to use dataloaders for some datasets
from PIL import PngImagePlugin
from tqdm.auto import tqdm

from src.tools import get_loader_stats, load_dataset

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

In [2]:
gc.collect()
torch.cuda.empty_cache()

## Main Config


In [5]:
DEVICE_ID = 0

DATASET_LIST = [
    ("MNIST-colored_3", "~/data/MNIST", 64, 1),
]
# DATASET_LIST = [
#     ("CelebA_high", "~/data/img_align_celeba", 64, 1),
# ]

assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_ID}")

In [None]:
for DATASET, DATASET_PATH, IMG_SIZE, N_EPOCHS in tqdm(DATASET_LIST):
    print("Processing {}".format(DATASET))
    sampler, test_sampler = load_dataset(
        DATASET, DATASET_PATH, img_size=IMG_SIZE, batch_size=256
    )
    print("Dataset {} loaded".format(DATASET))

    mu, sigma = get_loader_stats(
        test_sampler.loader, n_epochs=N_EPOCHS, verbose=True, batch_size=256
    )
    print("Trace of sigma: {}".format(np.trace(sigma)))
    stats = {"mu": mu.tolist(), "sigma": sigma.tolist()}
    print("Stats computed")

    filename = "{}_{}_test.json".format(DATASET, IMG_SIZE)
    with open(filename, "w") as fp:
        json.dump(stats, fp)
    print("States saved to {}".format(filename))