Skip to content

Commit

Permalink
Some refactoring, some TODOs
Browse files Browse the repository at this point in the history
  • Loading branch information
IshanManchanda committed May 31, 2020
1 parent 4c2ec97 commit d7bb018
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ scipy = "==1.4.1"
wandb = "==0.8.36"

[requires]
python_full_version = "3.8.2"
python_version = "3.8"
16 changes: 11 additions & 5 deletions digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@


def main():
if not os.path.isfile('data/mnist_py3_deskewed.pkl.gz'):
x = input('Deskewed data not found, generate now? (y/n): ')
if x.lower() == 'y':
deskew_data()
check_deskewed_files_exist()

if os.path.isfile('networks/network.json'):
x = input('Trained network found, train new network anyways? (y/n: ')
Expand All @@ -24,7 +21,8 @@ def main():
new_dir = os.path.join(archive_dir, '1')
os.rename(current_dir, new_dir)

# TODO: Refactor these functions to make sense
# REVIEW: Pass the loaded data as a global
# so that it can be used across runs?
training, validation, test = load_data()
wandb.init(project='digits')
n = train([784, 128, 10], 0.008, 0.2, 0.05, training, validation)
Expand All @@ -40,5 +38,13 @@ def gui():
run_gui(n)


def check_deskewed_files_exist():
# TODO: Check for all parts of deskewed dataset
if not os.path.isfile('data/mnist_py3_deskewed.pkl.gz'):
x = input('Deskewed data not found, generate now? (y/n): ')
if x.lower() == 'y':
deskew_data()


if __name__ == '__main__':
main()
6 changes: 5 additions & 1 deletion src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from globals import current_dir
from preprocessor import deskew_image, dots_to_image
from utils import draw_digit, save_digit
from utils import save_digit


class InputGUI:
Expand All @@ -26,6 +26,10 @@ def __init__(self, root, n=None):
self.label = tk.Label(root, text=desc)
self.label.grid(column=1, row=0, columnspan=3)

# TODO: Add text field(s) to display prediction and confidence
# REVIEW: Perhaps allow user to enter correct answer and save
# as additional test data?
# Perhaps make new training data out of it?
self.canvas = tk.Canvas(
self.root, width=self.size, height=self.size,
highlightthickness=2, highlightbackground='black'
Expand Down

0 comments on commit d7bb018

Please sign in to comment.