In [1]:
from PIL import Image  # has to be a separate cell to avoid import errors

In [5]:
from utils.data_loading import load_cifar10

train_data = load_cifar10(train=True, raw=True, return_loader=False)

In [6]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# pick 9 random images
indices = np.random.choice(len(train_data), 9, replace=False)

# create a 3x3 subplot grid
fig = make_subplots(
    rows=3,
    cols=3,
    subplot_titles=[train_data.classes[train_data[i][1]] for i in indices],
)

for idx, img_idx in enumerate(indices):
    row = idx // 3 + 1
    col = idx % 3 + 1

    img = train_data[img_idx][0].numpy()  # [C,H,W]
    img = np.transpose(img, (1, 2, 0))  # -> [H,W,C]
    img = (img * 255).astype(np.uint8)  # scale to [0,255] for go.Image

    fig.add_trace(go.Image(z=img), row=row, col=col)

# update layout
fig.update_layout(width=320 * 3, height=320 * 3, showlegend=False)


In [12]:
import torch

counts = torch.bincount(torch.tensor(train_data.targets, dtype=torch.int64))
for label, count in enumerate(counts):
    print(
        f"Label {label}: {count.item()} samples, {count.item() / len(train_data) * 100:.2f}% of dataset"
    )


Label 0: 5000 samples, 10.00% of dataset
Label 1: 5000 samples, 10.00% of dataset
Label 2: 5000 samples, 10.00% of dataset
Label 3: 5000 samples, 10.00% of dataset
Label 4: 5000 samples, 10.00% of dataset
Label 5: 5000 samples, 10.00% of dataset
Label 6: 5000 samples, 10.00% of dataset
Label 7: 5000 samples, 10.00% of dataset
Label 8: 5000 samples, 10.00% of dataset
Label 9: 5000 samples, 10.00% of dataset


In [6]:
import torch

# train_data already loaded with transforms.ToTensor() but NO Normalize
mean = torch.zeros(3)
std = torch.zeros(3)
n_samples = 0

for img, _ in train_data:
    # img is [C, H, W] float in [0,1]
    n_pixels = img.numel() // 3  # H*W
    mean += img.view(3, -1).mean(dim=1)
    std += img.view(3, -1).std(dim=1)
    n_samples += 1

mean /= n_samples
std /= n_samples

print("mean:", mean)
print("std:", std)


mean: tensor([0.4914, 0.4822, 0.4465])
std: tensor([0.2023, 0.1994, 0.2010])
