-
Notifications
You must be signed in to change notification settings - Fork 0
/
gui.py
117 lines (92 loc) · 2.99 KB
/
gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import tkinter as tk
import numpy as np
from globals import current_dir
from preprocessor import deskew_image, dots_to_image
from utils import save_digit
class InputGUI:
def __init__(self, root, n=None):
self.root = root
self.n = n # NeuralNetwork object
self.dots = [] # Array to store locations of dots
self.scale = 15 # Stroke size
self.size = 28 * 10 - 1
self.root.minsize(700, 400)
self.root.title('Digits')
desc = \
'Draw a single digit in the canvas.\n' + \
'For best output, try to ensure it is centered\n' + \
'in the frame and nearly fills it.'
self.label = tk.Label(root, text=desc)
self.label.grid(column=1, row=0, columnspan=3)
# TODO: Add text field(s) to display prediction and confidence
# REVIEW: Perhaps allow user to enter correct answer and save
# as additional test data?
# Perhaps make new training data out of it?
self.canvas = tk.Canvas(
self.root, width=self.size, height=self.size,
highlightthickness=2, highlightbackground='black'
)
self.canvas.grid(column=1, row=1, columnspan=3)
self.canvas.bind('<Button-1>', self.draw)
self.canvas.bind('<B1-Motion>', self.draw)
self.clear_button = tk.Button(
self.root, text='Clear', command=self.clear
)
self.predict_button = tk.Button(
self.root, text='Predict', command=self.predict
)
self.close_button = tk.Button(
self.root, text='Close', command=self.root.destroy
)
self.clear_button.grid(column=1, row=2)
self.predict_button.grid(column=2, row=2)
self.close_button.grid(column=3, row=2)
# Needed for center alignment
self.root.grid_columnconfigure(0, weight=1)
self.root.grid_columnconfigure(4, weight=1)
def clear(self):
self.canvas.delete('all')
self.dots = []
def predict(self):
prediction_dir = get_prediction_dir()
data = dots_to_image(self.dots, self.scale)
deskewed = deskew_image(data)
save_digit(data, os.path.join(prediction_dir, 'raw.png'))
save_digit(data, os.path.join(prediction_dir, 'deskewed.png'))
# draw_digit(data)
# draw_digit(deskewed)
if self.n:
# TODO: Save raw as well as deskewed prediction as json
# prediction = self.n.predict(data.reshape((784,)))
prediction = self.n.predict(deskewed.reshape((784,)))
digit = np.argmax(prediction)
print('Prediction: %s, confidence: %d%%' % (
digit, prediction[digit] * 100
))
print(prediction)
def draw(self, event):
x, y = event.x, event.y
if 0 <= x < self.size and 0 <= y < self.size:
self.dots.append((x, y))
self.canvas.create_oval(
x - self.scale, y - self.scale,
x + self.scale, y + self.scale,
fill='#222222'
)
def run_gui(n=None):
root = tk.Tk()
InputGUI(root, n)
root.mainloop()
def get_prediction_dir():
prediction_dir = os.path.join(current_dir, 'predictions')
if not os.path.isdir(prediction_dir):
os.mkdir(prediction_dir)
i = 1
while os.path.isdir(os.path.join(prediction_dir, str(i))):
i += 1
path = os.path.join(prediction_dir, str(i))
os.mkdir(path)
return path
if __name__ == '__main__':
run_gui()