In [26]:
import os
import numpy as np
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET

import pandas as pd

from tqdm import trange, tqdm

In [27]:
# config path
level = 'word' # 'word', 'line', 'paragraph'
icfhr_datasplit_dir = './VNOnDB_ICFHR2018_dataSplit'
inkml_dir = f'./InkData_{level}'

# output
out_label_train = f'./train_{level}.csv'
out_label_validation = f'./validation_{level}.csv'
out_label_test = f'./test_{level}.csv'

In [28]:
train_set = os.path.join(icfhr_datasplit_dir, 'train_set.txt')
val_set = os.path.join(icfhr_datasplit_dir, 'validation_set.txt')
test_set = os.path.join(icfhr_datasplit_dir, 'test_set.txt')

In [29]:
with open(train_set) as f:
    train_ink_files = [line.rstrip() for line in f]
with open(val_set) as f:
    val_ink_files = [line.rstrip() for line in f]
with open(test_set) as f:
    test_ink_files = [line.rstrip() for line in f]

print('train_ink_files:', len(train_ink_files))
print('val_ink_files:', len(val_ink_files))
print('test_ink_files:', len(test_ink_files))

train_ink_files: 153
val_ink_files: 38
test_ink_files: 64


In [30]:
train_ink_files = [os.path.join(inkml_dir, f) for f in train_ink_files]
val_ink_files = [os.path.join(inkml_dir, f) for f in val_ink_files]
test_ink_files = [os.path.join(inkml_dir, f) for f in test_ink_files]

In [31]:
def make_image_file(coord_groups, output_path: str, line_width=2, dpi=300):
    figure = plt.figure(dpi=dpi)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.gca().invert_yaxis()
    plt.axis('off')

    for group in coord_groups:
        data = np.array(group)
        x, y = zip(*data)
        plt.plot(x, y, linewidth=line_width, c='black')
    figure.savefig(output_path, bbox_inches='tight')
    plt.close()

In [32]:
def convert(ink_files, out_img_dir, out_label_path, line_width=2, dpi=300):
    if not os.path.exists(out_img_dir):
        os.mkdir(out_img_dir)
        
    annotations = pd.DataFrame(columns=['id', 'label'])
    total_files = len(ink_files)
    for _, inkml_file in zip(trange(len(ink_files), desc='Progress'), ink_files):
        tree = ET.parse(inkml_file)
        root = tree.getroot()

        for sample in root.findall('traceGroup'):
            sample_id = os.path.splitext(os.path.basename(inkml_file))[0] + '_' + sample.get('id')
            sample_label = sample.find('.//Tg_Truth').text
            annotations = annotations.append({'id': sample_id, 'label': sample_label}, ignore_index=True)

            coord_groups = []
            for trace_tag in sample.findall('trace'):
                coord_group = []
                for coord_text in trace_tag.text.split(','):
                    if coord_text == '':
                        continue
                    coords = coord_text.split(' ')
                    coords = np.array([int(coord) for coord in coords if coord != ''])
                    assert len(coords) == 2
                    coord_group.append(coords)
                coord_groups.append(coord_group)
            make_image_file(coord_groups, os.path.join(out_img_dir, sample_id) + '.png', line_width, dpi)
    annotations.to_csv(out_label_path, sep='\t')

In [33]:
def convert_label_only(ink_files, out_label_path):
    annotations = pd.DataFrame(columns=['id', 'label'])
    total_files = len(ink_files)
    for _, inkml_file in zip(trange(len(ink_files), desc='Progress'), ink_files):
        tree = ET.parse(inkml_file)
        root = tree.getroot()

        for sample in root.findall('traceGroup'):
            sample_id = os.path.splitext(os.path.basename(inkml_file))[0] + '_' + sample.get('id')
            sample_label = sample.find('.//Tg_Truth').text
            annotations = annotations.append({'id': sample_id, 'label': sample_label}, ignore_index=True)

    annotations.to_csv(out_label_path, sep='\t')

In [34]:
convert_label_only(train_ink_files, out_label_train)
convert_label_only(val_ink_files, out_label_validation)
convert_label_only(test_ink_files, out_label_test)

Progress: 100%|██████████| 153/153 [06:34<00:00,  2.58s/it]
Progress: 100%|██████████| 38/38 [00:55<00:00,  1.46s/it]
Progress: 100%|██████████| 64/64 [01:24<00:00,  1.33s/it]
