# XOR: backprop numpy para MLP CrossEntropy

In [1]:
import numpy as np; np.set_printoptions(precision=4)
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[1, 0], [0, 1], [0, 1], [1, 0]])
W = np.array([[1, 1], [1, 1]]); b1 = np.array([-1,  .5])
V = np.array([[1, -1], [-1, 1]]); b2 = np.array([ 1, -1])

## Forward: $\small\quad\boldsymbol{x}\to\boldsymbol{z}(\mathbf{W},\boldsymbol{b}_1)\to\boldsymbol{h}\to\boldsymbol{a}(\mathbf{V},\boldsymbol{b}_2)\to\mathcal{L}$

In [2]:
z = X @ W + b1; print('z =', str(z).replace('\n',','))
h = np.maximum(0, z); print('h =', str(h).replace('\n',','))
a = h @ V + b2; print('a =', str(a).replace('\n',','))
y_pred = np.exp(a); y_pred = np.transpose(y_pred.T / y_pred.sum(axis=1))
print('y_pred =', str(y_pred).replace('\n',','))
Ln = -y * np.log(y_pred)
print('Ln =', str(Ln).replace('\n',','), '\nL =', np.sum(Ln)/4.)

z = [[-1.   0.5], [ 0.   1.5], [ 0.   1.5], [ 1.   2.5]]
h = [[0.  0.5], [0.  1.5], [0.  1.5], [1.  2.5]]
a = [[ 0.5 -0.5], [-0.5  0.5], [-0.5  0.5], [-0.5  0.5]]
y_pred = [[0.7311 0.2689], [0.2689 0.7311], [0.2689 0.7311], [0.2689 0.7311]]
Ln = [[ 0.3133 -0.    ], [-0.      0.3133], [-0.      0.3133], [ 1.3133 -0.    ]] 
L = 0.5632616875182226


## Backward $\small\quad\mathcal{L}\to\boldsymbol{a}(\mathbf{V},\boldsymbol{b}_2)\to\boldsymbol{h}\to\boldsymbol{z}(\mathbf{W},\boldsymbol{b}_1)\to\boldsymbol{x}$

In [7]:
gV = np.zeros((4, 2, 2)); gb2 = np.zeros((4, 2))
gW = np.zeros((4, 2, 2)); gb1 = np.zeros((4, 2))
for n in np.arange(4):
    print('*** x =', str(X[n]).replace('\n',','), '***')
    ut = (y_pred[n] - y[n]).reshape(1, -1); print('uJLa =', str(ut).replace('\n',','))
    gV[n] = np.kron(ut.T, h[n].reshape(1, -1)); print('gV =', str(gV[n]).replace('\n',','))
    gb2[n] = ut; print('gb2 =', str(gb2[n]).replace('\n',','))
    ut = ut @ V; print('uJLaJah =', str(ut).replace('\n',','))
    Jhz = np.diag(np.heaviside(z[n], 0.0)); print('Jhz =', str(Jhz).replace('\n',','))
    ut = ut @ Jhz; print('uJLaJahJhz =', str(ut).replace('\n',','))
    gW[n] = np.kron(ut.T, X[n, :]); print('gW =', str(gW[n]).replace('\n',','))
    gb1[n] = ut; print('gb1 =', str(gb1[n]).replace('\n',','))
    ut = ut @ W; print('uJLaJahJhzJzx =', str(ut).replace('\n',','))
gWavg = gW.sum(axis=0)/4; print('\ngWavg =', str(gWavg).replace('\n',','))
gb1avg = gb1.sum(axis=0)/4; print('gb1avg =', str(gb1avg).replace('\n',','))
gVavg = gV.sum(axis=0)/4; print('gVavg =', str(gVavg).replace('\n',','))
gb2avg = gb2.sum(axis=0)/4; print('gb1avg =', str(gb2avg).replace('\n',','))

*** x = [0 0] ***
uJLa = [[-0.2689  0.2689]]
gV = [[-0.     -0.1345], [ 0.      0.1345]]
gb2 = [-0.2689  0.2689]
uJLaJah = [[-0.5379  0.5379]]
Jhz = [[0. 0.], [0. 1.]]
uJLaJahJhz = [[0.     0.5379]]
gW = [[0. 0.], [0. 0.]]
gb1 = [0.     0.5379]
uJLaJahJhzJzx = [[0.5379 0.5379]]
*** x = [0 1] ***
uJLa = [[ 0.2689 -0.2689]]
gV = [[ 0.      0.4034], [-0.     -0.4034]]
gb2 = [ 0.2689 -0.2689]
uJLaJah = [[ 0.5379 -0.5379]]
Jhz = [[0. 0.], [0. 1.]]
uJLaJahJhz = [[ 0.     -0.5379]]
gW = [[ 0.      0.    ], [-0.     -0.5379]]
gb1 = [ 0.     -0.5379]
uJLaJahJhzJzx = [[-0.5379 -0.5379]]
*** x = [1 0] ***
uJLa = [[ 0.2689 -0.2689]]
gV = [[ 0.      0.4034], [-0.     -0.4034]]
gb2 = [ 0.2689 -0.2689]
uJLaJah = [[ 0.5379 -0.5379]]
Jhz = [[0. 0.], [0. 1.]]
uJLaJahJhz = [[ 0.     -0.5379]]
gW = [[ 0.      0.    ], [-0.5379 -0.    ]]
gb1 = [ 0.     -0.5379]
uJLaJahJhzJzx = [[-0.5379 -0.5379]]
*** x = [1 1] ***
uJLa = [[-0.7311  0.7311]]
gV = [[-0.7311 -1.8276], [ 0.7311  1.8276]]
gb2 = [-0.7311  0.7311

## SGD: $\small\quad\boldsymbol{\theta}=\boldsymbol{\theta}'-\eta\frac{\partial\mathcal{L}}{\partial\boldsymbol{\theta}}$

In [9]:
eta = 0.1
nW = W - eta * gWavg; print('nW =', str(nW).replace('\n',','))
nb1 = b1 - eta * gb1avg; print('ngb1 =', str(nb1).replace('\n',','))
nV = V - eta * gVavg; print('ngV =', str(nV).replace('\n',','))
nb2 = b2 - eta * gb2avg; print('nb2 =', str(nb2).replace('\n',','))

nW = [[1.0366 1.0366], [0.9769 0.9769]]
ngb1 = [-0.9634  0.4769]
ngV = [[ 1.0183 -0.9711], [-1.0183  0.9711]]
nb2 = [ 1.0116 -1.0116]
