## This Python notebook reads the MNIST data in from the format specified on Yann LeCunn's Page and returns a list of tuples. Each tuples comprises of the digit label and the image matrix of size 28 x 28. We then expand this dataset by rotating and translatig the image by various amounts and finally return the complete dataset

In [35]:
import numpy as np
from matplotlib import pyplot as plt 
import matplotlib
import struct
import os
import h5py
from scipy import ndimage as nd

%matplotlib inline

In [36]:
# Function to read the MNIST data from ubyte format
def ParseMNISTData(dataType='training', filePath='./Data'):
    if dataType is 'training':
        imgFileName = os.path.join(os.path.abspath(filePath), 'train-images.idx3-ubyte')
        lblFileName = os.path.join(os.path.abspath(filePath), 'train-labels.idx1-ubyte')
    elif dataType is 'testing':
        imgFileName = os.path.join(os.path.abspath(filePath), 't10k-images.idx3-ubyte')
        lblFileName = os.path.join(os.path.abspath(filePath), 't10k-labels.idx1-ubyte')
    else:
        raise ValueError, 'Datatype must be either "training" or "testing"'

    with open(lblFileName, 'rb') as inFile:
        struct.unpack('>II', inFile.read(8))
        labelData = np.fromfile(inFile, dtype=np.uint8)

    with open(imgFileName, 'rb') as inFile:
        _, _, rowCount, colCount = struct.unpack('>IIII', inFile.read(16))
        imageData = np.fromfile(inFile, dtype=np.uint8).reshape(len(labelData), rowCount, colCount)

    GetData = lambda dataIndex : (labelData[dataIndex], imageData[dataIndex])

    for dataIndex in range(len(labelData)):
        yield GetData(dataIndex)

In [37]:
# Display the number image
def DrawNumber(numberImage):
    imgFigure = plt.figure().add_subplot(1, 1, 1)
    imgPlot = imgFigure.imshow(numberImage, cmap=matplotlib.cm.Greys)
    imgPlot.set_interpolation('nearest')
    imgFigure.xaxis.set_ticks_position('top')
    imgFigure.yaxis.set_ticks_position('left')
    plt.show()

### Functions for manipulation of images to augment the dataset for better learning

In [38]:
rotateAngles = [-15, -10, -5, 5, 10, 15]
translateDistance = [(1, 0), (0, 1), (-1, 0), (0, -1), (2, 0), (0, 2), (-2, 0), (0, -2), (3, 0), (0, 3), (-3, 0), (0, -3) ]

In [39]:
def RotateImages(originalImage):
    rotatedImageSet = []
    for angleVal in rotateAngles:
         rotatedImageSet.append(nd.rotate(originalImage, angleVal, reshape=False, mode='constant'))
    
    return rotatedImageSet

In [40]:
def TranslateImage(originalImage):
    translatedImageSet = []
    for shiftVal in translateDistance:
         translatedImageSet.append(nd.shift(originalImage, shiftVal, mode='constant'))
        
    return translatedImageSet

In [41]:
def GetDataset(dataType='training', filePath='./Data'):
    print 'Reading Data from Input File'
    dataSet = []
    for dataItem in ParseMNISTData(dataType, filePath):
        dataSet.append(dataItem)
        
    print 'Done'
    return dataSet

In [42]:
def AugmentDataset(originalDataset):
    completeDataset = []
    for dataItem in originalDataset:
        
        completeDataset.append(dataItem)
        
        manipulatedImageSet = RotateImages(dataItem[1])
        for manipulatedImage in manipulatedImageSet:
            completeDataset.append((dataItem[0], manipulatedImage))
            
        manipulatedImageSet = TranslateImage(dataItem[1])
        for manipulatedImage in manipulatedImageSet:
            completeDataset.append((dataItem[0], manipulatedImage))
           
    del originalDataset
    return completeDataset

In [45]:
def CreateFinalDataset(dataType='training', filePath='./Data'):
    if not os.path.exists(filePath):
        raise ValueError, 'Invalid Path provided.'
    
    if dataType is 'training':
        print 'Training DataType'
        outFileName = os.path.join(os.path.abspath(filePath), 'trainingData.h5')
        linkFileName = os.path.join(os.path.abspath(filePath), 'trainingData.txt')
    elif dataType is 'testing':
        print 'Testing Datatype'
        outFileName = os.path.join(os.path.abspath(filePath), 'testingData.h5')
        linkFileName = os.path.join(os.path.abspath(filePath), 'testingData.txt')
    else:
        raise ValueError, 'Datatype must be either "training" or "testing"'
        
    completeDataSet = GetDataset(dataType, filePath)
    if dataType is 'training':
        print 'Augmenting Training Data'
        completeDataSet = AugmentDataset(completeDataSet)
        
    imageMatrixSet, digitLabelSet = zip(*completeDataSet)
    comp_kwargs = {'compression': 'gzip', 'compression_opts': 1}
    with h5py.File(outFileName, 'w') as outFile:
        print 'Writing Data for HDF5 File'
        outFile.create_dataset('data', data=imageMatrixSet, **comp_kwargs)
        outFile.create_dataset('label', data=digitLabelSet, **comp_kwargs)
    with open(linkFileName, 'w') as outFile:
        outFile.write(outFileName + '\n')
    
    print 'Done'

In [46]:
CreateFinalDataset('training')

Training DataType
Reading Data from Input File
Done
Augmenting Training Data
1140000
Writing Data for HDF5 File
Done
