In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import numpy as np

In [3]:
if ".." not in sys.path:
    sys.path.append("..")

In [4]:
from raw.network import LinearANN
from raw.losses import binary_cross_entropy, d_binary_cross_entropy

In [5]:
input_size = 10
hidden_layers = [5]
output_size = 2
model = LinearANN(input_size, hidden_layers, output_size)

In [6]:
model

LinearANN {
  Linear - weights: (10, 5), bias: (5,)
  Linear - weights: (5, 2), bias: (2,)
}

In [7]:
def one_hot(positive_class_inds, batch_size, output_size):
    y = np.zeros((batch_size, output_size), dtype=np.float)
    y[np.arange(len(y)), positive_class_inds] = 1
    return y

In [8]:
batch_size = 1
x = np.random.normal(size=(batch_size, input_size))
y = np.random.randint(0, 2, size=(batch_size,))
print(y)
y = one_hot(y, batch_size, output_size)

[0]


In [9]:
y, x

(array([[1., 0.]]),
 array([[-0.81684848, -0.05829526, -1.5175493 ,  0.9165314 , -1.85453004,
         -0.85769108, -0.4260255 ,  0.15387974,  0.73615067,  0.07140471]]))

In [10]:
out = model(x)
out.shape

(1, 2)

In [11]:
loss_per_batch = binary_cross_entropy(out, y)

In [12]:
loss_per_batch

array([[0.54023501, 1.55101486]])

In [13]:
batch_loss = np.mean(np.sum(loss_per_batch, -1))
print("Batch loss", batch_loss)

Batch loss 2.091249868213781


In [14]:
loss_error = d_binary_cross_entropy(out, y)

In [15]:
loss_error

array([[-1.71641019,  4.71625408]])

Sense check. Remember if the loss error is negative then it means increasing the activation value will send the loss down and decreasing will increase it.

In [16]:
out, y

(array([[0.58261132, 0.78796732]]), array([[1., 0.]]))

In [17]:
model.backward(loss_error)

In [18]:
for _ in range(100):
    out = model(x)
    print(out, y)
    loss_per_batch = binary_cross_entropy(out, y)
    batch_loss = np.mean(np.sum(loss_per_batch, -1))
    loss_error = d_binary_cross_entropy(out, y)
    model.backward(loss_error, lr=0.1)

[[0.58615781 0.78262758]] [[1. 0.]]
[[0.62058667 0.72485851]] [[1. 0.]]
[[0.65056077 0.66384735]] [[1. 0.]]
[[0.67670358 0.60285526]] [[1. 0.]]
[[0.69958996 0.54466162]] [[1. 0.]]
[[0.71971955 0.49112891]] [[1. 0.]]
[[0.73751468 0.44316483]] [[1. 0.]]
[[0.75332831 0.40093292]] [[1. 0.]]
[[0.76745417 0.36412859]] [[1. 0.]]
[[0.78013648 0.33221173]] [[1. 0.]]
[[0.79157838 0.30456362]] [[1. 0.]]
[[0.80194901 0.28057704]] [[1. 0.]]
[[0.81138957 0.25970008]] [[1. 0.]]
[[0.82001826 0.24145201]] [[1. 0.]]
[[0.82793449 0.22542426]] [[1. 0.]]
[[0.83522224 0.21127436]] [[1. 0.]]
[[0.84195281 0.19871738]] [[1. 0.]]
[[0.8481871  0.18751691]] [[1. 0.]]
[[0.8539774  0.17747685]] [[1. 0.]]
[[0.85936884 0.16843427]] [[1. 0.]]
[[0.86440062 0.16025337]] [[1. 0.]]
[[0.86910689 0.15282058]] [[1. 0.]]
[[0.87351763 0.14604045]] [[1. 0.]]
[[0.87765919 0.13983249]] [[1. 0.]]
[[0.88155491 0.1341284 ]] [[1. 0.]]
[[0.88522549 0.12887003]] [[1. 0.]]
[[0.88868936 0.12400763]] [[1. 0.]]
[[0.89196303 0.11949839]] [[