In [2]:
import os
import struct
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class FashionMNISTDataset(Dataset):
    def __init__(self, image_path, label_path, transform=None):
        self.images = self.read_images(image_path)
        self.labels = self.read_labels(label_path)
        self.transform = transform

    def read_images(self, path):
        with open(path, 'rb') as f:
            _, num, rows, cols = struct.unpack(">IIII", f.read(16))
            images = np.fromfile(f, dtype=np.uint8).reshape(num, rows, cols)
        return images

    def read_labels(self, path):
        with open(path, 'rb') as f:
            _, num = struct.unpack(">II", f.read(8))
            labels = np.fromfile(f, dtype=np.uint8)
        return labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        image = Image.fromarray(image, mode='L')
        if self.transform:
            image = self.transform(image)
        return image, label

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = FashionMNISTDataset('./datasets/FashionMNIST/train-images-idx3-ubyte',
                                    './datasets/FashionMNIST/train-labels-idx1-ubyte',
                                    transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)


60000