From d7bb018d0a7467a8422855adb387b613e81b6c8b Mon Sep 17 00:00:00 2001 From: Ishan Manchanda Date: Sun, 31 May 2020 22:44:52 +0530 Subject: [PATCH] Some refactoring, some TODOs --- Pipfile | 2 +- digits.py | 16 +++++++++++----- src/gui.py | 6 +++++- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/Pipfile b/Pipfile index b67b76c..dd8118b 100644 --- a/Pipfile +++ b/Pipfile @@ -14,4 +14,4 @@ scipy = "==1.4.1" wandb = "==0.8.36" [requires] -python_full_version = "3.8.2" +python_version = "3.8" diff --git a/digits.py b/digits.py index 2943672..003f092 100644 --- a/digits.py +++ b/digits.py @@ -10,10 +10,7 @@ def main(): - if not os.path.isfile('data/mnist_py3_deskewed.pkl.gz'): - x = input('Deskewed data not found, generate now? (y/n): ') - if x.lower() == 'y': - deskew_data() + check_deskewed_files_exist() if os.path.isfile('networks/network.json'): x = input('Trained network found, train new network anyways? (y/n: ') @@ -24,7 +21,8 @@ def main(): new_dir = os.path.join(archive_dir, '1') os.rename(current_dir, new_dir) - # TODO: Refactor these functions to make sense + # REVIEW: Pass the loaded data as a global + # so that it can be used across runs? training, validation, test = load_data() wandb.init(project='digits') n = train([784, 128, 10], 0.008, 0.2, 0.05, training, validation) @@ -40,5 +38,13 @@ def gui(): run_gui(n) +def check_deskewed_files_exist(): + # TODO: Check for all parts of deskewed dataset + if not os.path.isfile('data/mnist_py3_deskewed.pkl.gz'): + x = input('Deskewed data not found, generate now? (y/n): ') + if x.lower() == 'y': + deskew_data() + + if __name__ == '__main__': main() diff --git a/src/gui.py b/src/gui.py index bfad376..052690c 100644 --- a/src/gui.py +++ b/src/gui.py @@ -5,7 +5,7 @@ from globals import current_dir from preprocessor import deskew_image, dots_to_image -from utils import draw_digit, save_digit +from utils import save_digit class InputGUI: @@ -26,6 +26,10 @@ def __init__(self, root, n=None): self.label = tk.Label(root, text=desc) self.label.grid(column=1, row=0, columnspan=3) + # TODO: Add text field(s) to display prediction and confidence + # REVIEW: Perhaps allow user to enter correct answer and save + # as additional test data? + # Perhaps make new training data out of it? self.canvas = tk.Canvas( self.root, width=self.size, height=self.size, highlightthickness=2, highlightbackground='black'