## Imports

In [51]:
import torch as torch
import torchvision as torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.utils import download_and_extract_archive
from PIL import Image
import numpy as np
import imgaug as ia
import imgaug.augmenters as iaa
import os
import re
from mlxtend.data import loadlocal_mnist

import matplotlib.pyplot as plt

import shutil

## Definition of Dataset Class

In [None]:
class MnistAugmentationDataset(Dataset):
    resources = [
        ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
        ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
    ]
    
    base_folder = '../data/MNIST/base'
    
    training_file = 'training.pt'
    test_file = 'test.pt'
    
    def download(self):
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            download_and_extract_archive(url, download_root=self.base_folder, filename=filename, md5=md5)
            
        for f in os.listdir(self.base_folder):
            if f[len(f) - 3: len(f)].lower() == '.gz':
                os.remove(os.path.join(self.base_folder, f))
    
    def __init__(self, name, seq=None, train=True, transform=None, force_augment=False, force_download=False):
        self.transform = transform
        
        if force_download or not os.path.exists(self.base_folder):
            if (os.path.exists(self.base_folder)):
                shutil.rmtree(self.base_folder)
            os.makedirs(self.base_folder)
            self.download()
        
        self.root_folder = '../data/MNIST/augmented/' + name
        if force_augment or not os.path.exists(self.root_folder):
            if (os.path.exists(self.root_folder)):
                shutil.rmtree(self.root_folder)
            os.makedirs(self.root_folder)
            if (train):
                self.image_data, self.label_data = loadlocal_mnist(
                    images_path=self.base_folder + '/train-images-idx3-ubyte', 
                    labels_path=self.base_folder + '/train-labels-idx1-ubyte')
            else:
                self.image_data, self.label_data = loadlocal_mnist(
                    images_path=self.base_folder + '/t10k-images-idx3-ubyte', 
                    labels_path=self.base_folder + '/t10k-labels-idx1-ubyte')
            
            self.image_data = np.reshape(self.image_data, (-1, 28, 28))
            
            if (seq != None):
                self.image_data = seq(images=self.image_data)
            
            #save the information to csv files
            print('Saving augmented data to filesystem...')
            image_output_data = np.reshape(self.image_data, (-1, (28*28)))
            np.savetxt(fname=self.root_folder + '/images.csv',
                X=image_output_data, delimiter=',', fmt='%d')
            np.savetxt(fname=self.root_folder + '/labels.csv', 
                X=self.label_data, delimiter=',', fmt='%d')
            print('Done!')
        else:
            image_input_data = np.loadtxt(fname=self.root_folder + '/images.csv', dtype='uint8', delimiter=',')
            self.image_data = np.reshape(image_input_data, (-1, 28, 28))
            self.label_data = np.loadtxt(fname=self.root_folder + '/labels.csv', dtype='uint8', delimiter=',')
            
        
            
    def __len__(self):
        return len(self.image_data)
            
    def __getitem__(self, index):
        if torch.is_tensor(index):
            idx = idx.tolist()
            
        image = self.image_data[index]
        label = self.label_data[index]
        
        image = Image.fromarray(image, mode='L')

        if self.transform is not None:
            image = self.transform(image)
            
        return image, label

In [None]:
seq = iaa.Sequential(
    [
        iaa.Flipud(1)
    ])
ds = MnistAugmentationDataset(name='flip', seq=seq)
image, label = ds.__getitem__(0)
plt.imshow(image)