In [1]:
import os

import numpy as np
import scipy as sc
import pandas as pd

import torch
from torch.utils.data import Dataset
from torch import nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter, MaxNLocator

_ = torch.manual_seed(42)

In [2]:
normalize = transforms.Normalize((0.5), (0.5))
transform = transforms.Compose([transforms.ToTensor(), normalize])


full_train_set = torchvision.datasets.FashionMNIST(root="../data/", train=True, download=True, transform=transform)
full_train_set_size = len(full_train_set)

test_set = torchvision.datasets.FashionMNIST(root="../data/", train=False, download=True, transform=transform)
test_set_size = len(test_set)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw



In [3]:
class ReducedFashionMNISTDataset(Dataset):
    """
    Custom reduced dataset class for workflow logic
    """
    
    def __init__(self, data, target, classes):
        """
        Args:
            data (tensor): The data
            target (tensor): The target class of the data
            classes (string list): Ordered list of the classes represented by target
        """
        self.data = data
        self.target = target
        self.classes = classes
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idc.tolist()
        return (self.data[idx], self.labels[ids],)
    
    def __repr__(self):
        return f'Dataset Reduced FashionMNIST\n    Number of datapoints: {self.__len__()}'
    
    def __str__(self):
        return self.__repr__()

In [4]:
reduction_factor = 0.001 # Ratio of the full train set used in the reduced one

indices = np.random.choice(np.arange(full_train_set_size), int(full_train_set_size*reduction_factor) ,replace=False)

# Normalize the data
red_data = full_train_set.data[indices].double()
red_data /= red_data.max() - red_data.min()
red_data = 2. * red_data - 1.

red_labels = full_train_set.targets[indices]

del indices # Free some memory

reduced_train_set = ReducedFashionMNISTDataset(red_data, red_labels, full_train_set.classes)