In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

In [None]:


# 1. DataLoader (no normalization! we’re computing it)
ds = datasets.ImageFolder(
    "path/to/images",
    transform=transforms.ToTensor()        # yields [0,1] float32 tensor
)
loader = DataLoader(ds, batch_size=64, num_workers=8, pin_memory=True)

# 2. Accumulators
sum_c   = torch.zeros(3)   # sum per channel
sum_sq  = torch.zeros(3)   # sum of squares per channel
n_pix   = 0                # total number of pixels per channel
sum_all = 0.0              # sum over all pixels & channels

for imgs, _ in loader:
    # imgs.shape == (B, C, H, W)
    B, C, H, W = imgs.shape
    n_pix += B * H * W

    # a) per‐channel sums
    sum_c   += imgs.sum(dim=[0,2,3])           # shape (C,)
    sum_sq  += (imgs*imgs).sum(dim=[0,2,3])    # shape (C,)

    # b) overall sum
    sum_all += imgs.sum()                      # scalar

# 3. Compute means & stds
#  — Per‐channel
mean_c = sum_c   / n_pix                       # shape (C,)
var_c  = sum_sq  / n_pix - mean_c**2
std_c  = torch.sqrt(var_c)

#  — Global image mean (scalar)
#    total pixels×channels = n_pix * C
global_mean = sum_all / (n_pix * C)

print("Per-channel mean:", mean_c)
print("Per-channel std: ", std_c)
print("Global image mean (all pixels & channels):", global_mean)
