In [5]:
import torch
import numpy as np
import lmdb
import os
import cv2
import pickle
from tqdm import  tqdm

In [6]:
path = "./QuickDraw"

In [7]:
dir_list = sorted(os.listdir(path))

In [9]:
for i,d in enumerate(dir_list):
    print(str(i)+":"+d)

0:The Eiffel Tower.npz
1:The Great Wall of China.npz
2:The Mona Lisa.npz
3:aircraft carrier.npz
4:airplane.npz
5:alarm clock.npz
6:ambulance.npz
7:angel.npz
8:animal migration.npz
9:ant.npz
10:anvil.npz
11:apple.npz
12:arm.npz
13:asparagus.npz
14:axe.npz
15:backpack.npz
16:banana.npz
17:bandage.npz
18:barn.npz
19:baseball bat.npz
20:baseball.npz
21:basket.npz
22:basketball.npz
23:bat.npz
24:bathtub.npz
25:beach.npz
26:bear.npz
27:beard.npz
28:bed.npz
29:bee.npz
30:belt.npz
31:bench.npz
32:bicycle.npz
33:binoculars.npz
34:bird.npz
35:birthday cake.npz
36:blackberry.npz
37:blueberry.npz
38:book.npz
39:boomerang.npz
40:bottlecap.npz
41:bowtie.npz
42:bracelet.npz
43:brain.npz
44:bread.npz
45:bridge.npz
46:broccoli.npz
47:broom.npz
48:bucket.npz
49:bulldozer.npz
50:bus.npz
51:bush.npz
52:butterfly.npz
53:cactus.npz
54:cake.npz
55:calculator.npz
56:calendar.npz
57:camel.npz
58:camera.npz
59:camouflage.npz
60:campfire.npz
61:candle.npz
62:cannon.npz
63:canoe.npz
64:car.npz
65:carrot.npz
66:ca

In [5]:
def off2abs(off_vector, img_size=128):
    new_sketch = np.zeros_like(off_vector)
    new_sketch[:,2] = off_vector[:,2]
    a = np.cumsum(off_vector[:,:2], axis=0)
    h_max = np.max(a[:,1])-np.min(a[:,1])
    w_max = np.max(a[:,0])-np.min(a[:,0])
    if h_max > w_max:
        scale_factor = h_max
    else:
        scale_factor = w_max
    new_sketch[:,0] = (a[:,0]-np.min(a[:,0]))/scale_factor*256
    new_sketch[:,1] = (a[:,1]-np.min(a[:,1]))/scale_factor*256
    return new_sketch

In [None]:
train_idx = 0
valid_idx = 0
test_idx = 0

train_label={}
valid_label={}
test_label = {}

train_env = lmdb.open(f"./Data/train_QuickDraw",map_size = 2*2**33)
test_env = lmdb.open(f"./Data/test_QuickDraw",map_size = 2**29)
valid_env = lmdb.open(f"./Data/valid_QuickDraw",map_size = 2**29)

for label,file in enumerate(dir_list):
    np_file = np.load(os.path.join(path,file),allow_pickle=True,encoding='latin1')
    category = file.split('.')[0]
    
    print(f"category: {category}, train_data.")
        
    for sketch in tqdm(np_file['train']):
        abs_sketch = off2abs(sketch)
        txn = train_env.begin(write=True)
        txn.put(key=str(train_idx).encode(), value=abs_sketch)
        txn.commit()
        train_label[str(train_idx)] = label
        train_idx += 1
        
    print(f"category: {category}, test_data.")
    for sketch in tqdm(np_file['test']):
        abs_sketch = off2abs(sketch)
        txn = test_env.begin(write=True)
        txn.put(key=str(test_idx).encode(), value=abs_sketch)
        txn.commit()
        test_label[str(test_idx)] = label
        test_idx += 1
        
    print(f"category: {category}, valid_data.")
    for sketch in tqdm(np_file['valid']):
        abs_sketch = off2abs(sketch)
        txn = valid_env.begin(write=True)
        txn.put(key=str(valid_idx).encode(), value=abs_sketch)
        txn.commit()
        valid_label[str(valid_idx)] = label
        valid_idx += 1
        
train_env.close()
valid_env.close()
test_env.close()

In [7]:
with open('./Data/train_QuickDraw.pkl', 'wb') as f:   
    pickle.dump(train_label,f)

with open('./Data/valid_QuickDraw.pkl', 'wb') as f:
    pickle.dump(valid_label,f)
    
with open('./Data/test_QuickDraw.pkl', 'wb') as f:
    pickle.dump(test_label,f)

In [129]:
import random
def draw_three(sketch, random_color=False,img_size=256):
    thickness = int(img_size * 0.025)
    canvas = np.ones((img_size, img_size, 3), dtype='uint8') * 255
    if random_color:
        color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    else:
        color = (0, 0, 0)
    pen_now = np.array([sketch[0,0], sketch[0,1]])
    first_zero = False
    i = 0
    j = 4
    for stroke in sketch:
        delta_x_y = stroke[0:0 + 2]-pen_now
        state = stroke[2:]
        if int(state) == -1:
            break
        if first_zero:  # 首个零是偏移量, 不画
            pen_now += delta_x_y
            first_zero = False
            continue
        if i==j:
            cv2.line(canvas, tuple(pen_now), tuple(pen_now + delta_x_y), color, thickness=thickness)
        if random_color:
                color = (color[0]+10, color[1]+10, color[2]+10)
        if int(state) == 1:  # next stroke
            i=i+1
            first_zero = True
            if random_color:
                color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            else:
                color = (0, 0, 0)
        pen_now += delta_x_y
        
    return cv2.resize(canvas, (img_size, img_size))

In [130]:
train_env = lmdb.open(f"./Data/train_QuickDraw",map_size = 2*2**33)
test_env = lmdb.open(f"./Data/test_QuickDraw",map_size = 2**29)
valid_env = lmdb.open(f"./Data/valid_QuickDraw",map_size = 2**29)

txn = test_env.begin(write=True)
test_sketch = np.fromstring(txn.get(key=str(18500).encode()), dtype=np.uint16).reshape(-1,3)
test_sketch[:,:2] = test_sketch[:,:2]

  test_sketch = np.fromstring(txn.get(key=str(18500).encode()), dtype=np.uint16).reshape(-1,3)


In [131]:
im = -np.ones([256,3], dtype=np.int16)
im[:len(test_sketch),:] = test_sketch
raster_imgae = draw_three(im)

from PIL import Image
img = Image.fromarray(raster_imgae, 'RGB')
img.save('test2.png')
img.show()

In [11]:
raster_imgae.shape

(256, 256, 3)

In [12]:
im

array([[ 76,  22,   0],
       [ 71,  67,   0],
       [ 59, 115,   0],
       [ 51, 135,   0],
       [ 47, 150,   0],
       [ 38, 176,   0],
       [ 32, 190,   0],
       [ 21, 208,   0],
       [ 17, 216,   0],
       [  4, 237,   0],
       [  0, 248,   1],
       [ 82,   0,   0],
       [ 83,   6,   0],
       [ 96,  47,   0],
       [104,  62,   0],
       [122,  90,   0],
       [129, 105,   0],
       [144, 131,   0],
       [151, 152,   0],
       [162, 173,   0],
       [166, 186,   0],
       [166, 205,   0],
       [164, 211,   0],
       [164, 229,   0],
       [169, 241,   0],
       [169, 248,   0],
       [173, 256,   1],
       [133, 246,   0],
       [134, 242,   0],
       [134, 231,   0],
       [132, 216,   0],
       [128, 203,   0],
       [120, 186,   0],
       [109, 174,   0],
       [ 93, 166,   0],
       [ 76, 170,   0],
       [ 58, 178,   0],
       [ 50, 183,   0],
       [ 34, 207,   0],
       [ 29, 216,   0],
       [ 25, 230,   0],
       [ 24, 239

In [13]:
im/1.0

array([[ 76.,  22.,   0.],
       [ 71.,  67.,   0.],
       [ 59., 115.,   0.],
       [ 51., 135.,   0.],
       [ 47., 150.,   0.],
       [ 38., 176.,   0.],
       [ 32., 190.,   0.],
       [ 21., 208.,   0.],
       [ 17., 216.,   0.],
       [  4., 237.,   0.],
       [  0., 248.,   1.],
       [ 82.,   0.,   0.],
       [ 83.,   6.,   0.],
       [ 96.,  47.,   0.],
       [104.,  62.,   0.],
       [122.,  90.,   0.],
       [129., 105.,   0.],
       [144., 131.,   0.],
       [151., 152.,   0.],
       [162., 173.,   0.],
       [166., 186.,   0.],
       [166., 205.,   0.],
       [164., 211.,   0.],
       [164., 229.,   0.],
       [169., 241.,   0.],
       [169., 248.,   0.],
       [173., 256.,   1.],
       [133., 246.,   0.],
       [134., 242.,   0.],
       [134., 231.,   0.],
       [132., 216.,   0.],
       [128., 203.,   0.],
       [120., 186.,   0.],
       [109., 174.,   0.],
       [ 93., 166.,   0.],
       [ 76., 170.,   0.],
       [ 58., 178.,   0.],
 

In [14]:
im

array([[ 76,  22,   0],
       [ 71,  67,   0],
       [ 59, 115,   0],
       [ 51, 135,   0],
       [ 47, 150,   0],
       [ 38, 176,   0],
       [ 32, 190,   0],
       [ 21, 208,   0],
       [ 17, 216,   0],
       [  4, 237,   0],
       [  0, 248,   1],
       [ 82,   0,   0],
       [ 83,   6,   0],
       [ 96,  47,   0],
       [104,  62,   0],
       [122,  90,   0],
       [129, 105,   0],
       [144, 131,   0],
       [151, 152,   0],
       [162, 173,   0],
       [166, 186,   0],
       [166, 205,   0],
       [164, 211,   0],
       [164, 229,   0],
       [169, 241,   0],
       [169, 248,   0],
       [173, 256,   1],
       [133, 246,   0],
       [134, 242,   0],
       [134, 231,   0],
       [132, 216,   0],
       [128, 203,   0],
       [120, 186,   0],
       [109, 174,   0],
       [ 93, 166,   0],
       [ 76, 170,   0],
       [ 58, 178,   0],
       [ 50, 183,   0],
       [ 34, 207,   0],
       [ 29, 216,   0],
       [ 25, 230,   0],
       [ 24, 239