In [77]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import pyvtorch.aux as vtaux

import karras.dnnlib as dnnlib
import Images.calculate_metrics as calc
import Images.generate_images as gen

## Dataset statistics

### Tiny ImageNet

In [6]:
dataset_name = "tiny"

In [None]:
dataset_kwargs = calc.get_dataset_kwargs(dataset_name, image_path=None)
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs, random_seed=0)

In [None]:
x_loader = lambda index : dataset_obj[index][1]
stats = vtaux.calculate_stats(len(dataset_obj), x_loader)

  all_x_mean.append(np.array(*[torch.mean(x, axis=(1,2)).detach().cpu()]))
  all_x_sqs_mean.append(np.array(*[torch.mean(x**2, axis=(1,2)).detach().cpu()]))
100%|██████████| 100000/100000 [01:10<00:00, 1423.94it/s]


In [56]:
stats

{'x_shape': torch.Size([3, 64, 64]),
 'x_dim': 3,
 'all_x_min': array([2., 0., 0., ..., 0., 0., 0.], shape=(100000,)),
 'all_x_max': array([255., 255., 255., ..., 217., 255., 255.], shape=(100000,)),
 'all_x_mean': array([[132.87378 , 147.03076 , 122.403076],
        [103.99927 ,  78.821045,  22.131592],
        [160.11401 , 144.73926 , 141.33447 ],
        ...,
        [ 60.133057,  26.569336,   8.109619],
        [166.11646 , 131.41968 ,  56.260254],
        [207.63867 ,  81.926025,  30.67212 ]],
       shape=(100000, 3), dtype=float32),
 'all_stds_x': array([[53.979   , 56.426582, 64.05145 ],
        [77.35214 , 59.091846, 31.75845 ],
        [39.368923, 69.23086 , 70.82298 ],
        ...,
        [52.178   , 31.1346  , 10.568505],
        [65.04069 , 55.074238, 67.50283 ],
        [38.18556 , 52.340843, 47.3147  ]],
       shape=(100000, 3), dtype=float32),
 'x_min': np.float64(0.0),
 'x_max': np.float64(255.0),
 'x_mean': array([122.45973, 114.25749, 101.36358], dtype=float32),
 '

In [60]:
for k in ["x_min", "x_max", "x_mean", "x_std"]:
    print(k, "=", stats[k] / 127.5 - 1)

x_min = -1.0
x_max = 1.0
x_mean = [-0.03953153 -0.10386282 -0.20499152]
x_std = [-0.44711548 -0.46227467 -0.4368255 ]


### ImageNet

In [None]:
dataset_name = "img512"

In [None]:
dataset_kwargs = calc.get_dataset_kwargs(dataset_name, image_path=None)
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs, random_seed=0)

In [None]:
# x_loader = lambda index : dataset_obj[index][1]
x_loader = lambda index : torch.Tensor(dataset_obj[index][1])
stats = vtaux.calculate_stats(len(dataset_obj), x_loader)

100%|██████████| 1281167/1281167 [2:28:37<00:00, 143.67it/s]  


In [None]:
stats

{'x_shape': torch.Size([3, 512, 512]),
 'x_dim': 3,
 'all_x_min': array([0., 0., 9., ..., 7., 0., 0.], shape=(1281167,)),
 'all_x_max': array([255., 255., 255., ..., 252., 255., 255.], shape=(1281167,)),
 'all_x_mean': array([[130.78575 , 128.22478 , 105.861206],
        [140.95712 , 129.00656 , 104.63354 ],
        [110.517365, 118.52455 ,  97.95478 ],
        ...,
        [163.70988 , 156.12029 , 139.55792 ],
        [ 61.800724,  48.505   ,  40.783894],
        [ 76.47135 ,  93.902725,  77.668526]],
       shape=(1281167, 3), dtype=float32),
 'all_stds_x': array([[83.564186, 81.642586, 95.86758 ],
        [67.36732 , 61.5589  , 62.421864],
        [46.838287, 36.040924, 38.12787 ],
        ...,
        [52.24701 , 51.36213 , 60.686085],
        [85.13034 , 71.48992 , 65.17409 ],
        [51.21757 , 62.511326, 53.195656]],
       shape=(1281167, 3), dtype=float32),
 'x_min': np.float64(0.0),
 'x_max': np.float64(255.0),
 'x_mean': array([123.69064, 116.78469, 103.86867], dtype=float3

In [None]:
for k in ["x_min", "x_max", "x_mean", "x_std"]:
    print(k, "=", stats[k] / 127.5 - 1)

x_min = -1.0
x_max = 1.0
x_mean = [-0.02987731 -0.08404166 -0.1853438 ]
x_std = [-0.4403078  -0.4555725  -0.42949784]
