# Training data

## Reading the data

In this notebook we will load the data that is used for training and convert it to the input representation that we need to train the network. The data is acquired here: "https://fki.tic.heia-fr.ch/databases/iam-on-line-handwriting-database".

The dataset consists of two seperate directories. One directory contains $x$ and $y$ coordinates and corresponding timestamps in xml format for each seperate line of a text. The other directory contains the texts. These two directories are unpacked in the "./data/raw_data".

In [None]:
from pathlib import Path

path_strokefiles = Path("./../../data/raw_data/strokefiles")
path_textfiles = Path("./../../data/raw_data/textfiles")

In [None]:
import os

textfile_paths = [Path(dirpath + "/" + filenames[0])
                  for (dirpath, dirnames, filenames) in os.walk(path_textfiles)
                  if filenames != []] 

print(textfile_paths[:3])

First a function is defined to read the text lines from a text file, as these files need to be parsed and split up into lines, this is done with regex.

In [None]:
import re

example_textfile_path = textfile_paths[0]
print(f'exmaple path = {example_textfile_path}\n')
with open(example_textfile_path) as f:
    print(f.read())

def get_file_lines(textfile_path):
    f = open(textfile_path) 
    content = f.read()
    lines = re.search("CSR:\s*([^~]*)", content).group(1).strip().split("\n")
    return lines
    
example_textlines = get_file_lines(example_textfile_path)
print(example_textlines)

In [None]:
stroke_paths = [(dirpath, filenames)
                for (dirpath, dirnames, filenames) in os.walk(path_strokefiles)
                if filenames != []]

print(stroke_paths[:3])

Next, the paths of the xml files containing the $x$ and $y$ coordinates and timestamps for the lines in a certain text file are retrieved.

In [None]:
def get_stroke_paths(textfile_path):
    strokefiles_root_folder = path_strokefiles / textfile_path.parts[-3] / textfile_path.parts[-2]
    
    if not strokefiles_root_folder.is_dir():
        return None
    
    m = re.search("(.*?)-(.*)", textfile_path.stem)
        
    res = [strokefiles_root_folder / filename for filename in sorted(os.listdir(strokefiles_root_folder))         
           if re.search("(.*?)-(.*?)-.*", filename).groups() == m.groups()]

    if len(res) == 0:
        return None
    
    return res

example_textline_strokefile_paths = get_stroke_paths(example_textfile_path)
print(example_textline_strokefile_paths)

These files are then parse using the xml.etree.ElementTree module, this is a efficient python module that provides APIs to parse xml file.

In [None]:
import xml.etree.ElementTree as ET
import numpy as np

def read_file(filename):
    root = ET.parse(filename).getroot()
    
    strokes = [[[point.attrib["x"], point.attrib["y"], point.attrib["time"]]
                 for point in stroke.findall("./Point")]
                for stroke in root.findall("./StrokeSet/Stroke")]

    max_stroke_len = max(len(r) for r in strokes)
    
    s = np.zeros((len(strokes), max_stroke_len, 3))
    s[:, :, 2] -= 1

    for i, row in enumerate(strokes):
        s[i, :len(row)] = row

    return s

example_textline_strokes = [read_file(file) for file in example_textline_strokefile_paths]
print(example_textline_strokes[0])

In [None]:
import matplotlib.pyplot as plt

def plot_strokes(strokes):
    for stroke in strokes:
        plt.plot(stroke[:, 0][stroke[:, 2] >= 0], stroke[:, 1][stroke[:, 2] >= 0])

    plt.show()
        
for (textline, textline_strokes) in zip(example_textlines[:1], example_textline_strokes[:1]):
    print(textline)
    plot_strokes(textline_strokes)

## Encoding the data

The encoding alphabet is made by taking looking at all unique characters in the dataset.

In [None]:
import pickle

def get_alphabet(textfile_paths):
    all_chars = set()
    
    for file in textfile_paths:
        with open(file) as f:
            content = f.read()
            text_chars = set(re.search("CSR:\s*([^~]*)", content).group(1).strip())
            
            all_chars = all_chars.union(text_chars)
        
    return list(all_chars)

alphabet = get_alphabet(textfile_paths)

if Path("./../../data/processed_data/alphabet").is_file():
    with open("./../../data/processed_data/alphabet", "rb") as f:
        alphabet = pickle.load(f)

print(alphabet, len(alphabet))

In [None]:
def encode_textline(textline, alphabet):
    return [alphabet.index(c) for c in textline]

def decode_textline(encodedline, alphabet):
    return [alphabet[v] for v in encodedline]

Now all the data is normalized and stored as explained in the notebook Input representation. These function are imported using import_ipynb.

In [None]:
import import_ipynb
from Input_Representation import (
    normalize_strokes,
    convert_to_rtps,
    resample_strokes,
    SSE,
    makeSMatrix,
    newton_step,
    get_relative_distances,
    fit_curve_newton_step,
    fit_datapoints,
    get_control_points,
    parameterize_curve,
    length_vecs,
    dot_vecs,
    calc_angles,
    split_datapoints,
    stitch_curves,
    convert_stroke_to_bezier_curves,
    plot_bezier_curves,
    plot_rtps,
    strokes_to_bezier,
    scale_timestamps
)

In [None]:
rtp_features = []
bezier_features = []
target = []

for i, textfile_path in enumerate(textfile_paths):
    if i % 100 == 0:
        print(i/len(textfile_paths))

    textline_strokefile_paths = get_stroke_paths(textfile_path)
    
    if not textline_strokefile_paths:
        continue

    lines = [encode_textline(line, alphabet) for line in get_file_lines(textfile_path)]
    rtp_strokes = []
    bezier_strokes = []
    
    for i, textfile_stroke_path in enumerate(textline_strokefile_paths):
        rtp_strokes.append(convert_to_rtps(read_file(textfile_stroke_path)))
        bezier_strokes.append(strokes_to_bezier(read_file(textfile_stroke_path), precision=0.01))

    rtp_features.extend(rtp_strokes)
    bezier_features.extend(bezier_strokes)
    target.extend(lines)

In [None]:
print(len(rtp_features), len(target))

s = 0

for elem in bezier_features:
    s+= len(elem)
    
print(s/len(rtp_features))

In [None]:
sample = 1

plot_rtps(rtp_features[sample])

print(decode_textline(target[sample], alphabet))

In [None]:
plot_bezier_curves(bezier_features[sample][:30])

print(decode_textline(target[sample], alphabet))

In [None]:
print(len(rtp_features[0]))

## Storing the data

The data is then padded and stored as compressed numpy files using np.save(). The alphabet is also stored for later use when training the network.

In [None]:
def pad_data(l, value=0, width=None):
    max_len = max(len(item) for item in l)
    
    padded_numpy_array = None
    if width:
        padded_numpy_array = np.full((len(l), max_len, width), value, dtype=np.float32)
    else:
        padded_numpy_array = np.full((len(l), max_len), value, dtype=np.float32)
    
    for i, row in enumerate(l):
        padded_numpy_array[i, :len(row)] = row
    
    return padded_numpy_array
        
padded_bezier_features = pad_data(bezier_features, width=11)
print(padded_bezier_features[0])
np.save("../../data/processed_data/bezier_features_padded_high_precision", padded_bezier_features)

In [None]:
import pickle

with open("../../data/processed_data/alphabet", "wb") as f:
    pickle.dump(alphabet, f)