## Initialization

In [None]:
import h5py
import numpy as np
import os
import os.path
from pathlib import Path
import pickle
import skimage
from skimage import io as skio
from skimage import measure as skm
import sys
import time
import torch
import warnings
warnings.filterwarnings("ignore")
os.environ['TORCH_EXTENSIONS_DIR'] = '/tmp/torch_extensions_jypyter'

_project_folder_ = os.path.realpath(os.path.abspath('..'))
if _project_folder_ not in sys.path:
    sys.path.insert(0, _project_folder_)
from data.sketch_util import SketchUtil
from neuralline.rasterize import Raster


In [None]:
# Arguments
quickdraw_root = '/media/hdd/craiglee/Data/sketchrnn/sketchrnn_data/'
output_root = '/media/hdd/craiglee/Data/sketchrnn_processed_entropyfilter/'

if not os.path.exists(output_root):
    os.makedirs(output_root)

img_size = 128
thickness = 1.0
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('[*] Using device: {}'.format(device))


## Helper Functions

In [None]:
def get_categories():
    res = [npz_file.stem[:-5] for npz_file in list(Path(quickdraw_root).glob('*.full.npz'))]
    return sorted(res, key=lambda s: s.lower())

def load_npz(file_path):
    npz = np.load(file_path, encoding='latin1')
    return npz['train'], npz['valid'], npz['test']

def cvrt_points3(points3_array):
    # Make a copy and convert dtype
    points3 = np.array(points3_array, dtype=np.int32)
    points3[:, 0:2] = np.cumsum(points3[:, 0:2], axis=0)
    return points3

def cvrt_category_to_points3(points3_arrays, hdf5_group=None):
    max_num_points = 0
    res = []
    for pts3_arr in points3_arrays:
        if len(pts3_arr) < 3:
            continue
        pts3 = np.array(cvrt_points3(pts3_arr), np.float32)
        pts3_norm = SketchUtil.normalization(pts3[:, 0:2])
        if pts3_norm is None:
            continue
        pts3[:, 0:2] = pts3_norm
        
        npts3 = len(pts3)
        if npts3 > max_num_points:
            max_num_points = npts3
        
        if hdf5_group is not None:
            hdf5_group.create_dataset(str(len(res)), data=pts3)

        res.append(pts3)
    return res, max_num_points

def points3_to_img(points3_array):
    points3_array_gpu = torch.from_numpy(points3_array).to(device)
    img_gpu = Raster.to_image(torch.unsqueeze(points3_array_gpu, 0), 
                              1.0, img_size, thickness, device=device)
    img_cpu = np.ascontiguousarray(img_gpu.cpu().numpy())
    return np.array(255 * img_cpu[0, 0, :, :], np.uint8)

def denoise_by_entropy(points3_arrays, hdf5_group=None):
    # Ref:
    # - SketchMate Deep Hashing for Million-Scale Human Sketch Retrieval
    imgs = []
    entropies = []
    for pts3_arr in points3_arrays:
        img = points3_to_img(pts3_arr)
        imgs.append(img)
        entropies.append(skm.shannon_entropy(img))

    num_imgs = len(imgs)
    range_low = int(num_imgs * 0.05)
    range_high = int(num_imgs * 0.95)
    
    sort_indices = np.argsort(np.array(entropies)).tolist()

    pts3_to_discard = []
    imgs_to_discard = []
    
    nsketches = 0
    max_num_points = 0
    for idx, sidx in enumerate(sort_indices):
        if range_low <= idx < range_high and hdf5_group is not None:
            npts3 = len(points3_arrays[sidx])
            if npts3 > max_num_points:
                max_num_points = npts3

            hdf5_group.create_dataset(str(nsketches), data=points3_arrays[sidx])
            nsketches += 1
        else:
            pts3_to_discard.append(points3_arrays[sidx])
            imgs_to_discard.append(imgs[sidx])
    return nsketches, max_num_points, pts3_to_discard, imgs_to_discard


In [None]:
# https://stackoverflow.com/questions/35321093/limit-on-number-of-hdf5-datasets

category_names = get_categories()
print('[*] Number of categories = {}'.format(len(category_names)))
print('[*] ------')
print(category_names)
print('[*] ------')

hdf5_names = ['train', 'valid', 'test']
mode_indices = [list() for hn in hdf5_names]
hdf5_files = [h5py.File(os.path.join(output_root, 'quickdraw_{}.hdf5'.format(hn)), 'w', libver='latest') for hn in hdf5_names]
hdf5_groups = [h5.create_group('/sketch') for h5 in hdf5_files]

max_num_points = 0
for cid, category_name in enumerate(category_names):
    print('[*] Processing {}th category: {}'.format(cid + 1, category_name))

    # Open the npz file
    train_valid_test = load_npz(os.path.join(quickdraw_root, category_name + '.npz'))
    
    for mid, mode in enumerate(hdf5_names):
        # Under a mode to create a group: train/valid/test
        hdf5_category_group = hdf5_groups[mid].create_group(str(cid))

        if True:
            pts3_arrays, npts3 = cvrt_category_to_points3(train_valid_test[mid])

            # Denoise
            nsketches, npts3, pts3_discard, imgs_discard = denoise_by_entropy(pts3_arrays, hdf5_category_group)

            discard_folder = os.path.join(output_root, 'discard', mode, category_name)
            if not os.path.exists(discard_folder):
                os.makedirs(discard_folder)
            for img_id, img in enumerate(imgs_discard):
                skio.imsave('{}/{}.png'.format(discard_folder, img_id), img)
        else:
            pts3_arrays, npts3 = cvrt_category_to_points3(train_valid_test[mid], hdf5_category_group)
            nsketches = len(pts3_arrays)

        if npts3 > max_num_points:
            max_num_points = npts3

        hdf5_category_group.attrs['num_sketches'] = nsketches
        mode_indices[mid].extend(list(zip([cid] * nsketches, range(nsketches))))

for gid, gp in enumerate(hdf5_groups):
    gp.attrs['num_categories'] = len(category_names)
    gp.attrs['max_points'] = max_num_points

for hf in hdf5_files:
    hf.flush()
    hf.close()

pkl_save = {'categories': category_names, 'indices': mode_indices}
with open(os.path.join(output_root, 'categories.pkl'), 'wb') as fh:
    pickle.dump(pkl_save, fh, pickle.HIGHEST_PROTOCOL)    

print('max_num_points = {}'.format(max_num_points))
print('All done.')
