In [1]:
import os, sys
dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: sys.path.append(dir1)   

In [2]:
import torch
import torch.utils.data

from utils import dataset, transform

In [3]:
data_root = '../dataset/cityscapes'
train_list = '../dataset/cityscapes/list/fine_train.txt'
batch_size = 8 # calculate the mean and std by using batch size

dct_transform = transform.Compose([
    # transform.Resize(size=(val_h, val_w)),
    transform.GetDctCoefficient(),
    transform.ToTensor()])
    # transform.Normalize(mean=dct_mean, std=dct_std)])
dct_data = dataset.SemData(split='train', img_type='dct', data_root=data_root, data_list=train_list, transform=dct_transform)
dataloader = torch.utils.data.DataLoader(dct_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
data, label = next(iter(dataloader))
data[0].mean(), data[0].std()

Totally 2975 samples in train set.
Starting Checking image&label pair train list...
Checking image&label pair train list done!


(tensor(-8.0169), tensor(119.8144))

In [4]:
# https://towardsdatascience.com/how-to-calculate-the-mean-and-standard-deviation-normalizing-datasets-in-pytorch-704bd7d05f4c
def get_mean_and_std(dataloader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in dataloader:
        # Mean over batch, height and width, but not over the channels
        channels_sum += torch.mean(data, dim=[0,2,3])
        channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
        num_batches += 1
    
    mean = channels_sum / num_batches

    # std = sqrt(E[X^2] - (E[X])^2)
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

In [5]:
mean_1, std_1 = get_mean_and_std(dataloader)

In [6]:
def mean_std_for_loader(loader):
    # var[X] = E[X**2] - E[X]**2
    channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
    for data, _ in loader:
        this_batch_size = data.size()[0]
        weight = this_batch_size / loader.batch_size
        channels_sum += weight*torch.mean(data, dim=[0, 2, 3])
        channels_sqrd_sum += weight*torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += weight
    mean = channels_sum / num_batches
    std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
    return mean, std

In [7]:
mean_2, std_2 = mean_std_for_loader(dataloader)

In [8]:
def batch_mean_and_sd(loader):
    cnt = 0
    fst_moment = torch.empty(192)
    snd_moment = torch.empty(192)

    for images, _ in loader:
        b, c, h, w = images.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(images, dim=[0, 2, 3])
        sum_of_square = torch.sum(images ** 2,
                                  dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
        cnt += nb_pixels

    mean, std = fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)        
    return mean,std

In [9]:
mean_3, std_3 = batch_mean_and_sd(dataloader)

In [21]:
value_scale = 1 # default: 255 for RGB
dct_mean = mean_1.tolist()
dct_mean = [item * value_scale for item in dct_mean]
dct_std = std_1.tolist()
dct_std = [item * value_scale for item in dct_std]

dct_transform = transform.Compose([
    # transform.Resize(size=(val_h, val_w)),
    transform.GetDctCoefficient(),
    transform.ToTensor(),
    transform.Normalize(mean=dct_mean, std=dct_std)])
dct_data = dataset.SemData(split='train', img_type='dct', data_root=data_root, data_list=train_list, transform=dct_transform)
dataloader = torch.utils.data.DataLoader(dct_data, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

Totally 2975 samples in train set.
Starting Checking image&label pair train list...
Checking image&label pair train list done!


In [25]:
data, label = next(iter(dataloader))
for i in range(len(data)):
    print(data[i].mean(), data[i].std())

tensor(-0.0012) tensor(0.7803)
tensor(-0.0014) tensor(0.7367)
tensor(-0.0030) tensor(0.8586)
tensor(-0.0013) tensor(0.8623)
tensor(-0.0009) tensor(0.7645)
tensor(-0.0018) tensor(0.8970)
tensor(-0.0019) tensor(0.6839)
tensor(-0.0021) tensor(0.7027)
tensor(-0.0030) tensor(0.7283)
tensor(-0.0003) tensor(0.6937)
tensor(-0.0034) tensor(0.9224)
tensor(-0.0050) tensor(0.9387)
tensor(-0.0038) tensor(0.9510)
tensor(-0.0008) tensor(0.8867)
tensor(-0.0010) tensor(0.9082)
tensor(-0.0033) tensor(0.9494)


In [12]:
mean_1, std_1

(tensor([-1.2301e+03,  1.9670e-02, -1.1034e-02, -1.4480e-02, -7.9868e-03,
         -6.9404e-03, -6.6573e-04, -1.5744e-02,  3.8558e+00,  3.3794e-04,
          3.6194e-03, -3.4312e-03, -3.2777e-03, -9.4393e-04,  1.3963e-03,
         -3.3344e-03,  1.3149e-02,  8.7037e-03,  7.8148e-03, -2.4318e-03,
         -2.8676e-04,  6.8836e-05, -1.0469e-03,  3.8229e-04,  2.7241e-01,
         -2.4652e-03, -2.4627e-03, -1.7762e-03,  8.5411e-04, -5.2705e-04,
          2.4735e-05, -2.7470e-03, -7.6050e-03,  3.3366e-03,  4.5456e-04,
         -9.9936e-04,  1.0653e-04, -6.9176e-05,  4.8235e-05,  1.3264e-04,
          2.3608e-02, -1.8564e-03, -3.0589e-04, -7.1253e-04,  1.9206e-04,
         -1.2259e-03, -1.0184e-04, -3.0074e-03, -7.2273e-04,  6.5256e-04,
         -1.5977e-04, -3.1606e-04,  3.5526e-05, -7.3496e-05,  7.1139e-05,
         -9.2909e-05, -2.6865e-02, -1.6012e-03,  2.8991e-04, -1.9029e-03,
          1.1725e-04, -2.8448e-03,  6.1576e-05, -7.7002e-03, -9.4736e+01,
         -4.3182e-02,  6.3000e-04, -2.

In [13]:
mean_2, std_2

(tensor([-1.2302e+03,  1.9250e-02, -1.1149e-02, -1.4518e-02, -7.9962e-03,
         -6.9432e-03, -6.6613e-04, -1.5741e-02,  3.8555e+00,  3.2565e-04,
          3.5939e-03, -3.4209e-03, -3.2758e-03, -9.4420e-04,  1.3975e-03,
         -3.3315e-03,  1.3075e-02,  8.6962e-03,  7.7847e-03, -2.4286e-03,
         -2.8923e-04,  6.9406e-05, -1.0478e-03,  3.8247e-04,  2.7243e-01,
         -2.4557e-03, -2.4607e-03, -1.7781e-03,  8.5281e-04, -5.2712e-04,
          2.5532e-05, -2.7457e-03, -7.5480e-03,  3.3411e-03,  4.5942e-04,
         -1.0016e-03,  1.0696e-04, -6.8842e-05,  4.7495e-05,  1.3312e-04,
          2.3617e-02, -1.8531e-03, -3.0279e-04, -7.1154e-04,  1.9238e-04,
         -1.2258e-03, -1.0202e-04, -3.0047e-03, -7.1673e-04,  6.5473e-04,
         -1.5881e-04, -3.1517e-04,  3.5688e-05, -7.3601e-05,  7.1016e-05,
         -9.2394e-05, -2.6851e-02, -1.6000e-03,  2.9015e-04, -1.9015e-03,
          1.1672e-04, -2.8428e-03,  6.1661e-05, -7.6932e-03, -9.4728e+01,
         -4.3172e-02,  5.9435e-04, -2.

In [26]:
mean_3, std_3

(tensor([-1.2302e+03,  1.9250e-02, -1.1149e-02, -1.4518e-02, -7.9962e-03,
         -6.9432e-03, -6.6613e-04, -1.5741e-02,  3.8555e+00,  3.2565e-04,
          3.5939e-03, -3.4209e-03, -3.2758e-03, -9.4420e-04,  1.3975e-03,
         -3.3315e-03,  1.3075e-02,  8.6962e-03,  7.7847e-03, -2.4286e-03,
         -2.8923e-04,  6.9406e-05, -1.0478e-03,  3.8247e-04,  2.7243e-01,
         -2.4557e-03, -2.4607e-03, -1.7781e-03,  8.5281e-04, -5.2712e-04,
          2.5532e-05, -2.7457e-03, -7.5480e-03,  3.3411e-03,  4.5942e-04,
         -1.0016e-03,  1.0696e-04, -6.8842e-05,  4.7495e-05,  1.3312e-04,
          2.3617e-02, -1.8531e-03, -3.0279e-04, -7.1154e-04,         nan,
                 nan, -1.0202e-04, -3.0047e-03,         nan,         nan,
         -1.5881e-04, -3.1517e-04,         nan,         nan,  7.1016e-05,
         -9.2394e-05,         nan,         nan,  2.9015e-04, -1.9015e-03,
                 nan,         nan,  6.1661e-05, -7.6932e-03,         nan,
                 nan,  5.9435e-04, -2.