## Initialization

In [1]:
import numpy as np
import os
import sys
from pathlib import Path
import pickle
#import skimage
#from skimage import io as skio
#from skimage import measure as skm
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

_project_folder_ = os.path.realpath(os.path.abspath('..'))
if _project_folder_ not in sys.path:
    sys.path.insert(0, _project_folder_)

from data.sketch_util import SketchUtil


In [2]:
# Arguments
# dataset_root = '/media/WD1TDisk/craiglee/TUBerlin/svg'
# output_root = '/media/WD1TDisk/craiglee/TUBerlin'
dataset_root = 'I:/TU-Berlin/svg'
output_root = 'I:/TU-Berlin'

num_folds = 3

if not os.path.exists(output_root):
    os.makedirs(output_root)
else:
    print('Output root already exists.')


Output root already exists.


## Helper functions

In [3]:
def list_categories(root_folder):
    res = list()
    for subfolder in Path(root_folder).iterdir():
        if subfolder.is_dir() and not subfolder.name.startswith('.'):
            res.append(subfolder.name)
    return sorted(res)

def list_svg_files(root_folder):
    svg_list = [p.name for p in list(Path(root_folder).glob('*.svg'))]
    return sorted(svg_list)

def strokes_to_points3(strokes):
    states = list()
    for stroke in strokes:
        state = np.zeros((len(stroke),), np.float32)
        state[-1] = 1
        states.append(np.expand_dims(state, axis=1))
    res = np.concatenate((np.concatenate(strokes), np.concatenate(states)), axis=1)
    return res
    

## Data processing

In [5]:
categories = list_categories(dataset_root)
print('Number of categories: {}'.format(len(categories)))

folds = [list() for i in range(num_folds)]
sketches = list()
cvxhulls = list()

MAX_POINTS = 448

max_num_points = 0
rdp_eps = 0.02

for cid, category in enumerate(categories):
    print('Processing {} - {}'.format(cid, category))
    
    cat_sketches = list()
    cat_cvxhulls = list()
    svg_files = list_svg_files(os.path.join(dataset_root, category))
    for svg_file in svg_files:
        try:
            strokes = SketchUtil.parse_tuberlin_svg_file(os.path.join(dataset_root, category, svg_file))
        except Exception as e:
            print('Something wrong with {}/{}'.format(category, svg_file))
            raise e

        strokes = SketchUtil.normalize_and_simplify(strokes, MAX_POINTS, rdp_eps)
        assert strokes is not None
        points3 = strokes_to_points3(strokes)
        
        num_points = len(points3)
        if num_points > max_num_points:
            max_num_points = num_points
        if num_points > 2:
            cat_cvxhulls.append(SketchUtil.convex_hull_padded(points3[:, 0:2]))
        else:
            cat_cvxhulls.append(None)
        cat_sketches.append(points3)
    
    sketches.append(cat_sketches)
    cvxhulls.append(cat_cvxhulls)
    
    print('  Max number of points = {}'.format(max_num_points))
    
    # Fold split
    idxes = np.arange(len(cat_sketches))
    np.random.shuffle(idxes)
    idxes_split = np.array_split(idxes, num_folds)
    for fidx in range(num_folds):
        folds[fidx].extend(list(zip([cid] * len(idxes_split[fidx]), idxes_split[fidx].tolist())))

print('Max number of points = {}'.format(max_num_points))
    
to_save = {'categories': categories,
           'sketches': sketches,
           'convex_hulls': cvxhulls,
           'folds': folds,
           'max_num_points': max_num_points}
with open(os.path.join(output_root, 'TUBerlin.pkl'), 'wb') as fh:
    pickle.dump(to_save, fh, pickle.HIGHEST_PROTOCOL) 


Number of categories: 250
Processing 0 - airplane
  Max number of points = 441
Processing 1 - alarm clock
  Max number of points = 445
Processing 2 - angel
  Max number of points = 445
Processing 3 - ant
  Max number of points = 446
Processing 4 - apple
  Max number of points = 446
Processing 5 - arm
  Max number of points = 446
Processing 6 - armchair
  Max number of points = 448
Processing 7 - ashtray
  Max number of points = 448
Processing 8 - axe
  Max number of points = 448
Processing 9 - backpack
  Max number of points = 448
Processing 10 - banana
  Max number of points = 448
Processing 11 - barn
  Max number of points = 448
Processing 12 - baseball bat
  Max number of points = 448
Processing 13 - basket
  Max number of points = 448
Processing 14 - bathtub
  Max number of points = 448
Processing 15 - bear (animal)
  Max number of points = 448
Processing 16 - bed
  Max number of points = 448
Processing 17 - bee
  Max number of points = 448
Processing 18 - beer-mug
  Max number of 