Skip to content

Commit

Permalink
Lots of stuff WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
IshanManchanda committed May 24, 2020
1 parent 45ffe3c commit 8172518
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ src/wandb
# Files to ignore
scratch.py
data/mnist_py3_deskewed.pkl.gz
data/mnist_py3_mini_deskewed.pkl.gz
networks/*
!networks/humans.txt

Expand Down
31 changes: 18 additions & 13 deletions digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import wandb

from src.gui import gui
from src.network import NeuralNetwork
from src.utils import deskew_data, load_data

Expand All @@ -13,34 +14,38 @@ def main():
if x.lower() == 'y':
deskew_data()

# TODO: Refactor these functions to make sense
training, validation, test = load_data()
train([784, 128, 10], 0.008, 0.2, 0.05, training, validation)

wandb.init(project='digits')
n = train([784, 128, 10], 0.008, 0.2, 0.05, training, validation)
try:
gui(n)
except:
wandb.run.save()

# root = tk.Tk()
# InputGUI(root, n)
# root.mainloop()

def train(size, eta, lmbda, alpha, training, validation):
wandb.init(project='digits')
n = NeuralNetwork(size, eta=eta, lmbda=lmbda, alpha=alpha)
n.train(
np.random.permutation(training)[:5000],
np.random.permutation(validation)[:500], epochs=20, batch_size=20
np.random.permutation(validation)[:500], epochs=5, batch_size=20
)

data_dir = os.path.join(os.getcwd(), 'networks')
i = get_save_index(data_dir)
n.save(os.path.join(data_dir, f'{i}.json'))
n.plot(os.path.join(data_dir, f'{i}.png'))
wandb.save(os.path.join(data_dir, f'{i}.json'))
wandb.save(os.path.join(data_dir, f'{i}.png'))
wandb.run.save()
data_dir = os.path.join(data_dir, str(i))
# TODO: Save to run folder
# os.path.join(wandb.run.dir, '')
n.save(os.path.join(data_dir, 'network.json'))
n.plot(os.path.join(data_dir, 'accuracy.png'))
wandb.save(os.path.join(data_dir, 'network.json'))
wandb.save(os.path.join(data_dir, 'accuracy.png'))
return n


def get_save_index(data_dir):
i = 1
while os.path.isfile(os.path.join(data_dir, f'{i}.json')):
while os.path.isdir(os.path.join(data_dir, str(i))):
i += 1
return i

Expand Down
25 changes: 17 additions & 8 deletions src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from preprocessor import deskew_image, dots_to_image
from utils import draw_digit
from utils import draw_digit, save_digit


class InputGUI:
Expand Down Expand Up @@ -56,16 +56,25 @@ def clear(self):
self.dots = []

def predict(self):
# TODO: Create new directory inside the run folder for this prediction
data = dots_to_image(self.dots, self.scale)
# draw_digit(data)

# save_digit(data, 'raw.png')
deskewed = deskew_image(data)
draw_digit(deskewed)
# TODO: Save images in prediction folder
# save_digit(data, 'deskewed.png')

# draw_digit(data)
# draw_digit(deskewed)

if self.n:
prediction = self.n.predict(data)
# TODO: Save raw as well as deskewed prediction as json
# prediction = self.n.predict(data.reshape((784,)))
prediction = self.n.predict(deskewed.reshape((784,)))
digit = np.argmax(prediction)
print(prediction, digit, prediction[digit])
print('Prediction: %s, confidence: %d%%' % (
digit, prediction[digit] * 100
))
print(prediction)

def draw(self, event):
x, y = event.x, event.y
Expand All @@ -78,9 +87,9 @@ def draw(self, event):
)


def gui():
def gui(n=None):
root = tk.Tk()
InputGUI(root)
InputGUI(root, n)
root.mainloop()


Expand Down
35 changes: 28 additions & 7 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from preprocessor import deskew_image


def load_data(deskew=True):
def load_data(mini=True, deskew=True):
"""
Loads MNIST dataset from disk.
Expand All @@ -20,8 +20,11 @@ def load_data(deskew=True):
# The _py3 version of the dataset is a redumped version for Python 3
# which doesn't use Python 2's latin1 encoding
data_dir = os.path.join(os.getcwd(), 'data')
file_path = os.path.join(data_dir, 'mnist_py3_deskewed.pkl.gz') if deskew \
file_path = (
os.path.join(data_dir, 'mnist_py3_mini_deskewed.pkl.gz') if mini
else os.path.join(data_dir, 'mnist_py3_deskewed.pkl.gz') if deskew
else os.path.join(data_dir, 'mnist_py3.pkl.gz')
)

if not os.path.isfile(file_path):
raise FileNotFoundError(f'{file_path} not found!')
Expand Down Expand Up @@ -56,15 +59,25 @@ def deskew_data():
data_path = os.path.join(data_dir, 'mnist_py3.pkl.gz')
with gzip.open(data_path, 'rb') as f:
data = pickle.load(f)
processed_data = []

processed_data = []
for section in data:
xs = [
list(deskew_image(x.reshape((28, 28))).reshape(784, ))
for x in section[0]
]
xs = []
for x in section[0]:
xs.append(list(deskew_image(x.reshape((28, 28))).reshape(784, )))
processed_data.append((xs, section[1]))

# deskew_mini_path = os.path.join(data_dir,
# 'mnist_py3_mini_deskewed.pkl.gz')
# processed_data_mini = [
# (processed_data[0][0][:5000], processed_data[0][1][:5000]),
# (processed_data[0][0][:1000], processed_data[0][1][:1000]),
# (processed_data[0][0][:1000], processed_data[0][1][:1000])
# ]
# with gzip.open(deskew_mini_path, 'wb') as f:
# A protocol of -1 means the latest one
# pickle.dump(processed_data_mini, f, protocol=-1)

with gzip.open(deskew_path, 'wb') as f:
# A protocol of -1 means the latest one
pickle.dump(processed_data, f, protocol=-1)
Expand All @@ -87,6 +100,14 @@ def draw_digit(image):
Image.fromarray(arr).resize((256, 256), Image.ANTIALIAS).show()


def save_digit(image, file_path):
"""
Renders and displays the image to disk.
"""
arr = (np.array(image).reshape((28, 28)) * 255).astype('uint8')
Image.fromarray(arr).resize((256, 256), Image.ANTIALIAS).save(file_path)


def main():
deskew_data()

Expand Down

0 comments on commit 8172518

Please sign in to comment.