In [20]:
import torch
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms

In this notebook we calculate the mean and the STD for the whole consolidated Oxford Flowers 102 dataset.

In [21]:
# Define transform
basic_transform = transforms.Compose([transforms.Resize((224, 224)),
                                      transforms.ToTensor()])

In [22]:
# Load the Oxford Flowers 102 dataset
default_train_data = datasets.Flowers102(root = "data",
                                 split = "train",
                                 download = True,
                                 transform = basic_transform)
default_val_data = datasets.Flowers102(root = "data",
                                split= "val",
                                download = True,
                                transform = basic_transform)
default_test_data = datasets.Flowers102(root = "data",
                                split= "test",
                                download = True,
                                transform = basic_transform)

In [23]:
# Check results
len(default_train_data), len(default_val_data), len(default_test_data)

(1020, 1020, 6149)

In [24]:
# Check image datype
type(default_train_data[0][0])

torch.Tensor

In [25]:
# Check shape
default_train_data[0][0].shape

torch.Size([3, 224, 224])

In [26]:
# Check class datatype
type(default_train_data[0][1])

int

In [27]:
# Concat all data
all_flower_images = ConcatDataset([default_train_data, default_val_data, default_test_data])
len(all_flower_images)

8189

In [28]:
# Create loader for the whole dataset
all_flower_loader = DataLoader(all_flower_images, batch_size = 64, shuffle = False)

In [29]:
# Define function to calculate the mean and the standard deviation for the whole dataset
def calculate_mean_std(loader):
    mean = 0.0
    std = 0.0
    total_images = 0

    # Iterate over the dateset
    for images, _ in loader:
        # Calculate the mean and std for each batch(channel, height, widht)
        batch_mean = torch.mean(images, dim = [ 0, 2, 3])
        batch_std = torch.std(images, dim = [0, 2, 3])

        # Accumulate the mean and std
        mean += batch_mean
        std += batch_std
        total_images += 1

    # Average the results across all batches
    mean /= total_images
    std /= total_images

    return mean, std

# Calculate mean and std for the entire dataset
mean, std = calculate_mean_std(all_flower_loader)

In [30]:
# Check results
mean, std

(tensor([0.4356, 0.3777, 0.2879]), tensor([0.2884, 0.2371, 0.2529]))

In [31]:
# Round results
mean_rounded = torch.round(mean, decimals = 3)
std_rounded = torch.round(std, decimals =  3)
mean_rounded, std_rounded

(tensor([0.4360, 0.3780, 0.2880]), tensor([0.2880, 0.2370, 0.2530]))