In [1]:
%load_ext autoreload
%autoreload 2
import time
from utils import generate_pixel_columns, extract_best_entries

In [2]:
category = 'car'
image_gen_params = {
    'magnification': 1,
    'resolution': 32,
    'invert_color': True,
    'stroke_width_scale': 1
}

if category is None:
    raise Exception("Must select a category")

df = extract_best_entries(f'./dataset/{category}.ndjson', 2000, recognized=True)
df = generate_pixel_columns(df, **image_gen_params)
df = df.drop(columns=['word', 'timestamp', 'recognized', 'key_id', 'complexity', 'drawing'])

In [3]:
train_amt = int(len(df) * .9)

train = df[:train_amt]
test = df[train_amt:]
# del df

train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

print(f'Train: {len(train)} entries, test: {len(test)} entries.')

Train: 1800 entries, test: 200 entries.


In [4]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.multiclass import OneVsRestClassifier
from sklearn.neural_network import MLPClassifier
from itertools import repeat
import time

y = train['countrycode'].to_numpy()
X = train.drop(columns=['countrycode']).to_numpy()
print("Done generating features and target")

scaler = StandardScaler()
X = scaler.fit_transform(X)
pca = PCA(.85)
X = pca.fit_transform(X)
print(f'PCA & standardization done. Keeping {pca.n_components_} features')

classifier = MLPClassifier(hidden_layer_sizes=tuple(repeat(int(pca.n_components_ * 1.2), 3)), solver='lbfgs', alpha=1e-07)
start = time.time()
model = OneVsRestClassifier(classifier, n_jobs=-1).fit(X, y)
end = time.time()
print(f"Done training model in {'{:.2f}'.format(end - start)}s")

Done generating features and target
PCA & standardization done. Keeping 386 features
Done training model in 56.25s


In [5]:
test2 = test.drop(columns=['countrycode']).to_numpy()
test2 = scaler.transform(test2)
test2 = pca.transform(test2)
prediction = model.predict(test2)

from sklearn.metrics import accuracy_score
acc_score = accuracy_score(test['countrycode'].values.tolist(), prediction)
print(f"Accuracy: {acc_score}")

Accuracy: 0.35
