In [1]:
import sys

sys.path.append("..")

import numpy as np

%matplotlib inline 

import gc
import json

import torch

# This needed to use dataloaders for some datasets
from PIL import PngImagePlugin
from tqdm.auto import tqdm

from src.tools import get_stats, load_dataset

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gc.collect()
torch.cuda.empty_cache()

## Main Config


In [3]:
DEVICE_ID = 0

DATASET_LIST = [
    ("MNIST-colored_3", "~/data/MNIST", 32),
]
# DATASET_LIST = [
#     ("CelebA_high", "~/data/img_align_celeba", 64, 1),
# ]

assert torch.cuda.is_available()
torch.cuda.set_device(f"cuda:{DEVICE_ID}")

In [4]:
for DATASET, DATASET_PATH, IMG_SIZE in tqdm(DATASET_LIST):
    print("Processing {}".format(DATASET))
    sampler, test_sampler = load_dataset(
        DATASET, DATASET_PATH, img_size=IMG_SIZE, batch_size=256
    )
    print("Dataset {} loaded".format(DATASET))

    mu, sigma = get_stats(
        test_sampler.loader, inception=True, verbose=True, batch_size=256
    )
    print("Trace of sigma: {}".format(np.trace(sigma)))
    stats = {"mu": mu.tolist(), "sigma": sigma.tolist()}
    print("Stats computed")

    filename = "{}_{}_test.json".format(DATASET, IMG_SIZE)
    with open(filename, "w") as fp:
        json.dump(stats, fp)
    print("States saved to {}".format(filename))

  0%|                                                                                                                                                              | 0/1 [00:00<?, ?it/s]

Processing MNIST-colored_3
Dataset MNIST-colored_3 loaded


Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth

  0%|                                                                                                                                                        | 0.00/91.2M [00:00<?, ?B/s][A
  0%|▏                                                                                                                                                | 128k/91.2M [00:00<02:53, 552kB/s][A
  0%|▌                                                                                                                                               | 384k/91.2M [00:00<01:20, 1.19MB/s][A
  1%|█▏                                                                                                                                              | 768k/91.2M [00:00<00:46, 2.05MB/s][A
  2%|██▎                                      

Trace of sigma: 67.06710443320588
Stats computed


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:16<00:00, 16.52s/it]

States saved to MNIST-colored_3_32_test.json



