In [7]:
from dataset import get_datasets
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

In [8]:
def get_mean_and_std(dataloader, count):
    mean = torch.zeros(3, device=device)
    std = torch.zeros(3, device=device)
    
    for images, _ in tqdm(dataloader):
        images = images.to(device)
        mean += images.sum(axis=[0, 2, 3])
        std += (images**2).sum(axis=[0, 2, 3])

    mean /= count
    std /= count

    return mean, std

In [11]:
image_dim = 224

transform = transforms.Compose([
    transforms.Resize((image_dim, image_dim)),
    transforms.ToTensor()
])

train_dataset, _, _ = get_datasets(transform)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=0)

mean, std = get_mean_and_std(train_dataloader, len(train_dataset)*image_dim*image_dim)

100%|██████████| 41/41 [00:51<00:00,  1.26s/it]


In [12]:
print(f"The mean is: {mean}")
print(f"The std is: {std}")

The mean is: tensor([0.4772, 0.4597, 0.4612], device='cuda:0')
The std is: tensor([0.2997, 0.2808, 0.2837], device='cuda:0')
