In [1]:
import pandas as pd
import numpy as np
import random
from ast import literal_eval
from sklearn.utils import shuffle
training_set = '../train_simplified/'
from os import listdir
from os.path import isfile, join
train_files = [join(training_set, f) for f in listdir(training_set) if isfile(join(training_set, f))]

In [2]:
def clean_drawing(inkarray):
    inkarray = literal_eval(inkarray)
    stroke_lengths = [len(stroke[0]) for stroke in inkarray]
    total_points = sum(stroke_lengths)
    np_ink = np.zeros((total_points, 3), dtype=np.float32)
    current_t = 0
    for stroke in inkarray:
        for i in [0, 1]:
            np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i]
        current_t += len(stroke[0])
        np_ink[current_t - 1, 2] = 1  # stroke_end
    # Preprocessing.
    # 1. Size normalization.
    lower = np.min(np_ink[:, 0:2], axis=0)
    upper = np.max(np_ink[:, 0:2], axis=0)
    scale = upper - lower
    scale[scale == 0] = 1
    np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale
    # 2. Compute deltas.
    np_ink = np_ink[1:, 0:2] - np_ink[0:-1, 0:2]
    return np_ink

def load_and_clean(files,n_rows=100):
    """
    loads into a single dataframe a 100 random rows from each of the csv files in the parameter list "files"
    then cleans the drawing column by mapping the above function clean_drawing 
    """
    dfs = []
    counter = 0
    for file in files:
        print(counter)
        total_rows = sum(1 for line in open(file))
        skip = sorted(random.sample(xrange(1,total_rows),total_rows-n_rows))
        df = pd.read_csv(file, skiprows=skip)
        dfs += [df]
        counter += 1
    full_df = pd.concat(dfs)
    full_df['drawing'] = full_df['drawing'].map(clean_drawing)
    return full_df

In [5]:
df = load_and_clean(train_files[:15])
df = df.reset_index()

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14


In [7]:
df.head()

Unnamed: 0,index,countrycode,drawing,key_id,recognized,timestamp,word
0,0,BR,"[[-0.13636363, 0.0], [-0.07792209, 0.015686275...",6597849682804736,True,2017-01-31 23:46:24.330910,lollipop
1,1,TH,"[[0.06730771, -0.06666666], [0.019230783, -0.0...",6534200897306624,True,2017-03-24 11:08:21.302110,lollipop
2,2,PL,"[[0.054878056, 0.41501978], [0.006097555, 0.23...",6543127101833216,True,2017-01-31 18:33:12.322410,lollipop
3,3,IM,"[[0.07182321, -0.572549], [-0.20994474, -0.070...",5467752183627776,True,2017-01-26 07:52:22.159220,lollipop
4,4,GB,"[[-0.26666668, 0.08235294], [-0.19999999, 0.10...",6263887919841280,True,2017-03-11 11:32:02.611740,lollipop
