# [nn.py](../src/nn.py): A Python Package for Implementing Simple Neural Networks

To illustrate the basic functionality of the [`nn.py`](../src/nn.py) package, we will implement the three-layer neural network for identifying handwritten images of digits that is discussed in the notebook [01_exercise_mnist](01_exercise_mnist.ipynb). Recall that the network consisted of an input layer containing 784 neurons, a hidden layer (with a ReLU activation) containing 40 neurons, and an output layer containing 10 neurons. Such a network can be created as follows:

In [9]:
input_layer = Layer(784)
hidden_layer = Layer(40, activation=RELU)
output_layer = Layer(10)
layers = [input_layer, hidden_layer, output_layer]

network = Network(layers)

To train the network, we first create a new `Dataset` object, which stores the data on which the network will be trained.

In [10]:
mnist_training_dataset = Dataset(training_images, training_targets)

If we want to record any information from the training session, we have to instruct the Dataset to do so. This is accomplished by adding to the Dataset a 'statistic' (which is any function from the outputs and targets for a given datum in the dataset, as well as the previous record). For example, if we want to record the mean squared deviation for each datum in the dataset, we can do so as follows:

In [13]:
def mean_squared_deviation(outputs, targets, record):
    """Mean-squared deviation for a datapoint"""
    return (np.sum((outputs - targets) ** 2) / len(outputs))
    
    
mnist_training_dataset.add_statistic("Loss", mean_squared_deviation)

To train the network, we simply call the network on the dataset, setting the `train` flag to `True` and specifying an initialization parameter and learning rate. In addition to updating the network's weights, calling the network on the dataset will also return the training record, which we store in a variable for later analysis. 

In [15]:
training_record = network(mnist_training_dataset, train=True, initialize=0.01, learning_rate=0.005)
training_record


Unnamed: 0,Loss
0,1.003605
1,0.994902
2,0.998143
3,1.000776
4,1.006966
...,...
59995,0.024958
59996,0.162332
59997,0.153164
59998,0.182135


Now that we have trained the network, we can examine the output of the network on a given set of inputs. To do so, we call the network on the inputs. 

In [16]:
network(training_images[103])

Unnamed: 0,Output
0,-0.020873
1,0.002021
2,-0.030372
3,-0.004016
4,0.019964
5,-0.055834
6,0.020547
7,0.995928
8,0.016964
9,0.139536


To test the network, we simply create a new `Dataset` object using the mnist testing images and targets. 

In [27]:
mnist_testing_dataset = Dataset(testing_images, testing_targets)

For each datum in the dataset, we would like to keep track of both the correct label of the image and the network's best guess as to the displayed digit.

In [28]:
mnist_testing_dataset.add_statistic("Digit", lambda outputs, targets, record : np.argmax(targets))
mnist_testing_dataset.add_statistic("Guess", lambda outputs, targets, record : np.argmax(outputs))

Now, we call the network on the testing dataset. As long as the `train` flag is not set to `True`, the network's weights will not be updated. 

In [40]:
testing_record = network(mnist_testing_dataset)

Now that we have the testing record, we can analyze the data however we like (NB: for data analysis, we do not rely on the functionality provided by `nn.py`, but instead use the `numpy` and `pandas` packages). For example, we can add a column to the record indicating whether or not the network guessed the digit correctly.

In [41]:
testing_record['Correct'] = np.where(testing_record['Guess'] == testing_record['Digit'], 1, 0)
testing_record

Unnamed: 0,Digit,Guess,Correct
0,7,7,1
1,2,6,0
2,1,1,1
3,0,0,1
4,4,4,1
...,...,...,...
9995,2,2,1
9996,3,3,1
9997,4,9,0
9998,5,5,1


We can then compute the percentage of correct guesses by summing over the values in the column 'Correct':

In [43]:
size = testing_record['Correct'].size
total_correct = testing_record['Correct'].sum()
percent_correct = (total_correct / size) * 100

print(f"The network guessed {percent_correct}% ({total_correct}/{size}) of the displayed digits correctly.")

The network guessed 89.56% (8956/10000) of the displayed digits correctly.
