### Transform MNIST Dataset

In [1]:
import torch
import torchvision

from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.datasets import ImageFolder

from torchvision.transforms import v2

In [2]:
import os
import json

import numpy as np

from PIL import Image

In [3]:
import warnings

warnings.filterwarnings('ignore')

In [4]:
# Create our MNIST dataset implementation
class MNISTDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.transform = transform

        self.len_dataset = 0
        self.data_list = []

        for path_dir, dir_list, file_list in os.walk(path):
            if path_dir == path:
                self.classes = sorted(dir_list)
                # dict with class names and they positions in one_hot vector
                self.class_to_index = {
                    cls_name: i for i, cls_name in enumerate(self.classes)
                }
                continue

            cls = path_dir.split('/')[-1]

            for name_file in file_list:
                file_path = os.path.join(path_dir, name_file)
                self.data_list.append((file_path, self.class_to_index[cls]))

            self.len_dataset += len(file_list)

    # always implement length of dataset here
    def __len__(self):
        return self.len_dataset

    # data with they class
    def __getitem__(self, index):
        file_path, target = self.data_list[index]
        sample = np.array(Image.open(file_path))

        if self.transform is not None:
            sample = self.transform(sample)
        # return sample and it's class position in one_hot vector
        return sample, target

In [5]:
transform = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        # norm only one color channel
        v2.Normalize(mean=(0.5, ), std=(0.5, ))
    ]
)

In [6]:
# create datasets
path = os.path.join(os.getcwd(), 'mnist')
train_data = MNISTDataset(os.path.join(path, 'training'), transform=transform)
test_data = MNISTDataset(os.path.join(path, 'testing'), transform=transform)

In [7]:
img, cls = test_data[2]

print('Img:')
print(f'    {type(img)}')
print(f'    {img.shape}')
print(f'    {img.dtype}')
print(f'    min = {img.min()}, max = {img.max()}')
print('cls:')
print(f'    {cls}')

Img:
    <class 'torchvision.tv_tensors._image.Image'>
    torch.Size([1, 28, 28])
    torch.float32
    min = -1.0, max = 1.0
cls:
    2


In [8]:
# split data
train_data, val_data = random_split(train_data, [0.8, 0.2])

In [9]:
# do batches
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(val_data, batch_size=16, shuffle=False)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)

In [10]:
imgs, cls = next(iter(train_loader))

print('Imgs:')
print(f'    {type(imgs)}')
print(f'    {imgs.shape}')
print(f'    {imgs.dtype}')
print(f'    min = {img.min()}, max = {img.max()}')
print('cls:')
print(f'    {type(cls)}')
print(f'    {cls.shape}')
print(f'    {cls.dtype}')

Imgs:
    <class 'torch.Tensor'>
    torch.Size([16, 1, 28, 28])
    torch.float32
    min = -1.0, max = 1.0
cls:
    <class 'torch.Tensor'>
    torch.Size([16])
    torch.int64


#### Do the same with ImageFolder

In [11]:
# ImageFolder always create 3 color channels, so add one more transformation
transform = v2.Compose(
    [
        v2.ToImage(),
        v2.Grayscale(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=(0.5, ), std=(0.5, ))
    ]
)

In [12]:
path = os.path.join(os.getcwd(), 'mnist')
train_data = ImageFolder(os.path.join(path, 'training'), transform=transform)
test_data = ImageFolder(os.path.join(path, 'testing'), transform=transform)

In [13]:
img, cls = test_data[2]

print('Img:')
print(f'    {type(img)}')
print(f'    {img.shape}')
print(f'    {img.dtype}')
print(f'    min = {img.min()}, max = {img.max()}')
print('cls:')
print(f'    {cls}')

Img:
    <class 'torchvision.tv_tensors._image.Image'>
    torch.Size([1, 28, 28])
    torch.float32
    min = -1.0, max = 0.992156982421875
cls:
    0


In [14]:
train_data, val_data = random_split(train_data, [0.8, 0.2])

In [15]:
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(val_data, batch_size=16, shuffle=False)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)

In [16]:
imgs, cls = next(iter(train_loader))

print('Imgs:')
print(f'    {type(imgs)}')
print(f'    {imgs.shape}')
print(f'    {imgs.dtype}')
print(f'    min = {img.min()}, max = {img.max()}')
print('cls:')
print(f'    {type(cls)}')
print(f'    {cls.shape}')
print(f'    {cls.dtype}')

Imgs:
    <class 'torch.Tensor'>
    torch.Size([16, 1, 28, 28])
    torch.float32
    min = -1.0, max = 0.992156982421875
cls:
    <class 'torch.Tensor'>
    torch.Size([16])
    torch.int64
