In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
from utils import load_run, extract_random_entries, generate_pixel_columns
import pandas as pd

run = load_run('1638008947')
if run is None:
    raise Exception("Invalid run id")

num_entries = 1000

og_data = run['data']
cats = list(og_data['word'].value_counts().keys())
print(cats)

files = list(map(lambda c: f'./dataset/{c}.ndjson', cats))
try:
    img_params = run['img_params']
except KeyError:
    raise Exception("Unknown image params. Aborting")


data = [extract_random_entries(file, num_entries, recognized=True) for file in files]
flat_data = [item for sublist in data for item in sublist]
df = pd.DataFrame.from_dict(flat_data, orient='columns')
print(f'Loaded {len(df)} entries from {files}')
df = generate_pixel_columns(df, **img_params)
print('Done generating pixel columns')
data = df.reset_index(drop=True)

model = run['model']
pca = run['pca'] if 'pca' in run else None
scaler = run['scaler'] if 'scaler' in run else None
print(f"Done loading run. PCA {'not ' if pca is None else ''}found.")

sample = data.sample(1000 if len(data) >= 1000 else len(data)).reset_index(drop=True)
target = sample['word'].values.tolist()
test = sample.drop(columns=['countrycode', 'timestamp', 'recognized', 'key_id', 'drawing', 'word']).to_numpy()

print('Predicting...')
if pca is not None:
    test = scaler.transform(test)
    test = pca.transform(test)
prediction = model.predict(test)
print('Done')

from sklearn.metrics import accuracy_score
print("Scoring...")
acc_score = accuracy_score(target, prediction)
print(f"Accuracy score: {acc_score}")

/home/chris/swd_2/aai/final-project/notebook
/home/chris/swd_2/aai/final-project/notebook/runs/1638008947
['panda', 'sheep', 'tennis racquet', 'lantern', 'blueberry']
<class 'dict'>
Loaded 5000 entries from ['./dataset/panda.ndjson', './dataset/sheep.ndjson', './dataset/tennis racquet.ndjson', './dataset/lantern.ndjson', './dataset/blueberry.ndjson']
Done generating pixel columns
Done loading run. PCA found.
Predicting...
Done
Scoring...
Accuracy score: 0.862
