# Normalization of Datasets

## MEDMNIST: PathMNIST

In [28]:
from medmnist import PathMNIST
from torchvision.transforms import v2
from pathlib import Path
import torch
import numpy as np

In [29]:
root = Path("./data/") / "PathMNIST"
root.mkdir(parents=True, exist_ok=True)
data = PathMNIST(
    root=root,
    split="train",
    download=True,
    transform=v2.Compose(
        [
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
        ]
    ),
)
imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()
imgs = imgs.reshape(-1, data.imgs.shape[-1])
print(f"PathMNIST Dataset Shape: {data.imgs.shape}")
print(f"PathMNIST Dataset Means: {np.round(imgs.mean(0), 3)}")
print(f"PathMNIST Dataset Stds: {np.round(imgs.std(0), 3)}")

PathMNIST Dataset Shape: (89996, 28, 28, 3)
PathMNIST Dataset Means: [0.238 0.238 0.238]
PathMNIST Dataset Stds: [0.358 0.309 0.352]


## MEDMNIST: PathMNIST

In [30]:
from medmnist import RetinaMNIST


root = Path("./data/") / "RetinaMNIST"
root.mkdir(parents=True, exist_ok=True)
data = RetinaMNIST(
    root=root,
    split="train",
    download=True,
    transform=v2.Compose(
        [
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
        ]
    ),
)
imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()
imgs = imgs.reshape(-1, data.imgs.shape[-1])
print(f"RetinaMNIST Dataset Shape: {data.imgs.shape}")
print(f"RetinaMNIST Dataset Means: {np.round(imgs.mean(0), 3)}")
print(f"RetinaMNIST Dataset Stds: {np.round(imgs.std(0), 3)}")

RetinaMNIST Dataset Shape: (1080, 28, 28, 3)
RetinaMNIST Dataset Means: [0.399 0.245 0.156]
RetinaMNIST Dataset Stds: [0.298 0.201 0.151]


## CIFAR10

In [31]:
from torchvision.datasets import CIFAR10

root = Path("./data/") / "CIFAR10"
root.mkdir(parents=True, exist_ok=True)
data = CIFAR10(
    root=root,
    train=True,
    download=True,
    transform=v2.Compose(
        [
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
        ]
    ),
)
imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()
print(f"CIFAR10 Dataset Shape: {imgs.shape}")
imgs = imgs.reshape(-1, imgs.shape[-1])
print(f"CIFAR10 Dataset Means: {np.round(imgs.mean(0), 3)}")
print(f"CIFAR10 Dataset Stds: {np.round(imgs.std(0), 3)}")

CIFAR10 Dataset Shape: (50000, 32, 32, 3)
CIFAR10 Dataset Means: [0.328 0.328 0.328]
CIFAR10 Dataset Stds: [0.278 0.269 0.268]


----------------------