In [42]:
import numpy as np
import os
import pickle
from pathlib import Path
from tqdm.notebook import tqdm

In [6]:
crohme_output_dir = os.path.join(os.getcwd(), 'data', 'crohme_extractor_outputs')

In [43]:
def merge(train_dirs, test_dirs):    
    all_class_files = []
    all_class_files.extend([os.path.join(d, 'classes.txt') for d in train_dirs])
    all_class_files.extend([os.path.join(d, 'classes.txt') for d in test_dirs])
    
    # Merge all the tokens from the different classes.txt into a single mapping.
    # This is to account for the fact that crohme extractor ignores some traces
    # if they're too small, so different stroke widths may have different classes.
    all_tokens = set()
    for p in all_class_files:
        curr_classes = classes = np.array(open(p, 'r').read().split())
        all_tokens.update(curr_classes)
        
    # Unified token-to-int mapping across all datasets.
    token_to_int = {t:i for i, t in enumerate(sorted(list(all_tokens)))}
    
    
    # Merge training datasets.
    merged_train_data = []
    for d in tqdm(train_dirs, desc='Merge train'):
        classes_path = os.path.join(d, 'classes.txt')
        data_path = os.path.join(d, 'train', 'train.pickle')
        
        curr_classes = classes = np.array(open(classes_path, 'r').read().split())
        
        with open(data_path, 'rb') as f:
            data = pickle.load(f)
        
        # Convert labels from one-hot vectors -> the actual token string.
        for row in data:
            class_idx = np.argmax(row['label'])
            class_token = classes[class_idx]
            row['label'] = token_to_int[class_token]
            
            merged_train_data.append(row)
            
    
    # Merge test datasets.
    merged_test_data = []
    for d in tqdm(test_dirs, desc='Merge test'):
        classes_path = os.path.join(d, 'classes.txt')
        data_path = os.path.join(d, 'train', 'train.pickle')
        
        curr_classes = classes = np.array(open(classes_path, 'r').read().split())
        
        with open(data_path, 'rb') as f:
            data = pickle.load(f)
        
        # Convert labels from one-hot vectors -> the actual token string.
        for row in data:
            class_idx = np.argmax(row['label'])
            class_token = classes[class_idx]
            row['label'] = token_to_int[class_token]
            
            merged_test_data.append(row)
            
            
    int_to_token = {i:t for t, i in token_to_int.items()}
            
    return merged_train_data, merged_test_data, int_to_token

def create_token_dataset(t_vals):
    t_val_str = ",".join(str(t) for t in t_vals)
    save_dir = os.path.join(os.getcwd(), 'data', 'tokens', f'b=96_train=2011,2013_test=2012_c=all_t={t_val_str}')
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    
    train_dirs = [os.path.join(crohme_output_dir, f'b=96_d=2011,2013_t={t}') for t in t_vals]
    test_dirs = [os.path.join(crohme_output_dir, f'b=96_d=2012_t={t}') for t in t_vals]
    
    merged_train_data, merged_test_data, int_to_token = merge(train_dirs=train_dirs, test_dirs=test_dirs)
    
    print(f'{len(merged_train_data)} training examples.')
    print(f'{len(merged_test_data)} training examples.')
    print(f'{len(int_to_token)} total classes.')
    
    train_write_path = os.path.join(save_dir, 'train.pickle')
    test_write_path = os.path.join(save_dir, 'test.pickle')
    int_to_token_write_path = os.path.join(save_dir, 'int_to_token.pickle')
    
    with open(train_write_path, 'wb') as f:
        pickle.dump(merged_train_data, f)
        
    with open(test_write_path, 'wb') as f:
        pickle.dump(merged_test_data, f)
        
    with open(int_to_token_write_path, 'wb') as f:
        pickle.dump(int_to_token, f)
              
    print(f'Wrote train set to {train_write_path}.')
    print(f'Wrote test set to {test_write_path}.')
    print(f'Wrote int-to-token dict to {int_to_token_write_path}.')

In [44]:
%%time

create_token_dataset(t_vals=[5])

HBox(children=(HTML(value='Merge train'), FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(HTML(value='Merge test'), FloatProgress(value=0.0, max=1.0), HTML(value='')))


61984 training examples.
16707 training examples.
101 total classes.
Wrote train set to C:\Users\Jamin Chen\Development\10617_Project\data\tokens\b=96_train=2011,2013_test=2012_c=all_t=5\train.pickle.
Wrote test set to C:\Users\Jamin Chen\Development\10617_Project\data\tokens\b=96_train=2011,2013_test=2012_c=all_t=5\test.pickle.
Wrote int-to-token dict to C:\Users\Jamin Chen\Development\10617_Project\data\tokens\b=96_train=2011,2013_test=2012_c=all_t=5\int_to_token.pickle.
Wall time: 6.7 s


In [45]:
%%time

create_token_dataset(t_vals=[3, 5, 7])

HBox(children=(HTML(value='Merge train'), FloatProgress(value=0.0, max=3.0), HTML(value='')))




HBox(children=(HTML(value='Merge test'), FloatProgress(value=0.0, max=3.0), HTML(value='')))


185952 training examples.
50121 training examples.
101 total classes.
Wrote train set to C:\Users\Jamin Chen\Development\10617_Project\data\tokens\b=96_train=2011,2013_test=2012_c=all_t=3,5,7\train.pickle.
Wrote test set to C:\Users\Jamin Chen\Development\10617_Project\data\tokens\b=96_train=2011,2013_test=2012_c=all_t=3,5,7\test.pickle.
Wrote int-to-token dict to C:\Users\Jamin Chen\Development\10617_Project\data\tokens\b=96_train=2011,2013_test=2012_c=all_t=3,5,7\int_to_token.pickle.
Wall time: 17 s


In [46]:
%%time

create_token_dataset(t_vals=[1, 3, 5, 7, 9])

HBox(children=(HTML(value='Merge train'), FloatProgress(value=0.0, max=5.0), HTML(value='')))




HBox(children=(HTML(value='Merge test'), FloatProgress(value=0.0, max=5.0), HTML(value='')))


309828 training examples.
83526 training examples.
101 total classes.
Wrote train set to C:\Users\Jamin Chen\Development\10617_Project\data\tokens\b=96_train=2011,2013_test=2012_c=all_t=1,3,5,7,9\train.pickle.
Wrote test set to C:\Users\Jamin Chen\Development\10617_Project\data\tokens\b=96_train=2011,2013_test=2012_c=all_t=1,3,5,7,9\test.pickle.
Wrote int-to-token dict to C:\Users\Jamin Chen\Development\10617_Project\data\tokens\b=96_train=2011,2013_test=2012_c=all_t=1,3,5,7,9\int_to_token.pickle.
Wall time: 30.2 s
