1. Create a function to load ndjson files

In [None]:
def load_ndjson(file):
    with open(file, 'r') as f:
        data = f.readlines()
    data = [d.strip() for d in data]
    return data

2. The NDJSON files contain all the data about a certain class. The drawings of the class are in an array of strokes. In order to use this dataset with a CNN, we need to transform those strokes into images.

In [None]:
from PIL import Image, ImageDraw

def stroke_to_img(strokes, img_size=(256, 256)):
    img = Image.new('L', img_size, 255)
    draw = ImageDraw.Draw(img)
    for stroke in strokes:
        coords = list(zip(stroke[0], stroke[1]))
        draw.line(coords, fill=0, width=2)
    return img

3. Function to process a certain class given the NDJSON filename. It will load and iterate through the drawings of that class to convert them from strokes to images and split them into train and test datasets.

In [None]:
import json, os
def process_class(class_name):
    ndjson = load_ndjson(f'./raw/{class_name}.ndjson')
    if not os.path.exists(f'./processed_dataset/train/{class_name}'):
        os.makedirs(f'./processed_dataset/train/{class_name}')
    if not os.path.exists(f'./processed_dataset/test/{class_name}'):
        os.makedirs(f'./processed_dataset/test/{class_name}')

    items = len(ndjson)
    index = 0
    train_items = int(items * 0.8)
    for item in ndjson:
        data = json.loads(item)
        drawing = data['drawing']
        img = stroke_to_img(drawing)
        if index < train_items:
            img.save(f'./processed_dataset/train/{class_name}/{class_name}' + str(index) + '.png')
        else:
            img.save(f'./processed_dataset/test/{class_name}/{class_name}' + str(index) + '.png')
        index += 1
        print(f'Processing class {class_name}: ' + str(index) + '/' + str(items), end='\r')

4. Use a for loop to iterate all classes with the previous function

In [None]:
classes = ["apple", "cat", "computer", "fish", "clock", "moon", "bird", "tree", "eyeglasses", "ice_cream"]
for class_name in classes:
    process_class(class_name)
    print(f'Processed class {class_name}')