In [None]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image



dataset = datasets.CIFAR10
(
    root="./data",
    train=True,
    download=True
)

original_img, label = dataset[0]
original_img.save("original.png")  





resize_transform = transforms.Resize((64, 64))

normalize_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

standardize_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[np.mean(original_img)],
        std=[np.std(original_img)]
    )
])

augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(25)
])

grayscale_transform = transforms.Grayscale(num_output_channels=1)





resized_img = resize_transform(original_img)
resized_img.save("resized.png")

norm_img_tensor = normalize_transform(original_img)
norm_img = transforms.ToPILImage()(norm_img_tensor)
norm_img.save("normalized.png")

std_img_tensor = standardize_transform(original_img)
std_img = transforms.ToPILImage()(std_img_tensor)
std_img.save("standardized.png")

augmented_img = augment_transform(original_img)
augmented_img.save("augmented.png")

gray_img = grayscale_transform(original_img)
gray_img.save("grayscale.png")





def show(title, img_path):
    img = Image.open(img_path)
    plt.imshow(img)
    plt.title(title)
    plt.axis("off")
    plt.show()


show("Original", "original.png")
show("Resized (64x64)", "resized.png")
show("Normalized", "normalized.png")
show("Standardized", "standardized.png")
show("Augmented (Flip + Rotate)", "augmented.png")
show("Grayscale", "grayscale.png")

print("work done")
