In [28]:
import torch

import os

def load_data(root_dir, file_name):
    filepath = os.path.join(root_dir, file_name)
    with open(filepath, 'rb') as f:
        data = f.read()
    # data[]0 und data[1] are 0
    # data[2] and data[3] contain type and dims
    num_type, num_dims = data[2], data[3]
    assert num_type == 0x08
    shape = [int.from_bytes(data[4*(i+1) :4 * (i+2)], 'big' ) for i in range(num_dims)]
    print(f"found {num_dims}-dim array of type {type} and shape {shape}")

    parsed = torch.frombuffer(bytearray(data), dtype=torch.uint8, offset=(4*(num_dims + 1)))
    return parsed.view(shape)


valid_label = load_data('./data/', 't10k-labels.idx1-ubyte')

valid_data = load_data('./data/', 't10k-images.idx3-ubyte')

train_data = load_data('./data/', 'train-images.idx3-ubyte')

train_label = load_data('./data/', 'train-labels.idx1-ubyte')

train_data


found 1-dim array of type <class 'type'> and shape [10000]
found 3-dim array of type <class 'type'> and shape [10000, 28, 28]
found 3-dim array of type <class 'type'> and shape [60000, 28, 28]
found 1-dim array of type <class 'type'> and shape [60000]


tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0,

In [29]:
from torch.utils.data import Dataset, DataLoader

class MNISTDataset(Dataset):
    
    def __init__(self, root_dir, train=True):
        self.root_dir = root_dir
        self.train = train
        self.data = load_data(root_dir, f"{'train' if self.train else 't10k'}-images.idx3-ubyte")
        self.label = load_data(root_dir, f"{'train' if self.train else 't10k'}-labels.idx1-ubyte")


    def __getitem__(self, idx):
        return (self.data[idx], self.label[idx])

    def __len__(self):
        return len(self.data)


mnist = MNISTDataset('data')

print(mnist)

for i in range(len(mnist)):
    image, label = mnist[i]


data train-images.idx3-ubyte
found 3-dim array of type <class 'type'> and shape [60000, 28, 28]
found 1-dim array of type <class 'type'> and shape [60000]
tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],

In [30]:
dataloader = DataLoader(mnist, batch_size=4, shuffle=True, num_workers=0)

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched[0].size(), sample_batched[1])
    if i_batch == 10:
        break
    

0 torch.Size([4, 28, 28]) tensor([9, 2, 6, 8], dtype=torch.uint8)
1 torch.Size([4, 28, 28]) tensor([7, 3, 9, 9], dtype=torch.uint8)
2 torch.Size([4, 28, 28]) tensor([3, 8, 8, 8], dtype=torch.uint8)
3 torch.Size([4, 28, 28]) tensor([0, 8, 5, 8], dtype=torch.uint8)
4 torch.Size([4, 28, 28]) tensor([6, 2, 0, 5], dtype=torch.uint8)
5 torch.Size([4, 28, 28]) tensor([2, 1, 8, 3], dtype=torch.uint8)
6 torch.Size([4, 28, 28]) tensor([0, 3, 1, 9], dtype=torch.uint8)
7 torch.Size([4, 28, 28]) tensor([2, 1, 9, 6], dtype=torch.uint8)
8 torch.Size([4, 28, 28]) tensor([1, 0, 2, 2], dtype=torch.uint8)
9 torch.Size([4, 28, 28]) tensor([9, 9, 1, 8], dtype=torch.uint8)
10 torch.Size([4, 28, 28]) tensor([3, 8, 4, 9], dtype=torch.uint8)


In [33]:
def get_mean_and_std(data):
    sum_mean, sum_stdev = 0,0
    for img in data:
        sum_mean += torch.sum(img)
    mean =sum_mean / (data.shape[0] * data.shape[1] * data.shape[2])

    for img in data:
        sum_stdev += torch.sum((img-mean)**2)
    stdev = (sum_stdev / (data.shape[0] * data.shape[1] * data.shape[2]) )**0.5

    return mean, stdev

print(train_data.shape)
train_data_scaled = train_data / 255
mean, stdev = get_mean_and_std(train_data_scaled)

train_data_normalized = (train_data_scaled - mean) / stdev
print(train_data_normalized)

torch.Size([60000, 28, 28])
tensor([[[-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         ...,
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241]],

        [[-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         ...,
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         [-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241]],

        [[-0.4241, -0.4241, -0.4241,  ..., -0.4241, -0.4241, -0.4241],
         [-0.4241