Skip to content

Commit

Permalink
Added some docstrings, comments; W&B WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
IshanManchanda committed May 23, 2020
1 parent 42d12af commit 7488191
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
3 changes: 0 additions & 3 deletions src/digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from .utils import deskew_data, load_data


# TODO: Move to Weights and Biases to track models
# in place of the current "dump as json" method

def main():
if not os.path.isfile('../data/mnist_py3_deskewed.pkl.gz'):
x = input('Deskewed data not found, generate now? (y/n): ')
Expand Down
4 changes: 2 additions & 2 deletions src/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def train(self, data, validation_data=None, epochs=10, batch_size=20):
# Each batch is then used to train the network
self.train_batch(batch)

# After each epoch, optionally print progress
# TODO: wandb.log({'epoch': epoch, 'loss': loss})
if validation_data is not None:
correct = self.validate(validation_data)
percentage = 100 * correct / n_validation
Expand Down Expand Up @@ -282,6 +280,8 @@ def softmax(z):


def main():
# TODO: Animated plotting to show performance change with time.
# check scratch.py
wandb.init(project='digits')
# n = NeuralNetwork([784, 256, 10])
training, validation, test = load_data()
Expand Down
16 changes: 16 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@


def load_data(deskew=True):
"""
Loads MNIST dataset from disk.
Returns the training, validation, and test data as a list of tuples,
where the outputs are one-hot encoded vectors.
"""
# The _py3 version of the dataset is a redumped version for Python 3
# which doesn't use Python 2's latin1 encoding
if deskew:
Expand All @@ -33,6 +39,10 @@ def load_data(deskew=True):


def deskew_data():
"""
Deskews the MNIST dataset and saves it to disk.
"""
# Check if deskewed data already exists
if os.path.isfile('../data/mnist_py3_deskewed.pkl.gz'):
return

Expand All @@ -54,12 +64,18 @@ def deskew_data():


def get_expected_y(digit):
"""
Returns a one-hot encoded vector of the inputted digit.
"""
y = np.array([0] * 10)
y[digit] = 1
return y


def draw_digit(image):
"""
Renders and displays the image.
"""
arr = (np.array(image).reshape((28, 28)) * 255).astype('uint8')
Image.fromarray(arr).resize((256, 256), Image.ANTIALIAS).show()

Expand Down
2 changes: 2 additions & 0 deletions wandb/settings
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[default]

0 comments on commit 7488191

Please sign in to comment.