# Gradient Descent

[Gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) is the fundamental algorithm used to train neural networks. The following example uses gradient descent to find the optimum weights and biases for the simple multilayer perceptron (MLP) shown below. The network contains three layers: an input layer that accepts two values, a hidden layer with three neurons, and an output layer with one neuron. The network can be trained to transform two inputs into an output — for example, to add two inputs or square the difference. Values forwarded from the hidden layer to the output layer are transformed using the [ReLU](https://machinelearningmastery.com/rectified-linear-activation-function-for-deep-learning-neural-networks/) activation function, which turns negative numbers into zeros.

![](Images/network.png)

Let's begin by defining a class named `NeuralNetwork` that encapculates a network of this type. The network contains 13 trainable parameters: nine weights and four biases. `w0` and `w1` in the diagram correspond to `weights[0]` and `weights[1]` in the `NeuralNetwork` class, while `b0` and `b1` correspond to `biases[0]` and `biases[1]`.

In [1]:
import numpy as np

class NeuralNetwork():
    def __init__(self):
        self.weights = np.random.uniform(-1, 1, 9)
        self.biases = np.zeros(4)
        
    def show_weights_and_biases(self):
        print(f'h1: weights=[{self.weights[0]}, {self.weights[1]}], bias={self.biases[0]}')
        print(f'h2: weights=[{self.weights[2]}, {self.weights[3]}], bias={self.biases[1]}')
        print(f'h3: weights=[{self.weights[4]}, {self.weights[5]}], bias={self.biases[2]}')
        print(f'y: weights=[{self.weights[6]}, {self.weights[7]}, {self.weights[8]}], bias={self.biases[3]}')
        
    def relu(self, x):
        return max(0, x)
    
    def update_weights_and_biases(self, x1, x2, y, lr=0.01):
        prediction = self.predict(x1, x2)
        delta = prediction - y

        # Compute intermediate values for neurons in the hidden layer
        h1 = (x1 * self.weights[0]) + (x2 * self.weights[1]) + self.biases[0]
        h2 = (x1 * self.weights[2]) + (x2 * self.weights[3]) + self.biases[1]
        h3 = (x1 * self.weights[4]) + (x2 * self.weights[5]) + self.biases[2]

        # Compute deltas for 9 weights
        weight_deltas = np.empty(9)
        weight_deltas[0] = lr * (x1 * delta * self.weights[6])
        weight_deltas[1] = lr * (x2 * delta * self.weights[6])
        weight_deltas[2] = lr * (x1 * delta * self.weights[7])
        weight_deltas[3] = lr * (x2 * delta * self.weights[7])
        weight_deltas[4] = lr * (x1 * delta * self.weights[8])
        weight_deltas[5] = lr * (x2 * delta * self.weights[8])
        weight_deltas[6] = lr * delta * h1
        weight_deltas[7] = lr * delta * h2
        weight_deltas[8] = lr * delta * h3

        # Compute deltas for 4 biases
        bias_deltas = np.empty(4)
        bias_deltas[0] = lr * delta * self.weights[6]
        bias_deltas[1] = lr * delta * self.weights[7]
        bias_deltas[2] = lr * delta * self.weights[8]
        bias_deltas[3] = lr * delta

        # Update weights and biases
        self.weights -= weight_deltas
        self.biases -= bias_deltas

        # Show the results
        prediction = self.predict(x1, x2)
        print(f'Prediction: ({x1}, {x2}) => {prediction}, Error: {delta}')          
        return delta
    
    def predict(self, x1, x2):
        h1 = (x1 * self.weights[0]) + (x2 * self.weights[1]) + self.biases[0]
        h2 = (x1 * self.weights[2]) + (x2 * self.weights[3]) + self.biases[1]
        h3 = (x1 * self.weights[4]) + (x2 * self.weights[5]) + self.biases[2]
        y = (self.relu(h1) * self.weights[6]) + (self.relu(h2) * self.weights[7]) + (self.relu(h3) * self.weights[8]) + self.biases[3]
        return y

Create an instance of `NeuralNetwork` and show the randomly initialized weights and biases. Note that biases are simply initialized to 0.

In [2]:
model = NeuralNetwork()
model.show_weights_and_biases()

h1: weights=[0.1483880111917648, -0.7943936559526472], bias=0.0
h2: weights=[0.8778837883567088, -0.7589434650907823], bias=0.0
h3: weights=[0.7817547965444347, -0.03309822683623698], bias=0.0
y: weights=[0.34380530447017477, -0.2713691265157361, 0.10521625218734632], bias=0.0


Train the network for 1,000 iterations with samples that teach it how to add numbers together. The following code performs 5,000 forward passes through the network and 5,000 backpropagation passes:

In [3]:
x = np.array([[2, 2], [5, 1], [0, 4], [2, 8], [3, 0]])
y = np.array([4, 6, 4, 10, 3])

for i in range(1000):
    for j in range(len(x)):
        model.update_weights_and_biases(x[j][0], x[j][1], y[j], 0.01)

Prediction: (2, 2) => 0.25292706816963456, Error: -3.9070117863840848
Prediction: (5, 1) => 1.8576502654551839, Error: -6.197168519774474
Prediction: (0, 4) => 0.23135988438070537, Error: -3.8989581969384144
Prediction: (2, 8) => 2.9791696108506467, Error: -9.00777013618347
Prediction: (3, 0) => 0.9113798565956375, Error: -2.699704200956672
Prediction: (2, 2) => 2.3042507232574514, Error: -1.9399969241251003
Prediction: (5, 1) => 4.758630114416077, Error: -3.3485547446249737
Prediction: (0, 4) => 2.228631315616632, Error: -2.167640045074493
Prediction: (2, 8) => 10.584811026907799, Error: -3.92035458247084
Prediction: (3, 0) => 3.135795309045781, Error: 0.22815583618016966
Prediction: (2, 2) => 4.788113176128418, Error: 1.1145508856680415
Prediction: (5, 1) => 5.980173553114242, Error: 0.38442852941335204
Prediction: (0, 4) => 3.8330751882220464, Error: -0.24296572048081622
Prediction: (2, 8) => 10.214328152060528, Error: -0.36857112684289994
Prediction: (3, 0) => 2.816521670290274, Er

Prediction: (0, 4) => 3.9484798316356677, Error: -0.07601374624443924
Prediction: (2, 8) => 10.1070048995872, Error: -0.17781159232452204
Prediction: (3, 0) => 2.573244625945956, Error: -0.688366215043525
Prediction: (2, 2) => 4.531494406832108, Error: 0.7315733076628748
Prediction: (5, 1) => 5.993179518228075, Error: 0.1918351934376421
Prediction: (0, 4) => 3.9489437899445994, Error: -0.07532956578360217
Prediction: (2, 8) => 10.10718227418378, Error: -0.17813004574995972
Prediction: (3, 0) => 2.5728523120794433, Error: -0.6889631584836167
Prediction: (2, 2) => 4.53117665349291, Error: 0.7310994423513559
Prediction: (5, 1) => 5.99316052341541, Error: 0.192516780532495
Prediction: (0, 4) => 3.9494086102702797, Error: -0.07464411138220317
Prediction: (2, 8) => 10.107359958071997, Error: -0.17844911633548222
Prediction: (3, 0) => 2.5724593811475143, Error: -0.6895610995124564
Prediction: (2, 2) => 4.53085829864483, Error: 0.7306247423392715
Prediction: (5, 1) => 5.993141474797429, Error:

Prediction: (3, 0) => 2.556214319313788, Error: -0.7143303820599387
Prediction: (2, 2) => 4.517597280623268, Error: 0.7109063645948481
Prediction: (5, 1) => 5.992328834089625, Error: 0.22151420720891402
Prediction: (0, 4) => 3.9692319975708283, Error: -0.04540733605689651
Prediction: (2, 8) => 10.1149176496164, Error: -0.19207590640592898
Prediction: (3, 0) => 2.5557944774343127, Error: -0.7149717008659535
Prediction: (2, 2) => 4.5172517464837, Error: 0.7103940040404293
Prediction: (5, 1) => 5.992307083420644, Error: 0.22224823907315017
Prediction: (0, 4) => 3.9697347936348306, Error: -0.04466564629619896
Prediction: (2, 8) => 10.115108811737265, Error: -0.19242192577908845
Prediction: (3, 0) => 2.5553739422233894, Error: -0.7156141340419282
Prediction: (2, 2) => 4.516905488036283, Error: 0.7098806415105718
Prediction: (5, 1) => 5.992285254456736, Error: 0.22298359809573132
Prediction: (0, 4) => 3.97023854019621, Error: -0.043922545960570325
Prediction: (2, 8) => 10.115300308447845, Er

Prediction: (3, 0) => 2.4783178187121826, Error: -0.8284287552469154
Prediction: (2, 2) => 4.38911564061956, Error: 0.525271964579674
Prediction: (5, 1) => 6.022577675854779, Error: 0.3971577232853001
Prediction: (0, 4) => 4.114468900546399, Error: 0.1698960023206748
Prediction: (2, 8) => 10.161372743626828, Error: -0.27733712008644495
Prediction: (3, 0) => 2.4780961666601953, Error: -0.8287416940894254
Prediction: (2, 2) => 4.388553171281699, Error: 0.5244767951228768
Prediction: (5, 1) => 6.022634668086518, Error: 0.39776872459161616
Prediction: (0, 4) => 4.114989515559166, Error: 0.1706745531854068
Prediction: (2, 8) => 10.161524591513823, Error: -0.27761277938265216
Prediction: (3, 0) => 2.4778745287003296, Error: -0.8290545764073576
Prediction: (2, 2) => 4.387990174931522, Error: 0.5236810012921627
Prediction: (5, 1) => 6.0226917315178055, Error: 0.39837993366438695
Prediction: (0, 4) => 4.115510346379517, Error: 0.17145349764787277
Prediction: (2, 8) => 10.161676517673293, Error:

Prediction: (5, 1) => 6.0239656174734995, Error: 0.41187359120598277
Prediction: (0, 4) => 4.1270218492044535, Error: 0.1886885153840776
Prediction: (2, 8) => 10.165038708630298, Error: -0.28397538159410907
Prediction: (3, 0) => 2.4727843573050596, Error: -0.8362282426965688
Prediction: (2, 2) => 4.3748978320254075, Error: 0.5052087271682568
Prediction: (5, 1) => 6.026488873606981, Error: 0.4098532603075853
Prediction: (0, 4) => 4.1280194293201715, Error: 0.19019390499271527
Prediction: (2, 8) => 10.164410397119553, Error: -0.28284068014799857
Prediction: (3, 0) => 2.473197699343582, Error: -0.8354672030740797
Prediction: (2, 2) => 4.374167680624926, Error: 0.5041819173914641
Prediction: (5, 1) => 6.029745451048595, Error: 0.407591632259388
Prediction: (0, 4) => 4.128689918732615, Error: 0.1912044840811138
Prediction: (2, 8) => 10.164028974045065, Error: -0.28215642098094307
Prediction: (3, 0) => 2.474271388887674, Error: -0.8336458015115502
Prediction: (2, 2) => 4.373523171633604, Err

Prediction: (2, 2) => 4.232328634892745, Error: 0.30973153353886307
Prediction: (5, 1) => 6.089734353014693, Error: 0.2363559320411177
Prediction: (0, 4) => 4.090514839992044, Error: 0.13475916020353473
Prediction: (2, 8) => 10.106736249343468, Error: -0.1775168698032541
Prediction: (3, 0) => 2.685487539597314, Error: -0.48712680367982575
Prediction: (2, 2) => 4.232031594663691, Error: 0.309329454604061
Prediction: (5, 1) => 6.08963170145956, Error: 0.23608068661835535
Prediction: (0, 4) => 4.090386458960131, Error: 0.13456835732666406
Prediction: (2, 8) => 10.106601338933826, Error: -0.17727822513441005
Prediction: (3, 0) => 2.6858509172854452, Error: -0.48654599323840575
Prediction: (2, 2) => 4.231735373096102, Error: 0.3089284992467416
Prediction: (5, 1) => 6.089529306177018, Error: 0.2358061353792067
Prediction: (0, 4) => 4.090258461566303, Error: 0.1343781238005759
Prediction: (2, 8) => 10.106466793397994, Error: -0.17704026772138626
Prediction: (3, 0) => 2.686213395828861, Error:

Train the model again using a reduced learning rate.

In [4]:
for i in range(1000):
    for j in range(len(x)):
        model.update_weights_and_biases(x[j][0], x[j][1], y[j], 0.001)

Prediction: (2, 2) => 4.295450493968877, Error: 0.3030458494187709
Prediction: (5, 1) => 6.310431124330571, Error: 0.3314248882056585
Prediction: (0, 4) => 4.243800862434232, Error: 0.2523646157862762
Prediction: (2, 8) => 10.172864568715262, Error: 0.20704990273031143
Prediction: (3, 0) => 2.6652159192479967, Error: -0.3470738998605505
Prediction: (2, 2) => 4.2647384452850625, Error: 0.2714937759729441
Prediction: (5, 1) => 6.272121617167064, Error: 0.2903969530509549
Prediction: (0, 4) => 4.21112021243469, Error: 0.2184736619107417
Prediction: (2, 8) => 10.110035441793872, Error: 0.13160852822391256
Prediction: (3, 0) => 2.6465823461692684, Error: -0.3663092725409345
Prediction: (2, 2) => 4.241940492480717, Error: 0.2480798520901084
Prediction: (5, 1) => 6.243911698598696, Error: 0.2602083994917912
Prediction: (0, 4) => 4.18656220435132, Error: 0.19301896511244365
Prediction: (2, 8) => 10.062855400352207, Error: 0.07509854009080286
Prediction: (3, 0) => 2.6333066539422307, Error: -0.

Prediction: (0, 4) => 4.098451022844657, Error: 0.10177943212640894
Prediction: (2, 8) => 9.897776594658858, Error: -0.12167172926072034
Prediction: (3, 0) => 2.644580110488226, Error: -0.3680812819222683
Prediction: (2, 2) => 4.179937148919581, Error: 0.18443086444135126
Prediction: (5, 1) => 6.183817761255053, Error: 0.19595866606583723
Prediction: (0, 4) => 4.098460462623885, Error: 0.10178919009934173
Prediction: (2, 8) => 9.897772262510143, Error: -0.12167685753375146
Prediction: (3, 0) => 2.6445904196019185, Error: -0.3680706015143489
Prediction: (2, 2) => 4.179941431875412, Error: 0.18443525256236804
Prediction: (5, 1) => 6.183811487436984, Error: 0.1959519473535165
Prediction: (0, 4) => 4.0984698867369325, Error: 0.10179893187812361
Prediction: (2, 8) => 9.897767938203657, Error: -0.12168197651985224
Prediction: (3, 0) => 2.644600716379199, Error: -0.3680599338838184
Prediction: (2, 2) => 4.179945706266775, Error: 0.18443963190662416
Prediction: (5, 1) => 6.183805221362952, Err

Prediction: (0, 4) => 4.098947550242118, Error: 0.10229269648285388
Prediction: (2, 8) => 9.897549632871542, Error: -0.12194039529667045
Prediction: (3, 0) => 2.6451293497554755, Error: -0.36751225628488626
Prediction: (2, 2) => 4.180160341290683, Error: 0.18465953428444237
Prediction: (5, 1) => 6.183483924043287, Error: 0.19560116947855377
Prediction: (0, 4) => 4.098956162513587, Error: 0.10230159905967895
Prediction: (2, 8) => 9.89754571321548, Error: -0.12194503510260546
Prediction: (3, 0) => 2.6451390111483306, Error: -0.36750224674579757
Prediction: (2, 2) => 4.1801641727039245, Error: 0.18466345967094355
Prediction: (5, 1) => 6.183478060279516, Error: 0.19559489040528621
Prediction: (0, 4) => 4.098964760376467, Error: 0.10231048674232568
Prediction: (2, 8) => 9.897541800728158, Error: -0.12194966641968286
Prediction: (3, 0) => 2.6451486612961963, Error: -0.3674922488534982
Prediction: (2, 2) => 4.1801679962614395, Error: 0.18466737700701596
Prediction: (5, 1) => 6.183472203663389

Prediction: (2, 2) => 4.180938975777472, Error: 0.18545720401362775
Prediction: (5, 1) => 6.182140659165563, Error: 0.19416304395499662
Prediction: (0, 4) => 4.1008042823969815, Error: 0.10421202007613228
Prediction: (2, 8) => 9.896723712382496, Error: -0.12291798286233302
Prediction: (3, 0) => 2.647366713795078, Error: -0.36519415880506356
Prediction: (2, 2) => 4.180941088664104, Error: 0.18545936828651932
Prediction: (5, 1) => 6.182136359229626, Error: 0.19415844139185623
Prediction: (0, 4) => 4.100809732993665, Error: 0.10421765443489761
Prediction: (2, 8) => 9.8967213658105, Error: -0.1229207600755533
Prediction: (3, 0) => 2.647373910770562, Error: -0.36518670170298195
Prediction: (2, 2) => 4.1809431963588315, Error: 0.1854615272388651
Prediction: (5, 1) => 6.182132064018658, Error: 0.19415384389381796
Prediction: (0, 4) => 4.100815174082179, Error: 0.10422327896510986
Prediction: (2, 8) => 9.89671902397105, Error: -0.12292353168576042
Prediction: (3, 0) => 2.6473811003378076, Erro

Prediction: (2, 2) => 4.181009889585639, Error: 0.18552984108767134
Prediction: (5, 1) => 6.181992925586132, Error: 0.19400491739272052
Prediction: (0, 4) => 4.100989491020838, Error: 0.10440347295114805
Prediction: (2, 8) => 9.896644350952537, Error: -0.12301190752639002
Prediction: (3, 0) => 2.64761427410479, Error: -0.3649376487873077
Prediction: (2, 2) => 4.181011825457377, Error: 0.18553182395980095
Prediction: (5, 1) => 6.181988786755264, Error: 0.19400048753080323
Prediction: (0, 4) => 4.1009946174370295, Error: 0.10440877220611267
Prediction: (2, 8) => 9.896642165741719, Error: -0.1230144937009161
Prediction: (3, 0) => 2.6476212184938834, Error: -0.364930453289134
Prediction: (2, 2) => 4.181013756411231, Error: 0.18553380179220103
Prediction: (5, 1) => 6.181984652400602, Error: 0.19399606246724765
Prediction: (0, 4) => 4.100999734846561, Error: 0.10441406215093973
Prediction: (2, 8) => 9.896639985014115, Error: -0.12301707456771993
Prediction: (3, 0) => 2.6476281558652714, Erro

Prediction: (5, 1) => 6.18185068519505, Error: 0.19385267937342032
Prediction: (0, 4) => 4.101163646675909, Error: 0.10458350042900744
Prediction: (2, 8) => 9.896570491312554, Error: -0.12309931856822409
Prediction: (3, 0) => 2.647853222277753, Error: -0.3646900583553325
Prediction: (2, 2) => 4.181076536366252, Error: 0.18559810444912284
Prediction: (5, 1) => 6.181846698990943, Error: 0.1938484131219429
Prediction: (0, 4) => 4.101168466007472, Error: 0.10458848225225381
Prediction: (2, 8) => 9.896568458962559, Error: -0.12310172377098638
Prediction: (3, 0) => 2.647859927389047, Error: -0.3646831106767525
Prediction: (2, 2) => 4.181078299906807, Error: 0.1855999107204056
Prediction: (5, 1) => 6.181842717027584, Error: 0.1938441514163669
Prediction: (0, 4) => 4.101173276807254, Error: 0.10459345525624641
Prediction: (2, 8) => 9.896566430859629, Error: -0.12310412394569248
Prediction: (3, 0) => 2.6478666258521932, Error: -0.3646761698836385
Prediction: (2, 2) => 4.181080058796217, Error: 

Prediction: (0, 4) => 4.10214600508898, Error: 0.10559898631627007
Prediction: (2, 8) => 9.896174624993456, Error: -0.12356776312464568
Prediction: (3, 0) => 2.6493678785589503, Error: -0.3631205160016129
Prediction: (2, 2) => 4.181392707578517, Error: 0.18592186490082696
Prediction: (5, 1) => 6.1809553104979384, Error: 0.19289460355026833
Prediction: (0, 4) => 4.1021490308558395, Error: 0.10560211414452603
Prediction: (2, 8) => 9.896173485538652, Error: -0.12356911128975234
Prediction: (3, 0) => 2.649373185268785, Error: -0.3631150166349517
Prediction: (2, 2) => 4.18139349364676, Error: 0.18592266948525182
Prediction: (5, 1) => 6.1809522192879305, Error: 0.19289129665516214
Prediction: (0, 4) => 4.10215205086113, Error: 0.10560523601709804
Prediction: (2, 8) => 9.896172348952113, Error: -0.12357045605947725
Prediction: (3, 0) => 2.6493784874818385, Error: -0.36310952192554424
Prediction: (2, 2) => 4.181394276576561, Error: 0.1859234708534956
Prediction: (5, 1) => 6.180949130942449, Er

Prediction: (2, 2) => 4.181415637919289, Error: 0.18594533363078014
Prediction: (5, 1) => 6.180860795465638, Error: 0.192793496601106
Prediction: (0, 4) => 4.102240015496546, Error: 0.10569616787660419
Prediction: (2, 8) => 9.896139563377336, Error: -0.12360924601563816
Prediction: (3, 0) => 2.649535496673272, Error: -0.36294681125679595
Prediction: (2, 2) => 4.181416328921272, Error: 0.1859460407942466
Prediction: (5, 1) => 6.180857791044305, Error: 0.19279028273160037
Prediction: (0, 4) => 4.102242861235393, Error: 0.10569910961146611
Prediction: (2, 8) => 9.89613851354401, Error: -0.12361048808940112
Prediction: (3, 0) => 2.649540662852023, Error: -0.36294145743620154
Prediction: (2, 2) => 4.181417016931965, Error: 0.18594674489233398
Prediction: (5, 1) => 6.1808547893544965, Error: 0.19278707178990295
Prediction: (0, 4) => 4.102245701490445, Error: 0.10570204567777886
Prediction: (2, 8) => 9.896137466440576, Error: -0.12361172693160327
Prediction: (3, 0) => 2.6495458247493007, Erro

Show the weights and biases that were computed during training.

In [5]:
model.show_weights_and_biases()

h1: weights=[0.04532524687895712, -0.8903815259271088], bias=0.023610612209226924
h2: weights=[0.21346908229195602, -1.5904676059685565], bias=-0.1368067316663643
h3: weights=[0.937380778195616, 0.9039277080301252], bias=-0.029227857222607064
y: weights=[-0.5833987954979141, -1.0419467705323109, 1.0551137839275915], bias=0.33122463786858797


Ask the model to add 4 and 4.

In [6]:
model.predict(4, 4)

8.071545779956857

This is a *very* simple network with just nine trainable parameters, and it updates its weights and biases after every training sample. Imagine how much longer training would take if the network had 100 million parameters, which isn't uncommon in deep learning. In practice, data scientists run batches of training samples through the network and update the weights and biases after each batch, a technique known as *mini-batch gradient descent*. Still, this network proves the principle that gradient descent can converge on a reasonable set of weights and biases, and it shows in a very limited way how gradient descent is enacted.