In [None]:
!pip install torch torchvision
import torch
import torchvision
import torchvision.transforms as transforms



**Step 1: Load the dataset without normalization**

First, load the CIFAR-10 training dataset. To calculate the original mean and standard deviation, you must use a minimal transform that only converts the images to tensors

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train = True, download= True, transform=transform)

100%|██████████| 170M/170M [00:06<00:00, 27.4MB/s]


**Step 2: Stack all images into a single tensor**

Next, iterate through the training dataset and combine all the image tensors into one large tensor. This will create a tensor of shape [N, C, H, W], where N is the number of images, C is the number of channels, and H and W are the height and width

In [None]:
all_images =[]
for images, _ in trainset:
  all_images.append(images)
all_images_tensor = torch.stack(all_images)

print(f"Shape of the stacked images tensor: {all_images_tensor.shape}")

Shape of the stacked images tensor: torch.Size([50000, 3, 32, 32])


In [None]:
all_images[10]

tensor([[[0.2078, 0.2118, 0.2196,  ..., 0.1843, 0.1608, 0.0941],
         [0.1804, 0.2078, 0.2118,  ..., 0.1647, 0.1529, 0.1098],
         [0.1765, 0.1961, 0.1804,  ..., 0.1490, 0.1412, 0.1137],
         ...,
         [0.2784, 0.2902, 0.3137,  ..., 0.2000, 0.1804, 0.1922],
         [0.2941, 0.3098, 0.3176,  ..., 0.2392, 0.2510, 0.1882],
         [0.3333, 0.3333, 0.3373,  ..., 0.2392, 0.2510, 0.1922]],

        [[0.2549, 0.2471, 0.2353,  ..., 0.2000, 0.1765, 0.1098],
         [0.2314, 0.2431, 0.2314,  ..., 0.1804, 0.1686, 0.1255],
         [0.2314, 0.2353, 0.2039,  ..., 0.1647, 0.1569, 0.1294],
         ...,
         [0.3255, 0.3255, 0.3333,  ..., 0.2118, 0.1922, 0.1961],
         [0.3216, 0.3333, 0.3333,  ..., 0.2549, 0.2627, 0.1961],
         [0.3255, 0.3294, 0.3373,  ..., 0.2549, 0.2627, 0.1961]],

        [[0.2078, 0.2039, 0.1961,  ..., 0.1961, 0.1725, 0.1059],
         [0.1608, 0.1765, 0.1725,  ..., 0.1765, 0.1647, 0.1216],
         [0.1490, 0.1608, 0.1333,  ..., 0.1608, 0.1529, 0.

In [None]:
all_images_tensor[0]

tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
         [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
         [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
         ...,
         [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
         [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
         [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],

        [[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],
         [0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],
         [0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],
         ...,
         [0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],
         [0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],
         [0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],

        [[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],
         [0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],
         [0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.

**Step 3: Reshape and compute the channel-wise mean**

To calculate the mean for each channel, you need to reshape the tensor so that all pixel values for each channel are in a single dimension. This can be done with all_images_tensor.view(3, -1), which produces a tensor of shape [3, 51200000] (3 channels, and 50,000 * 32 * 32 pixels per channel). Then, take the mean along the pixel dimension (dim=1).

In [None]:
# Reshape the tensor to group all pixel values for each channel
channel_wise_pixels = all_images_tensor.view(3,-1)

# Calculate the mean for each channel
mean = channel_wise_pixels.mean(dim=1)

print(f"Mean for each channel: {mean}")

Mean for each channel: tensor([0.4741, 0.4727, 0.4733])


**Step 4: Reshape and compute the channel-wise standard deviation**

The same process is used for the standard deviation. After reshaping, calculate the standard deviation along the pixel dimension.

In [None]:
channel_wise_pixels = all_images_tensor.view(3,-1)
std = channel_wise_pixels.std(dim=1)
print(f"Standard deviation for each channel: {std}")

Standard deviation for each channel: tensor([0.2521, 0.2520, 0.2506])
