Loads the handwritten math symbol dataset into a .pickle file  
Dataset can be downloaded from https://www.kaggle.com/datasets/xainano/handwrittenmathsymbols  
From the download extract the 'extracted_images' directory from 'archive.zip'  
(there is another extracted_images folder outside of archive.zip for some reason but it has less samples)

Both dataset and .pickle file are too large to upload to GitHub

In [1]:
import pickle, os
from PIL import Image
from torchvision.transforms import PILToTensor

In [2]:
# Function that takes a .jpeg file and returns a FloatTensor

transform = PILToTensor()
def jpeg_to_tensor(path):
    # .jpeg file to PIL image
    image = Image.open(path)
    # PIL image to pytorch float tensor
    image_tensor = transform(image).float()
    image_tensor.requires_grad_(True)
    return image_tensor

In [3]:
# Maps each label to an integer (index)
# NN will return an int, which will be converted to the label by simply indexing this list
# ex. if the NN predicts '+', it should output 3 (label_map[3] == '+')
label_map = os.listdir('extracted_images')

In [4]:
train, val, test = [], [], []
# samples are sorted into different folders
# name of folder is the label for the images it contains
for label in os.listdir('extracted_images'):
    curr_path = 'extracted_images/' + label
    # list of samples/files in directory
    sample_files = os.listdir(curr_path)
    # convert all files to pytorch tensors
    samples = [jpeg_to_tensor(curr_path + '/' + x) for x in sample_files]
    # integer representation of label should be stored in dataset
    y = label_map.index(label)

    num_samples = len(sample_files)
    print(f'Loading {num_samples} samples for label \'{label}\'')

    # move 80% of samples to training set
    for _ in range(int(0.8 * num_samples)):
        x = samples.pop()
        train.append((x, y))
    
    # move 10% of samples to val set
    for _ in range(int(0.1 * num_samples)):
        x = samples.pop()
        val.append((x, y))

    # move remaining samples to test set
    for _ in range(len(samples)):
        x = samples.pop()
        test.append((x, y))

Loading 1300 samples for label '!'
Loading 14294 samples for label '('
Loading 14355 samples for label ')'
Loading 25112 samples for label '+'
Loading 1906 samples for label ','
Loading 33997 samples for label '-'
Loading 6914 samples for label '0'
Loading 26520 samples for label '1'
Loading 26141 samples for label '2'
Loading 10909 samples for label '3'
Loading 7396 samples for label '4'
Loading 3545 samples for label '5'
Loading 3118 samples for label '6'
Loading 2909 samples for label '7'
Loading 3068 samples for label '8'
Loading 3737 samples for label '9'
Loading 13104 samples for label '='
Loading 12367 samples for label 'A'
Loading 2546 samples for label 'alpha'
Loading 1339 samples for label 'ascii_124'
Loading 8651 samples for label 'b'
Loading 2025 samples for label 'beta'
Loading 5802 samples for label 'C'
Loading 2986 samples for label 'cos'
Loading 4852 samples for label 'd'
Loading 137 samples for label 'Delta'
Loading 868 samples for label 'div'
Loading 3003 samples for 

In [5]:
dataset = {
    'train': train,
    'val': val,
    'test': test,
    'label_map': label_map
}

with open('dataset.pickle', 'wb') as new_file:
    pickle.dump(dataset, new_file, pickle.HIGHEST_PROTOCOL)