In [1]:
# Datasets from https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/simplified
%matplotlib inline 
import numpy as np
import matplotlib.pyplot as plt
import cv2
import ndjson
import time
import random
import glob
import re
import os

N_ITEMS_PER_CLASS = 20
IMG_SIZE = 28  # 28x28

In [2]:
def select_samples_from_class(data):
    ''' Select only recognized samples. '''
    random.seed(0)
    indexes = [i for i, sample in enumerate(data) if sample['recognized']]
    random.shuffle(indexes)
    return indexes[:N_ITEMS_PER_CLASS]

In [3]:
def save_drawings(data, sample_indexes, class_title):
    ''' Generates multiple images per sample. '''
    for i in sample_indexes:     
        lines = []
        for line in data[i]['drawing']:
            # Scale img to IMG_SIZE
            scaler = IMG_SIZE / 255.0  
            line[0] = [int(l * scaler) for l in line[0]]
            line[1] = [int(l * scaler) for l in line[1]] 
            
            # Connect lines
            unique_vertices = list(zip(line[0], line[1]))
            for j in range(len(unique_vertices)-1):
                lines.append([unique_vertices[j], unique_vertices[j+1]])

            # Generate images
            os.makedirs('datasets/quickdraw/{}/'.format(class_title), exist_ok=True)
            img = np.zeros((IMG_SIZE, IMG_SIZE))
            for t, line in enumerate(lines):
                img = cv2.line(img, line[0], line[1], thickness=1, color=(1,1,1))
                cv2.imwrite('datasets/quickdraw/{}/{}-{}-{}.png'.format(class_title, class_title, i, t), img*255)
#             plt.figure()
#             plt.imshow(img)

In [4]:
file_paths = [f for f in glob.glob("raw_datasets/quickdraw/*.ndjson")]
file_paths

['raw_datasets/quickdraw/baseball.ndjson',
 'raw_datasets/quickdraw/cello.ndjson',
 'raw_datasets/quickdraw/bicycle.ndjson',
 'raw_datasets/quickdraw/banana.ndjson',
 'raw_datasets/quickdraw/carrot.ndjson',
 'raw_datasets/quickdraw/airplane.ndjson']

In [5]:
for file_path in file_paths:
    with open(file_path) as f:
        match = re.match(r'.*/(\w*).ndjson', file_path)
        if match:
            class_title = match.group(1)

        data = ndjson.load(f)
        sample_indexes = select_samples_from_class(data)
        save_drawings(data, sample_indexes, class_title)