In [1]:
"""
This uses classes defined in the file nn.py

We are mainly using the Layer class, which defines matrix operations
for the forward and backward pass at the instance of a single layer,
and its relationship to a 'child' layer. It would be useful to look at
how these matrices are instantiated, and how their operations are implemented.
"""
from layer import *
import numpy as np

# Xavier initialization, the 4 is from d_in + d_out
init = np.sqrt(6) / np.sqrt(3)

# Input layer, each sample will have 2 dimensions
il = InputLayer(size=2)
# Hidden Layers
hl1 = Layer(size=3, child_layer=il, init=init)
hl2 = Layer(size=3, child_layer=hl1, init=init)
# Output Layer; size should be the number of potential class layers
# No bias for output layer
ol = Layer(size=1, child_layer=hl2, init=init, add_bias=False)
# Loss Layer
ll = LossLayer(child_layer=ol)

# These are the XOR values. each tuple has (coordinate pairs, target label)
input_samples = [([1, 1], 0), ([1, -1], 1), ([-1, 1], 1), ([-1, -1], 0)]

# Run 20 epochs
for e in range(1000):
  """
  SGD, but in practice we will actually just sample everything
  """
  # For tracking losses in a single epoch
  losses = []
  for inp, target in input_samples:
    # Forward propogate each layer. Note the different forward implementation per layer type
    il.forward(inp)
    hl1.forward()
    hl2.forward()
    ol.forward()
    ll.forward(target)
    # Track per isntance loss
    losses.append(ll.loss)

    # Now backward propogate the error starting at the loss node. Note the different forward implementation per layer type
    ll.backward(target)
    ol.backward()
    hl2.backward()
    hl1.backward()

  print("AVG LOSS: %.2f" % (sum(losses) / len(losses)))




AVG LOSS: 1.08
AVG LOSS: 1.00
AVG LOSS: 0.93
AVG LOSS: 0.87
AVG LOSS: 0.82
AVG LOSS: 0.79
AVG LOSS: 0.77
AVG LOSS: 0.75
AVG LOSS: 0.74
AVG LOSS: 0.74
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.73
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 0.72
AVG LOSS: 