In [9]:
import os
import numpy as np
import pickle
import random
import xml.etree.cElementTree as et
import html
import matplotlib.pyplot as plt
import sys

# 1) Data preprocessing

XML data structure of the files :
```xml
<WhiteboardCaptureSession>
    <General> General Information </General>
    <Transcription>
        <Text> The text corresponding to the strokes </Text>
        <TextLine id="a01-001z-01" text="By Trevor Williams. A move">
              <Word id="a01-001z-01-01" text="By">
                <Char id="a01-001z-01-01-01" text="B"/>
                <Char id="a01-001z-01-01-02" text="y"/>
              </Word>
              .................. similar details for every word in the textline
        </TextLine>
        .................. similar details for every textline in the text
    </Transcription>
    <WhiteboardDescription>
        <SensorLocation corner="top_left"/>
        <DiagonallyOppositeCoords x="6912" y="8798"/>
        <VerticallyOppositeCoords x="214" y="8878"/>
        <HorizontallyOppositeCoords x="7038" y="196"/>
    </WhiteboardDescription>
    <StrokeSet>
        <Stroke colour="black" start_time="13090871.35" end_time="13090871.78">
            <Point x="895" y="992" time="13090871.35"/>
            ............ similar details of points in the stroke
        </Stroke>
        ............... similar details of stroke in the StrokeSet
    </StrokeSet>
</WhiteboardCaptureSession>
```
So firstly, we need to parse the data from the xml and arrange them in a proper format : 

In [10]:
def distance(point_1, point_2):
    return np.sqrt(np.sum(np.square(point_1 - point_2)))

# To make sure that the points are not too far away, and if they are, remove that piece of data
def clear_points(points):
    plot = False
    points_to_remove = set()
    for i in range(1, len(points) - 1):
        p1, p2, p3 = points[i - 1: i + 2, :2]
        
        # sum of distance between 3 sequential points in the stroke
        dis = distance(p1, p2) + distance(p2, p3)
        if dis > 1500:
            points_to_remove.add(i)
        
    valid_pts = []
    for i in range(len(points)):
        if i not in points_to_remove:
            valid_pts.append(points[i])
    
    return np.array(valid_pts)

# separates the points of the stroke into separate groups based on the distance between the points
# interprete them as separate strokes
# SHAPES : init = (30, 3) , zip = (14, 30) , final = [(14, 3), (16, 3)]
def separate(points):
    seps = []
    for i in range(0, len(points) - 1):
        if distance(points[i], points[i+1]) > 600:
            seps.append(i+1)
    z = zip([0] + seps, seps + [len(points)])
    
    final = [points[b:e] for b, e in zip([0] + seps, seps + [len(points)])]
    
    return final

In [24]:
data = []
characters = set()

# Surf through all the files in the dataset
for root, dirs, files in os.walk('.\data'):
    for file in files:
        if file.split('.')[-1] == 'xml':
            raw_data = et.parse(root + "\\" + file).getroot()
            transcription = raw_data.findall("Transcription")  # gives us a list, transcription object at transcription[0]
            strokeset = raw_data.findall('StrokeSet') # gives us a list, strokeset object at strokeset[0]
            
            if not transcription:
                continue
            
            # get the text
            ascii_text = [html.unescape(line.get('text')) for line in transcription[0].findall("TextLine")]
            
            # get every stroke as a list of points in the strokeset and adds it to a list
            strokes_init = [stroke.findall('Point') for stroke in strokeset[0].findall('Stroke')] 
            
            strokes = []
            midpoints = []
            chars = ' abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-\'#)/!",;?&:.(+%'
            charlist = [char for char in chars]
            
            for single_stroke in strokes_init:
                
                # returns the list of (x, y, eos) coordinates for a single stroke
                coords = np.array([[int(point.get('x')), int(point.get('y')), 0] for point in single_stroke])
                
                # assign the last element of the last point as 1 to indicate end of stroke
                coords[-1, 2] = 1
                
                # Now, we need to filter the data
                coords = clear_points(coords)
                
                if len(coords) == 0:
                    continue
                    
                seps = separate(coords)
                
                for stroke in seps:
                    # check for single points as strokes
                    if len(seps) > 1 and len(stroke) == 1:
                        continue
                    stroke[-1, 2] = 1
                    
                    # get the maximum value of x and y coordinates for each stroke
                    xmax, ymax = max(stroke, key = lambda x: x[0])[0], max(stroke, key = lambda x: x[1])[1]
                    
                    # get the minimum value of x and y coordinates for each stroke
                    xmin, ymin = min(stroke, key = lambda x : x[0])[0], min(stroke, key = lambda x: x[1])[1]
                    
                    strokes.append(stroke)
                    
                    # calculate the midpoints for each stroke
                    midpoints.append([(xmax + xmin)/2, (ymax + ymin)/2])
            
            # for every point in a single strokeset
            distances = [-abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) for p1, p2 in zip(midpoints, midpoints[1:])]
            splits = sorted(np.argsort(distances)[:len(ascii_text) - 1] + 1)
            lines = []
            
            for b, e in zip([0] + splits, splits + [len(strokes)]):
                lines.append([point for stroke in strokes[b:e] for point in stroke])
            
            data.append((ascii_text, lines))
            
# we assign indexes to every character in the dataset          
translation = {0 : '<NULL>'}
i = 1
for char in charlist:
    translation[i] = char
    i += 1

print(translation)

{0: '<NULL>', 1: ' ', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: 'i', 11: 'j', 12: 'k', 13: 'l', 14: 'm', 15: 'n', 16: 'o', 17: 'p', 18: 'q', 19: 'r', 20: 's', 21: 't', 22: 'u', 23: 'v', 24: 'w', 25: 'x', 26: 'y', 27: 'z', 28: 'A', 29: 'B', 30: 'C', 31: 'D', 32: 'E', 33: 'F', 34: 'G', 35: 'H', 36: 'I', 37: 'J', 38: 'K', 39: 'L', 40: 'M', 41: 'N', 42: 'O', 43: 'P', 44: 'Q', 45: 'R', 46: 'S', 47: 'T', 48: 'U', 49: 'V', 50: 'W', 51: 'X', 52: 'Y', 53: 'Z', 54: '1', 55: '2', 56: '3', 57: '4', 58: '5', 59: '6', 60: '7', 61: '8', 62: '9', 63: '0', 64: '-', 65: "'", 66: '#', 67: ')', 68: '/', 69: '!', 70: '"', 71: ',', 72: ';', 73: '?', 74: '&', 75: ':', 76: '.', 77: '(', 78: '+', 79: '%'}


In [26]:
def translate(text, translation):
    final_list = []
    for i in range(len(text)):
        try:
            final_list.append(translation[text[i]])
        except KeyError:
            # keep adding new characters to the translation
            translation[text[i]] = len(translation)
            final_list.append(translation[text[i]])
    return translation, final_list
    
dataset = []
labels = []

for texts, lines in data:
    for text, line in zip(texts, lines):
        line = np.array(line, dtype = np.float32)
        # shift the coordinates
        line[:,0] = line[:,0] - np.min(line[:,0])
        line[:,1] = line[:,1] - np.min(line[:,1])
        
        # now add the lines to the dataset
        dataset.append(line)
        
        # get the indexes just in case if we need one hots somewhere in the future
        translation , final_list= translate(text, translation)
        labels.append(final_list)

In [27]:
whole_data = np.concatenate(dataset, axis=0)
y_std = np.std(whole_data[:, 1])
norm_data = []
# normalize the data
for line in dataset:
    line[:, :2] /= y_std
    norm_data.append(line)
dataset = norm_data

In [28]:
# Now that all the data is parsed successfully, we save the data
try:
    os.makedirs('data_parsed')
except:
    pass
np.save('data_parsed/dataset', np.array(dataset))
np.save('data_parsed/labels', np.array(labels))
with open('data_parsed/translation.pkl', 'wb') as translation_file:
    pickle.dump(translation, translation_file)