In [16]:
import numpy as np
import ndjson
import json
import os

In [17]:
numpy_path = 'data/numpy/'
datapath = 'data/images_{0}.npy'
sequencepath = 'data/sequences_{0}.npy'
labelpath = 'data/labels_{0}.json'
testsize = 10000
valsize = 5000

#first: stroke
#second: x/y
#third: point

# [1~254], 255: new stroke, 0: the end/default
# maximum length: 300
MAX_LENGTH = 200
simplify_path = 'data/simplify/'

# Select 100 Categories, Get all the categories

In [18]:
def file_name(file_dir): 
    names = []
    for root, dirs, files in os.walk(file_dir):
        names.append(files) 
    names = [file.split('.')[0] for file in names[0]]
    return names

In [19]:
names = file_name(numpy_path)[:100]
name_dict = {}
for i in range(len(names)):
    name_dict[i] = names[i]

In [20]:
with open('categories.json', 'w') as fd:
    json.dump(names, fd)
with open('categories_dict.json', 'w') as fd:
    json.dump(name_dict, fd)

# Filter images with too few strokes

In [34]:
select_mask = {}

for k in range(len(names)):
    op = open(simplify_path+names[k]+'.ndjson')
    data = ndjson.load(op)
    data_size = len(data)
    select_length = np.zeros((data_size), dtype=np.int16)
    
    for i in range(len(data)):
        instance = data[i]['drawing']
        count = 0
        for stroke in instance:
            count += len(stroke[0])
        select_length[i] = count
    mean_length = int(select_length.mean())
    mask = (select_length > (mean_length - 50)) * (select_length < (mean_length + 50)) * (select_length < MAX_LENGTH)
    print('processing '+names[k]+' '+str(mask.astype(int).sum()))
    np.save(simplify_path+names[k]+'_length.npy', select_length[mask])
    print(select_length[mask].shape)
    select_mask[names[k]] = mask

processing squiggle 111792
(111792,)
processing bread 119123
(119123,)
processing violin 213951
(213951,)
processing bush 107552
(107552,)
processing eyeglasses 222123
(222123,)
processing soccer ball 112554
(112554,)
processing string bean 111364
(111364,)
processing shovel 116184
(116184,)
processing zebra 140548
(140548,)
processing kangaroo 171430
(171430,)
processing spoon 123767
(123767,)
processing submarine 121899
(121899,)
processing underwear 123378
(123378,)
processing hot air balloon 125357
(125357,)
processing pickup truck 126779
(126779,)
processing snowman 337081
(337081,)
processing chair 219567
(219567,)
processing cloud 118882
(118882,)
processing giraffe 125214
(125214,)
processing axe 122798
(122798,)
processing matches 136000
(136000,)
processing aircraft carrier 111974
(111974,)
processing camel 120620
(120620,)
processing saxophone 115938
(115938,)
processing streetlight 121459
(121459,)
processing drums 133960
(133960,)
processing camouflage 134987
(134987,)
pro

# Combine all the numpy file together, generate category, train/val/test split

In [35]:
# generate category, data
category_train = []
category_test = []
category_val = []
image_train = None
image_test = None
image_val = None
for i in range(len(names)):
    data = np.load(numpy_path+names[i]+'.npy')
    data = data[select_mask[names[i]]]
    print(names[i])
    print(data.shape)
    instance_num = data.shape[0]
    
    image_test = np.concatenate((image_test, data[-testsize:]), axis=0) if image_test is not None else data[-testsize:]
    image_val = np.concatenate((image_val, data[-testsize-valsize:-testsize]), axis=0) if image_val is not None else data[-testsize-valsize:-testsize]
    image_train = np.concatenate((image_train, data[:-testsize-valsize]), axis=0) if image_train is not None else data[:-testsize-valsize]
    # set category
    category_test.extend([i] * (testsize))
    category_val.extend([i] * (valsize))
    category_train.extend([i] * (instance_num-testsize-valsize))
    

squiggle
(111792, 784)
bread
(119123, 784)
violin
(213951, 784)
bush
(107552, 784)
eyeglasses
(222123, 784)
soccer ball
(112554, 784)
string bean
(111364, 784)
shovel
(116184, 784)
zebra
(140548, 784)
kangaroo
(171430, 784)
spoon
(123767, 784)
submarine
(121899, 784)
underwear
(123378, 784)
hot air balloon
(125357, 784)
pickup truck
(126779, 784)
snowman
(337081, 784)
chair
(219567, 784)
cloud
(118882, 784)
giraffe
(125214, 784)
axe
(122798, 784)
matches
(136000, 784)
aircraft carrier
(111974, 784)
camel
(120620, 784)
saxophone
(115938, 784)
streetlight
(121459, 784)
drums
(133960, 784)
camouflage
(134987, 784)
grass
(120876, 784)
snorkel
(150871, 784)
laptop
(252804, 784)
hot tub
(113760, 784)
car
(179553, 784)
passport
(146016, 784)
flying saucer
(146286, 784)
lobster
(136723, 784)
cactus
(129078, 784)
apple
(143992, 784)
helicopter
(156378, 784)
compass
(126530, 784)
pear
(116156, 784)
cannon
(136894, 784)
spider
(200796, 784)
fan
(132048, 784)
bandage
(142039, 784)
cruise ship
(119

In [36]:
# save data
with open(labelpath.format('train'), 'w') as fd:
    json.dump(category_train, fd)
with open(labelpath.format('val'), 'w') as fd:
    json.dump(category_val, fd)
with open(labelpath.format('test'), 'w') as fd:
    json.dump(category_test, fd)
    
np.save(datapath.format('train'), image_train)
np.save(datapath.format('val'), image_val)
np.save(datapath.format('test'), image_test)

In [37]:
print('image_train: ' + str(len(image_train)))
print('image_val: ' + str(len(image_val)))
print('image_test: ' + str(len(image_test)))
del image_train, image_val, image_test
len(name_dict)

image_train: 12871701
image_val: 500000
image_test: 1000000


100

# Preprocessing NDJSON file

In [38]:
def check_outrange(np_data, i, marker):
    if marker + 1 >= MAX_LENGTH:
        np_data[i,marker,0] = 0
        np_data[i,marker,1] = 0
        return True
    return False

In [39]:
def clamp(n, minn, maxn):
    return max(min(maxn, n), minn)

In [40]:
# preprocessing
for k in range(len(names)):
    op = open(simplify_path+names[k]+'.ndjson')
    data = ndjson.load(op)
    data_size = len(data)
    np_data = np.zeros((data_size,MAX_LENGTH,2), dtype=np.uint8)
    reverse_data = np.zeros((data_size,MAX_LENGTH,2), dtype=np.uint8)
    print('saving '+names[k])
    for i in range(len(data)):
        instance = data[i]['drawing']
        marker = 0
        for stroke in instance:
            for point in range(len(stroke[0])):
                # out of range, mannually end
                if check_outrange(np_data, i, marker):
                    continue
                else:
                    np_data[i,marker,0] = clamp(stroke[0][point], 1, 254)
                    np_data[i,marker,1] = clamp(stroke[1][point], 1, 254)
                    marker += 1
            # End of stroke
            if check_outrange(np_data, i, marker):
                continue
            else:
                np_data[i,marker,0] = 255
                np_data[i,marker,1] = 255
                marker += 1
        if check_outrange(np_data, i, marker):
            continue
        else:
            np_data[i,marker,0] = 0
            np_data[i,marker,1] = 0
        idx = np.where(np_data[i][:,0] == 0)[0][0]
        reverse_data[i][-(idx+1):, 0] = np_data[i][:idx+1, 0]
        reverse_data[i][-(idx+1):, 1] = np_data[i][:idx+1, 1]
    np.save(simplify_path+names[k]+'.npy', np_data[select_mask[names[k]]])
    np.save(simplify_path+names[k]+'_new.npy', reverse_data[select_mask[names[k]]])
    print(np_data[select_mask[names[k]]].shape)

saving squiggle
(111792, 200, 2)
saving bread
(119123, 200, 2)
saving violin
(213951, 200, 2)
saving bush
(107552, 200, 2)
saving eyeglasses
(222123, 200, 2)
saving soccer ball
(112554, 200, 2)
saving string bean
(111364, 200, 2)
saving shovel
(116184, 200, 2)
saving zebra
(140548, 200, 2)
saving kangaroo
(171430, 200, 2)
saving spoon
(123767, 200, 2)
saving submarine
(121899, 200, 2)
saving underwear
(123378, 200, 2)
saving hot air balloon
(125357, 200, 2)
saving pickup truck
(126779, 200, 2)
saving snowman
(337081, 200, 2)
saving chair
(219567, 200, 2)
saving cloud
(118882, 200, 2)
saving giraffe
(125214, 200, 2)
saving axe
(122798, 200, 2)
saving matches
(136000, 200, 2)
saving aircraft carrier
(111974, 200, 2)
saving camel
(120620, 200, 2)
saving saxophone
(115938, 200, 2)
saving streetlight
(121459, 200, 2)
saving drums
(133960, 200, 2)
saving camouflage
(134987, 200, 2)
saving grass
(120876, 200, 2)
saving snorkel
(150871, 200, 2)
saving laptop
(252804, 200, 2)
saving hot tub
(11

In [43]:
# generate category, data
sequence_train = None
sequence_test = None
sequence_val = None

for i in range(len(names)):
    print('processing '+names[i])
    np_data = np.load(simplify_path+names[i]+'_new.npy')
    sequence_test = np.concatenate((sequence_test, np_data[-testsize:]), axis=0) if sequence_test is not None else np_data[-testsize:]
    sequence_val = np.concatenate((sequence_val, np_data[-testsize-valsize:-testsize]), axis=0) if sequence_val is not None else np_data[-testsize-valsize:-testsize]
    sequence_train = np.concatenate((sequence_train, np_data[:-testsize-valsize]), axis=0) if sequence_train is not None else np_data[:-testsize-valsize]

np.save(sequencepath.format('train'), sequence_train)
np.save(sequencepath.format('val'), sequence_val)
np.save(sequencepath.format('test'), sequence_test)

processingsquiggle
processingbread
processingviolin
processingbush
processingeyeglasses
processingsoccer ball
processingstring bean
processingshovel
processingzebra
processingkangaroo
processingspoon
processingsubmarine
processingunderwear
processinghot air balloon
processingpickup truck
processingsnowman
processingchair
processingcloud
processinggiraffe
processingaxe
processingmatches
processingaircraft carrier
processingcamel
processingsaxophone
processingstreetlight
processingdrums
processingcamouflage
processinggrass
processingsnorkel
processinglaptop
processinghot tub
processingcar
processingpassport
processingflying saucer
processinglobster
processingcactus
processingapple
processinghelicopter
processingcompass
processingpear
processingcannon
processingspider
processingfan
processingbandage
processingcruise ship
processingblueberry
processingcrab
processingelbow
processingcooler
processingcircle
processingsleeping bag
processingyoga
processingowl
processingcup
processingbackpack


In [44]:
print('sequence_train: ' + str(len(sequence_train)))
print('sequence_val: ' + str(len(sequence_val)))
print('sequence_test: ' + str(len(sequence_test)))

sequence_train: 12871701
sequence_val: 500000
sequence_test: 1000000


In [45]:
len(names)

100

In [None]:
del sequence_train, sequence_val, sequence_test