In [8]:
from typing import Callable, List
import random
Vector = List[float]

In [2]:
# difference quotient - the derivative is defined as the limit of the difference quotient 

def difference_quotient(f: Callable[[float], float], 
                        x: float, 
                        h: float) -> float:
    '''compute difference quotient'''
    return (f(x + h) - f(x)) / h

In [4]:
def partial_difference_quotient(f: Callable[[Vector], float], 
                                v: Vector,
                                i: int,
                                h: float) -> float:
    '''returns i-th partial difference quotient of f at v'''
    
    # add h only to the i-th element of the vector
    w = [v_j + (h if i == j else 0) for j, v_j in enumerate(v)]
    return (f(w) - f(v)) / h

In [5]:
def estimate_gradient(f: Callable[[Vector], float], 
                      v: Vector, 
                      h: float = 0.0001):
    return [partial_difference_quotient(f, v, i, h) for i in range(len(v))]

In [7]:
def scalar_multiply(c: float, v: Vector) -> Vector:
    return [c * x for x in v]

def add(v: Vector, w: Vector) -> Vector:
    assert len(v) == len(w), 'length of vectors must be the same'
    return [x + y for x, y in zip(v, w)]

def gradient_step(v: Vector, 
                  gradient: Vector, 
                  step_size: float) -> Vector:
    '''moves `step_size` in the gradient direction from v'''
    assert len(v) == len(gradient), 'length of vector and gradient must be the same'
    step = scalar_multiply(step_size, gradient)
    return add(v, step)

In [9]:
# gradient for y = x^2 function => derivative is 2x 
def sum_of_squares_gradient(v: Vector) -> Vector:
    return [2 * v_i for v_i in v]

# pick random starting point
v = [random.uniform(-10, 10) for i in range(3)]

for epoch in range(1000):
    grad = sum_of_squares_gradient(v)
    v = gradient_step(v, grad, -0.01)
    print(epoch, v)

0 [6.081137314203362, -7.913315947420309, -3.2987814709553493]
1 [5.959514567919295, -7.755049628471903, -3.2328058415362424]
2 [5.840324276560909, -7.599948635902464, -3.1681497247055175]
3 [5.723517791029691, -7.447949663184415, -3.104786730211407]
4 [5.609047435209097, -7.298990669920727, -3.042690995607179]
5 [5.496866486504914, -7.153010856522312, -2.981837175695035]
6 [5.386929156774816, -7.009950639391866, -2.922200432181134]
7 [5.279190573639319, -6.869751626604028, -2.8637564235375113]
8 [5.173606762166533, -6.732356594071947, -2.806481295066761]
9 [5.070134626923203, -6.597709462190508, -2.750351669165426]
10 [4.9687319343847385, -6.465755272946698, -2.6953446357821176]
11 [4.869357295697044, -6.336440167487764, -2.6414377430664753]
12 [4.771970149783103, -6.209711364138008, -2.588608988205146]
13 [4.67653074678744, -6.085517136855248, -2.536836808441043]
14 [4.5830001318516915, -5.963806794118144, -2.4861000722722224]
15 [4.491340129214658, -5.84453065823578, -2.436378070826

974 [1.7305435560860793e-08, -2.251936967135466e-08, -9.387528553523198e-09]
975 [1.6959326849643577e-08, -2.206898227792757e-08, -9.199777982452733e-09]
976 [1.6620140312650704e-08, -2.1627602632369016e-08, -9.015782422803678e-09]
977 [1.628773750639769e-08, -2.1195050579721636e-08, -8.835466774347605e-09]
978 [1.5961982756269736e-08, -2.0771149568127205e-08, -8.658757438860653e-09]
979 [1.564274310114434e-08, -2.035572657676466e-08, -8.48558229008344e-09]
980 [1.5329888239121454e-08, -1.9948612045229367e-08, -8.31587064428177e-09]
981 [1.5023290474339024e-08, -1.954963980432478e-08, -8.149553231396135e-09]
982 [1.4722824664852243e-08, -1.9158647008238282e-08, -7.986562166768213e-09]
983 [1.4428368171555198e-08, -1.8775474068073517e-08, -7.826830923432847e-09]
984 [1.4139800808124094e-08, -1.8399964586712046e-08, -7.67029430496419e-09]
985 [1.3857004791961612e-08, -1.8031965294977805e-08, -7.516888418864906e-09]
986 [1.357986469612238e-08, -1.767132598907825e-08, -7.366550650487608e-0

In [20]:
def linear_gradient(x: float, y: float, theta: Vector) -> Vector:
    slope, intercept = theta
    predicted = intercept + x * slope
    error = predicted - y
    squared_error = error ** 2
    gradient = [2 * error * x, 2 * error]
    return gradient

In [23]:
def vector_mean(vectors: List[Vector]) -> Vector:
    v_len = len(vectors[0])
    sums = [0] * v_len
    
    for i in range(v_len):
        for vector in vectors:
            sums[i] += vector[i]
            
    return [x / len(vectors) for x in sums]

# using linear function: y = 20 * x + b
inputs = [(x, 20 * x + 5) for x in range(-50, 50)]

theta = [random.uniform(-1, 1), random.uniform(-1, 1)]
learning_rate = 0.001

for epoch in range(5000):
    # compute mean of the gradients
    grad = vector_mean([linear_gradient(x, y, theta) for x, y in inputs])
    # take a step in that direction
    theta = gradient_step(theta, grad, -learning_rate)
    print(epoch, theta)


0 [33.267724471806574, 0.42634413141293237]
1 [11.145854121436422, 0.4487591676219131]
2 [25.901164060169528, 0.4490075034081058]
3 [16.059372579370333, 0.46401065246145917]
4 [22.623862500212454, 0.4691420037359066]
5 [18.245352854362032, 0.48082758222864697]
6 [21.165830473722753, 0.48811127991855163]
7 [19.217879185306842, 0.49830088783243726]
8 [20.51717288428817, 0.5065221652420792]
9 [19.650552208345033, 0.5160262937958832]
10 [20.228597703327658, 0.5246447934166364]
11 [19.843049976673868, 0.5338241015331308]
12 [20.100219489660063, 0.5425995033067383]
13 [19.928696199900045, 0.5516145237897849]
14 [20.04311124919046, 0.5604399909421054]
15 [19.966805236780907, 0.5693622222094117]
16 [20.017710269289346, 0.5781903030017738]
17 [19.98376544068701, 0.5870516326650596]
18 [20.00641550269443, 0.5958612948404164]
19 [19.991316720997656, 0.6046759877534299]
20 [20.001396423082316, 0.6134579524989207]
21 [19.994682043756594, 0.622232433017005]
22 [19.999169309247367, 0.6309826501947277

920 [19.999565059303922, 4.2758234798068]
921 [19.999565928924092, 4.277271397906491]
922 [19.99956679680554, 4.278716421039602]
923 [19.999567662951744, 4.280158554994328]
924 [19.999568527366183, 4.281597805547291]
925 [19.999569390052304, 4.283034178463562]
926 [19.999570251013576, 4.284467679496687]
927 [19.99957111025344, 4.285898314388707]
928 [19.999571967775342, 4.2873260888701825]
929 [19.999572823582717, 4.288751008660218]
930 [19.99957367767899, 4.29017307946648]
931 [19.99957453006758, 4.291592306985226]
932 [19.99957538075191, 4.293008696901324]
933 [19.99957622973538, 4.294422254888273]
934 [19.999577077021392, 4.295832986608232]
935 [19.99957792261334, 4.2972408977120375]
936 [19.999578766514613, 4.2986459938392265]
937 [19.99957960872859, 4.300048280618062]
938 [19.999580449258648, 4.301447763665555]
939 [19.999581288108146, 4.302844448587483]
940 [19.999582125280455, 4.304238340978416]
941 [19.999582960778916, 4.305629446421739]
942 [19.999583794606885, 4.3070177704896

1573 [19.99988228164789, 4.803998873036299]
1574 [19.999882517013894, 4.804390757571874]
1575 [19.999882751909304, 4.804781858573745]
1576 [19.999882986335066, 4.805172177608506]
1577 [19.99988322029212, 4.805561716239624]
1578 [19.999883453781397, 4.805950476027437]
1579 [19.999883686803834, 4.806338458529163]
1580 [19.99988391936037, 4.8067256652989085]
1581 [19.99988415145193, 4.807112097887671]
1582 [19.99988438307945, 4.807497757843348]
1583 [19.99988461424385, 4.80788264671074]
1584 [19.999884844946063, 4.808266766031562]
1585 [19.999885075187006, 4.808650117344445]
1586 [19.999885304967613, 4.809032702184943]
1587 [19.999885534288786, 4.809414522085541]
1588 [19.999885763151465, 4.809795578575659]
1589 [19.99988599155655, 4.810175873181659]
1590 [19.999886219504962, 4.810555407426852]
1591 [19.999886446997618, 4.810934182831503]
1592 [19.99988667403542, 4.8113122009128375]
1593 [19.999886900619288, 4.811689463185047]
1594 [19.99988712675012, 4.812065971159297]
1595 [19.999887352

2276 [19.99997117306724, 4.95200313963891]
2277 [19.99997123070379, 4.952099104532699]
2278 [19.999971288225105, 4.952194877554338]
2279 [19.99997134563141, 4.952290459087455]
2280 [19.999971402922938, 4.952385849514911]
2281 [19.999971460099914, 4.952481049218805]
2282 [19.999971517162574, 4.952576058580467]
2283 [19.999971574111143, 4.952670877980469]
2284 [19.999971630945847, 4.9527655077986195]
2285 [19.999971687666918, 4.952859948413968]
2286 [19.99997174427458, 4.952954200204807]
2287 [19.99997180076906, 4.953048263548672]
2288 [19.999971857150584, 4.9531421388223436]
2289 [19.999971913419383, 4.9532358264018495]
2290 [19.99997196957567, 4.953329326662465]
2291 [19.99997202561969, 4.953422639978716]
2292 [19.999972081551643, 4.9535157667243785]
2293 [19.999972137371778, 4.9536087072724815]
2294 [19.999972193080296, 4.953701461995308]
2295 [19.999972248677437, 4.953794031264398]
2296 [19.999972304163414, 4.9538864154505475]
2297 [19.999972359538454, 4.95397861492381]
2298 [19.9999

2944 [19.999992428627667, 4.987393660522536]
2945 [19.99999244376587, 4.987418865630119]
2946 [19.999992458873795, 4.9874440203426245]
2947 [19.999992473951522, 4.987469124760813]
2948 [19.999992488999094, 4.987494178985243]
2949 [19.99999250401659, 4.987519183116272]
2950 [19.999992519004053, 4.987544137254056]
2951 [19.99999253396155, 4.987569041498552]
2952 [19.999992548889146, 4.987593895949517]
2953 [19.99999256378689, 4.9876187007065065]
2954 [19.99999257865485, 4.987643455868881]
2955 [19.999992593493083, 4.987668161535797]
2956 [19.99999260830165, 4.987692817806219]
2957 [19.999992623080605, 4.987717424778909]
2958 [19.999992637830015, 4.987741982552432]
2959 [19.99999265254993, 4.987766491225157]
2960 [19.99999266724042, 4.987790950895257]
2961 [19.999992681901535, 4.987815361660707]
2962 [19.999992696533337, 4.987839723619287]
2963 [19.999992711135885, 4.987864036868582]
2964 [19.999992725709234, 4.987888301505981]
2965 [19.999992740253447, 4.987912517628678]
2966 [19.9999927

3623 [19.999998054686497, 4.996761051850232]
3624 [19.999998058575958, 4.996767527801218]
3625 [19.99999806245764, 4.996773990804192]
3626 [19.99999806633156, 4.996780440885041]
3627 [19.999998070197734, 4.996786878069602]
3628 [19.99999807405618, 4.996793302383661]
3629 [19.999998077906913, 4.996799713852949]
3630 [19.999998081749943, 4.99680611250315]
3631 [19.99999808558529, 4.996812498359894]
3632 [19.999998089412973, 4.99681887144876]
3633 [19.999998093232996, 4.996825231795275]
3634 [19.99999809704539, 4.996831579424917]
3635 [19.99999810085015, 4.9968379143631125]
3636 [19.999998104647315, 4.996844236635236]
3637 [19.999998108436877, 4.996850546266613]
3638 [19.99999811221887, 4.996856843282516]
3639 [19.999998115993296, 4.99686312770817]
3640 [19.99999811976018, 4.996869399568746]
3641 [19.99999812351953, 4.996875658889369]
3642 [19.999998127271365, 4.99688190569511]
3643 [19.999998131015694, 4.9968881400109915]
3644 [19.999998134752545, 4.996894361861985]
3645 [19.999998138481

4315 [19.999999513026832, 4.9991891893812825]
4316 [19.999999514000486, 4.999190810515547]
4317 [19.999999514972192, 4.9991924284085165]
4318 [19.999999515941955, 4.999194043066672]
4319 [19.99999951690978, 4.99919565449648]
4320 [19.999999517875672, 4.9991972627043975]
4321 [19.99999951883963, 4.999198867696864]
4322 [19.999999519801662, 4.99920046948031]
4323 [19.999999520761772, 4.999202068061151]
4324 [19.999999521719957, 4.99920366344579]
4325 [19.999999522676234, 4.999205255640619]
4326 [19.999999523630592, 4.999206844652014]
4327 [19.999999524583046, 4.999208430486341]
4328 [19.999999525533596, 4.999210013149952]
4329 [19.99999952648224, 4.999211592649186]
4330 [19.999999527428994, 4.999213168990369]
4331 [19.999999528373852, 4.999214742179817]
4332 [19.99999952931682, 4.999216312223831]
4333 [19.999999530257906, 4.9992178791287]
4334 [19.999999531197105, 4.999219442900701]
4335 [19.99999953213443, 4.999221003546097]
4336 [19.99999953306988, 4.999222561071139]
4337 [19.999999534

4895 [19.99999984746463, 4.99974602851212]
4896 [19.999999847769605, 4.99974653630256]
4897 [19.999999848073976, 4.999747043077724]
4898 [19.999999848377737, 4.999747548839642]
4899 [19.99999984868089, 4.9997480535903405]
4900 [19.99999984898344, 4.9997485573318405]
4901 [19.999999849285377, 4.9997490600661605]
4902 [19.999999849586718, 4.999749561795314]
4903 [19.999999849887455, 4.99975006252131]
4904 [19.99999985018759, 4.999750562246154]
4905 [19.999999850487125, 4.99975106097185]
4906 [19.99999985078606, 4.999751558700393]
4907 [19.999999851084397, 4.999752055433778]
4908 [19.999999851382142, 4.999752551173995]
4909 [19.999999851679284, 4.99975304592303]
4910 [19.99999985197584, 4.999753539682863]
4911 [19.9999998522718, 4.999754032455473]
4912 [19.999999852567164, 4.999754524242834]
4913 [19.999999852861944, 4.999755015046916]
4914 [19.99999985315613, 4.999755504869684]
4915 [19.99999985344973, 4.999755993713101]
4916 [19.999999853742743, 4.999756481579124]
4917 [19.9999998540351