### 1. Re-code the house price machine learning

In [1]:
%matplotlib inline

from sklearn.datasets import load_boston
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random

data = load_boston()
df = pd.DataFrame(data['data'], columns=data['feature_names'])
y = data['target']
print(df.describe())
print(data['DESCR'])
X = df.RM # I choose RM column

             CRIM          ZN       INDUS        CHAS         NOX          RM  \
count  506.000000  506.000000  506.000000  506.000000  506.000000  506.000000   
mean     3.613524   11.363636   11.136779    0.069170    0.554695    6.284634   
std      8.601545   23.322453    6.860353    0.253994    0.115878    0.702617   
min      0.006320    0.000000    0.460000    0.000000    0.385000    3.561000   
25%      0.082045    0.000000    5.190000    0.000000    0.449000    5.885500   
50%      0.256510    0.000000    9.690000    0.000000    0.538000    6.208500   
75%      3.677083   12.500000   18.100000    0.000000    0.624000    6.623500   
max     88.976200  100.000000   27.740000    1.000000    0.871000    8.780000   

              AGE         DIS         RAD         TAX     PTRATIO           B  \
count  506.000000  506.000000  506.000000  506.000000  506.000000  506.000000   
mean    68.574901    3.795043    9.549407  408.237154   18.455534  356.674032   
std     28.148861    2.1057

###### 1. Random Choose Method to get optimal *k* and *b*

For example, you can change the loss function: $Loss = \frac{1}{n} sum({y_i - \hat{y_i}})^2$ to $Loss = \frac{1}{n} sum(|{y_i - \hat{y_i}}|)$

And you can change the learning rate and observe the performance.

In [2]:
def RMSE(y, y_hat):
    return np.sqrt(sum((y - y_hat) ** 2) / len(y))

def MAE(y, y_hat):
    return sum(abs(y - y_hat)) / len(y)

def loss_random(X, y, n, loss=RMSE):
    loss_min = float('inf')
    k_best, b_best = 0, 0
    for i in range(n):
        k = random.random() * 200 - 100
        b = random.random() * 200 - 100
        y_hat = k * X + b
        loss_new = loss(y, y_hat)
        if loss_new < loss_min:
            loss_min = loss_new
            k_best, b_best = k, b
            print(f"round: {i}, k: {k_best}, b: {b_best}, {loss.__name__}: {loss_min}")
        
    return (k_best, b_best)
loss_random(X, y, 2000, RMSE)

round: 0, k: -84.89216131360669, b: 43.27872723386801, RMSE: 517.0395342752083
round: 1, k: -29.93201511411418, b: 96.63474215106558, RMSE: 117.44166453627572
round: 2, k: -5.898747215703622, b: 30.41671347758779, RMSE: 31.72354696925381
round: 14, k: 2.900817971855659, b: -3.985411735523087, RMSE: 11.45567904455622
round: 52, k: 18.25547361061095, b: -95.87204007188208, RMSE: 9.919303576007216
round: 501, k: 7.604334333818841, b: -29.509685330231065, RMSE: 7.923736080242468
round: 889, k: 10.512341765806426, b: -47.00406703666036, RMSE: 7.5250120089701875


(10.512341765806426, -47.00406703666036)

###### 2.Supervised Direction to get optimal *k* and *b*

In [3]:
# This is a combination of supervising and random walking.
def loss_spvs_dr(X, y, n, alpha=0.1, loss=RMSE):
    loss_min = float('inf')
    direction = [(1, 1), (1, -1), (-1, 1), (-1, -1)]
    
    k = random.random() * 200 - 100
    b = random.random() * 200 - 100
    
    for i in range(n):
        dr_k, dr_b = random.choice(direction) # random walk. Can we walk 4 directions and compare which one is the best?
        k_new = k + dr_k * alpha
        b_new = b + dr_b * alpha
        y_hat = k_new * X + b_new
        loss_new = loss(y, y_hat)
        if loss_new < loss_min:
            k, b = k_new, b_new
            k_best, b_best = k_new, b_new
            loss_min = loss_new
            print(f"round: {i}, k: {k_best}, b: {b_best}, {loss.__name__}: {loss_min}")
    return (k_best, b_best)
loss_spvs_dr(X, y, 2000)

round: 0, k: -14.011183137764585, b: 34.6876472727352, RMSE: 77.89524527603272
round: 1, k: -13.911183137764585, b: 34.7876472727352, RMSE: 77.17088855995638
round: 3, k: -13.811183137764585, b: 34.8876472727352, RMSE: 76.44667433917958
round: 4, k: -13.711183137764586, b: 34.7876472727352, RMSE: 75.91731777803741
round: 9, k: -13.611183137764586, b: 34.8876472727352, RMSE: 75.19333795861398
round: 10, k: -13.511183137764586, b: 34.9876472727352, RMSE: 74.46951174791349
round: 11, k: -13.411183137764587, b: 34.8876472727352, RMSE: 73.94039001225256
round: 12, k: -13.311183137764587, b: 34.7876472727352, RMSE: 73.41132590732764
round: 13, k: -13.211183137764587, b: 34.8876472727352, RMSE: 72.68785058685935
round: 14, k: -13.111183137764588, b: 34.9876472727352, RMSE: 71.96454436922932
round: 15, k: -13.011183137764588, b: 34.8876472727352, RMSE: 71.43574117120278
round: 16, k: -12.911183137764588, b: 34.9876472727352, RMSE: 70.7127152524332
round: 18, k: -12.811183137764589, b: 34.88764

round: 223, k: -3.1111831377646193, b: 34.18764727273519, RMSE: 13.396587406344453
round: 225, k: -3.011183137764619, b: 34.28764727273519, RMSE: 12.934223046122145
round: 227, k: -2.911183137764619, b: 34.387647272735194, RMSE: 12.497634549998562
round: 228, k: -2.811183137764619, b: 34.28764727273519, RMSE: 12.185390911190108
round: 234, k: -2.711183137764619, b: 34.18764727273519, RMSE: 11.888861807935763
round: 235, k: -2.611183137764619, b: 34.08764727273519, RMSE: 11.609251467444173
round: 236, k: -2.5111831377646188, b: 34.18764727273519, RMSE: 11.273065000976192
round: 238, k: -2.4111831377646187, b: 34.08764727273519, RMSE: 11.039015656789969
round: 239, k: -2.3111831377646186, b: 33.98764727273519, RMSE: 10.826178455113137
round: 241, k: -2.2111831377646185, b: 33.88764727273519, RMSE: 10.635826922695763
round: 242, k: -2.1111831377646184, b: 33.98764727273519, RMSE: 10.43257497240345
round: 243, k: -2.0111831377646183, b: 34.08764727273519, RMSE: 10.277528820079674
round: 24

(-1.711183137764618, 34.18764727273519)

#### Walk through all 4 directions and find the smallest loss

In [4]:
# Here I try completely supervised direction: walk all 4 directions at the same time, 
# then select the one with smallest loss.
def calculate_loss(X, y, n, alpha=0.01, loss=RMSE):
    '''
    calculate the loss for all 4 directions and select the smallest one
    '''
    loss_min = float('inf')
    direction = [(1, 1), (1, -1), (-1, 1), (-1, -1)]
    
    k = random.random() * 200 - 100
    b = random.random() * 200 - 100
    
    for i in range(n):
        loss_complete = []
        best_data = []
        k_b = []
        for coord in direction:
            dr_k, dr_b = coord
            k_new = k + dr_k * alpha
            b_new = b + dr_b * alpha
            k_b.append((k_new, b_new))
            y_hat = k_new * X + b_new
            loss_complete.append(loss(y, y_hat))
        best_data = sorted(zip(loss_complete, k_b), key=lambda x: x[0])[0]
        loss_new, (k_n, b_n) = best_data
        if loss_new < loss_min:
            k, b = k_n, b_n
            k_best, b_best = k_new, b_new
            loss_min = loss_new
            print(f"round: {i}, k: {k}, b: {b}, {loss.__name__}: {loss_min}")
    return (k_best, b_best)
calculate_loss(X, y, 2000)

round: 0, k: 93.309677934698, b: -8.450638010611877, RMSE: 558.6089231459674
round: 1, k: 93.299677934698, b: -8.460638010611877, RMSE: 558.5357481509604
round: 2, k: 93.28967793469799, b: -8.470638010611877, RMSE: 558.4625731582445
round: 3, k: 93.27967793469799, b: -8.480638010611877, RMSE: 558.3893981678207
round: 4, k: 93.26967793469798, b: -8.490638010611876, RMSE: 558.3162231796899
round: 5, k: 93.25967793469798, b: -8.500638010611876, RMSE: 558.243048193852
round: 6, k: 93.24967793469797, b: -8.510638010611876, RMSE: 558.1698732103094
round: 7, k: 93.23967793469797, b: -8.520638010611876, RMSE: 558.0966982290619
round: 8, k: 93.22967793469796, b: -8.530638010611876, RMSE: 558.0235232501111
round: 9, k: 93.21967793469796, b: -8.540638010611875, RMSE: 557.9503482734573
round: 10, k: 93.20967793469795, b: -8.550638010611875, RMSE: 557.8771732991022
round: 11, k: 93.19967793469795, b: -8.560638010611875, RMSE: 557.8039983270459
round: 12, k: 93.18967793469794, b: -8.570638010611875,

round: 125, k: 92.05967793469736, b: -9.70063801061185, RMSE: 549.4620668162727
round: 126, k: 92.04967793469736, b: -9.71063801061185, RMSE: 549.3888921147752
round: 127, k: 92.03967793469735, b: -9.72063801061185, RMSE: 549.3157174156851
round: 128, k: 92.02967793469735, b: -9.73063801061185, RMSE: 549.242542719004
round: 129, k: 92.01967793469734, b: -9.74063801061185, RMSE: 549.1693680247314
round: 130, k: 92.00967793469734, b: -9.75063801061185, RMSE: 549.0961933328692
round: 131, k: 91.99967793469733, b: -9.76063801061185, RMSE: 549.0230186434183
round: 132, k: 91.98967793469733, b: -9.770638010611849, RMSE: 548.9498439563798
round: 133, k: 91.97967793469732, b: -9.780638010611849, RMSE: 548.8766692717542
round: 134, k: 91.96967793469732, b: -9.790638010611849, RMSE: 548.8034945895429
round: 135, k: 91.95967793469731, b: -9.800638010611848, RMSE: 548.7303199097462
round: 136, k: 91.9496779346973, b: -9.810638010611848, RMSE: 548.657145232366
round: 137, k: 91.9396779346973, b: -9

round: 247, k: 90.83967793469674, b: -10.920638010611825, RMSE: 540.5347712909447
round: 248, k: 90.82967793469673, b: -10.930638010611824, RMSE: 540.4615968904011
round: 249, k: 90.81967793469673, b: -10.940638010611824, RMSE: 540.3884224923861
round: 250, k: 90.80967793469672, b: -10.950638010611824, RMSE: 540.3152480969011
round: 251, k: 90.79967793469672, b: -10.960638010611824, RMSE: 540.2420737039461
round: 252, k: 90.78967793469671, b: -10.970638010611824, RMSE: 540.1688993135232
round: 253, k: 90.7796779346967, b: -10.980638010611823, RMSE: 540.0957249256331
round: 254, k: 90.7696779346967, b: -10.990638010611823, RMSE: 540.0225505402765
round: 255, k: 90.7596779346967, b: -11.000638010611823, RMSE: 539.9493761574552
round: 256, k: 90.74967793469669, b: -11.010638010611823, RMSE: 539.8762017771693
round: 257, k: 90.73967793469669, b: -11.020638010611822, RMSE: 539.8030273994201
round: 258, k: 90.72967793469668, b: -11.030638010611822, RMSE: 539.7298530242094
round: 259, k: 90.7

round: 366, k: 89.64967793469613, b: -12.1106380106118, RMSE: 531.8270356655662
round: 367, k: 89.63967793469612, b: -12.120638010611799, RMSE: 531.753861573298
round: 368, k: 89.62967793469612, b: -12.130638010611799, RMSE: 531.6806874836847
round: 369, k: 89.61967793469611, b: -12.140638010611799, RMSE: 531.6075133967277
round: 370, k: 89.60967793469611, b: -12.150638010611798, RMSE: 531.5343393124275
round: 371, k: 89.5996779346961, b: -12.160638010611798, RMSE: 531.4611652307857
round: 372, k: 89.5896779346961, b: -12.170638010611798, RMSE: 531.3879911518028
round: 373, k: 89.5796779346961, b: -12.180638010611798, RMSE: 531.3148170754805
round: 374, k: 89.56967793469609, b: -12.190638010611798, RMSE: 531.2416430018197
round: 375, k: 89.55967793469608, b: -12.200638010611797, RMSE: 531.1684689308216
round: 376, k: 89.54967793469608, b: -12.210638010611797, RMSE: 531.0952948624871
round: 377, k: 89.53967793469607, b: -12.220638010611797, RMSE: 531.0221207968171
round: 378, k: 89.5296

round: 487, k: 88.43967793469551, b: -13.320638010611773, RMSE: 522.9729900966114
round: 488, k: 88.4296779346955, b: -13.330638010611773, RMSE: 522.8998163337228
round: 489, k: 88.4196779346955, b: -13.340638010611773, RMSE: 522.8266425736263
round: 490, k: 88.4096779346955, b: -13.350638010611773, RMSE: 522.7534688163224
round: 491, k: 88.39967793469549, b: -13.360638010611773, RMSE: 522.6802950618136
round: 492, k: 88.38967793469548, b: -13.370638010611772, RMSE: 522.6071213101002
round: 493, k: 88.37967793469548, b: -13.380638010611772, RMSE: 522.5339475611834
round: 494, k: 88.36967793469547, b: -13.390638010611772, RMSE: 522.4607738150648
round: 495, k: 88.35967793469547, b: -13.400638010611772, RMSE: 522.3876000717446
round: 496, k: 88.34967793469546, b: -13.410638010611772, RMSE: 522.3144263312254
round: 497, k: 88.33967793469546, b: -13.420638010611771, RMSE: 522.2412525935073
round: 498, k: 88.32967793469545, b: -13.430638010611771, RMSE: 522.1680788585919
round: 499, k: 88.3

round: 607, k: 87.2396779346949, b: -14.520638010611748, RMSE: 514.1921588200101
round: 608, k: 87.22967793469489, b: -14.530638010611748, RMSE: 514.1189854007282
round: 609, k: 87.21967793469489, b: -14.540638010611747, RMSE: 514.0458119843835
round: 610, k: 87.20967793469488, b: -14.550638010611747, RMSE: 513.9726385709779
round: 611, k: 87.19967793469488, b: -14.560638010611747, RMSE: 513.8994651605122
round: 612, k: 87.18967793469487, b: -14.570638010611747, RMSE: 513.8262917529878
round: 613, k: 87.17967793469487, b: -14.580638010611747, RMSE: 513.7531183484061
round: 614, k: 87.16967793469486, b: -14.590638010611746, RMSE: 513.6799449467684
round: 615, k: 87.15967793469486, b: -14.600638010611746, RMSE: 513.6067715480755
round: 616, k: 87.14967793469485, b: -14.610638010611746, RMSE: 513.5335981523291
round: 617, k: 87.13967793469484, b: -14.620638010611746, RMSE: 513.4604247595302
round: 618, k: 87.12967793469484, b: -14.630638010611746, RMSE: 513.3872513696804
round: 619, k: 87

round: 727, k: 86.03967793469428, b: -15.720638010611722, RMSE: 505.41136983875884
round: 728, k: 86.02967793469428, b: -15.730638010611722, RMSE: 505.33819678114884
round: 729, k: 86.01967793469427, b: -15.740638010611722, RMSE: 505.2650237266324
round: 730, k: 86.00967793469427, b: -15.750638010611722, RMSE: 505.1918506752102
round: 731, k: 85.99967793469426, b: -15.760638010611721, RMSE: 505.1186776268844
round: 732, k: 85.98967793469426, b: -15.770638010611721, RMSE: 505.0455045816558
round: 733, k: 85.97967793469425, b: -15.780638010611721, RMSE: 504.97233153952607
round: 734, k: 85.96967793469425, b: -15.79063801061172, RMSE: 504.89915850049675
round: 735, k: 85.95967793469424, b: -15.80063801061172, RMSE: 504.8259854645685
round: 736, k: 85.94967793469424, b: -15.81063801061172, RMSE: 504.75281243174334
round: 737, k: 85.93967793469423, b: -15.82063801061172, RMSE: 504.67963940202225
round: 738, k: 85.92967793469423, b: -15.83063801061172, RMSE: 504.60646637540634
round: 739, k:

round: 855, k: 84.75967793469363, b: -17.000638010611873, RMSE: 496.04524407499787
round: 856, k: 84.74967793469362, b: -17.010638010611874, RMSE: 495.97207142453783
round: 857, k: 84.73967793469362, b: -17.020638010611876, RMSE: 495.8988987773499
round: 858, k: 84.72967793469361, b: -17.030638010611877, RMSE: 495.8257261334355
round: 859, k: 84.7196779346936, b: -17.04063801061188, RMSE: 495.75255349279564
round: 860, k: 84.7096779346936, b: -17.05063801061188, RMSE: 495.67938085543216
round: 861, k: 84.6996779346936, b: -17.060638010611882, RMSE: 495.6062082213465
round: 862, k: 84.68967793469359, b: -17.070638010611884, RMSE: 495.5330355905403
round: 863, k: 84.67967793469359, b: -17.080638010611885, RMSE: 495.4598629630142
round: 864, k: 84.66967793469358, b: -17.090638010611887, RMSE: 495.3866903387709
round: 865, k: 84.65967793469358, b: -17.10063801061189, RMSE: 495.31351771781056
round: 866, k: 84.64967793469357, b: -17.11063801061189, RMSE: 495.24034510013576
round: 867, k: 84

round: 978, k: 83.529677934693, b: -18.230638010612065, RMSE: 487.0450330642319
round: 979, k: 83.519677934693, b: -18.240638010612066, RMSE: 486.9718608273565
round: 980, k: 83.50967793469299, b: -18.250638010612068, RMSE: 486.8986885939382
round: 981, k: 83.49967793469298, b: -18.26063801061207, RMSE: 486.825516363978
round: 982, k: 83.48967793469298, b: -18.27063801061207, RMSE: 486.7523441374779
round: 983, k: 83.47967793469297, b: -18.280638010612073, RMSE: 486.6791719144391
round: 984, k: 83.46967793469297, b: -18.290638010612074, RMSE: 486.6059996948632
round: 985, k: 83.45967793469296, b: -18.300638010612076, RMSE: 486.53282747875176
round: 986, k: 83.44967793469296, b: -18.310638010612077, RMSE: 486.45965526610695
round: 987, k: 83.43967793469295, b: -18.32063801061208, RMSE: 486.3864830569294
round: 988, k: 83.42967793469295, b: -18.33063801061208, RMSE: 486.31331085122116
round: 989, k: 83.41967793469294, b: -18.340638010612082, RMSE: 486.2401386489837
round: 990, k: 83.4096

round: 1098, k: 82.32967793469238, b: -19.430638010612252, RMSE: 478.26438976600025
round: 1099, k: 82.31967793469238, b: -19.440638010612254, RMSE: 478.19121795533374
round: 1100, k: 82.30967793469237, b: -19.450638010612256, RMSE: 478.11804614831755
round: 1101, k: 82.29967793469237, b: -19.460638010612257, RMSE: 478.04487434495394
round: 1102, k: 82.28967793469236, b: -19.47063801061226, RMSE: 477.97170254524457
round: 1103, k: 82.27967793469236, b: -19.48063801061226, RMSE: 477.8985307491905
round: 1104, k: 82.26967793469235, b: -19.490638010612262, RMSE: 477.8253589567941
round: 1105, k: 82.25967793469235, b: -19.500638010612263, RMSE: 477.75218716805654
round: 1106, k: 82.24967793469234, b: -19.510638010612265, RMSE: 477.67901538297997
round: 1107, k: 82.23967793469234, b: -19.520638010612267, RMSE: 477.6058436015658
round: 1108, k: 82.22967793469233, b: -19.530638010612268, RMSE: 477.5326718238155
round: 1109, k: 82.21967793469233, b: -19.54063801061227, RMSE: 477.45950004973173

round: 1219, k: 81.11967793469177, b: -20.64063801061244, RMSE: 469.41062767098515
round: 1220, k: 81.10967793469176, b: -20.650638010612443, RMSE: 469.3374563145251
round: 1221, k: 81.09967793469175, b: -20.660638010612445, RMSE: 469.2642849619259
round: 1222, k: 81.08967793469175, b: -20.670638010612446, RMSE: 469.1911136131902
round: 1223, k: 81.07967793469174, b: -20.680638010612448, RMSE: 469.1179422683191
round: 1224, k: 81.06967793469174, b: -20.69063801061245, RMSE: 469.0447709273145
round: 1225, k: 81.05967793469173, b: -20.70063801061245, RMSE: 468.97159959017824
round: 1226, k: 81.04967793469173, b: -20.710638010612453, RMSE: 468.8984282569128
round: 1227, k: 81.03967793469172, b: -20.720638010612454, RMSE: 468.82525692751886
round: 1228, k: 81.02967793469172, b: -20.730638010612456, RMSE: 468.75208560199917
round: 1229, k: 81.01967793469171, b: -20.740638010612457, RMSE: 468.67891428035495
round: 1230, k: 81.00967793469171, b: -20.75063801061246, RMSE: 468.6057429625881
rou

round: 1337, k: 79.93967793469116, b: -21.820638010612626, RMSE: 460.7764347536452
round: 1338, k: 79.92967793469116, b: -21.830638010612628, RMSE: 460.7032638655866
round: 1339, k: 79.91967793469115, b: -21.84063801061263, RMSE: 460.6300929816103
round: 1340, k: 79.90967793469115, b: -21.85063801061263, RMSE: 460.55692210171856
round: 1341, k: 79.89967793469114, b: -21.860638010612632, RMSE: 460.4837512259128
round: 1342, k: 79.88967793469114, b: -21.870638010612634, RMSE: 460.4105803541954
round: 1343, k: 79.87967793469113, b: -21.880638010612635, RMSE: 460.33740948656833
round: 1344, k: 79.86967793469113, b: -21.890638010612637, RMSE: 460.26423862303324
round: 1345, k: 79.85967793469112, b: -21.90063801061264, RMSE: 460.19106776359195
round: 1346, k: 79.84967793469112, b: -21.91063801061264, RMSE: 460.11789690824725
round: 1347, k: 79.83967793469111, b: -21.92063801061264, RMSE: 460.04472605700033
round: 1348, k: 79.8296779346911, b: -21.930638010612643, RMSE: 459.97155520985336
rou

round: 1457, k: 78.73967793469055, b: -23.020638010612814, RMSE: 451.99595789176385
round: 1458, k: 78.72967793469054, b: -23.030638010612815, RMSE: 451.92278750784203
round: 1459, k: 78.71967793469054, b: -23.040638010612817, RMSE: 451.8496171282453
round: 1460, k: 78.70967793469053, b: -23.05063801061282, RMSE: 451.77644675297546
round: 1461, k: 78.69967793469053, b: -23.06063801061282, RMSE: 451.7032763820348
round: 1462, k: 78.68967793469052, b: -23.07063801061282, RMSE: 451.63010601542567
round: 1463, k: 78.67967793469052, b: -23.080638010612823, RMSE: 451.5569356531496
round: 1464, k: 78.66967793469051, b: -23.090638010612825, RMSE: 451.48376529520954
round: 1465, k: 78.6596779346905, b: -23.100638010612826, RMSE: 451.4105949416073
round: 1466, k: 78.6496779346905, b: -23.110638010612828, RMSE: 451.337424592344
round: 1467, k: 78.6396779346905, b: -23.12063801061283, RMSE: 451.26425424742314
round: 1468, k: 78.62967793469049, b: -23.13063801061283, RMSE: 451.19108390684596
round:

round: 1577, k: 77.53967793468993, b: -24.220638010613, RMSE: 443.21554330328
round: 1578, k: 77.52967793468993, b: -24.230638010613003, RMSE: 443.1423734537554
round: 1579, k: 77.51967793468992, b: -24.240638010613004, RMSE: 443.06920360881816
round: 1580, k: 77.50967793468992, b: -24.250638010613006, RMSE: 442.99603376847
round: 1581, k: 77.49967793468991, b: -24.260638010613008, RMSE: 442.92286393271394
round: 1582, k: 77.48967793468991, b: -24.27063801061301, RMSE: 442.849694101552
round: 1583, k: 77.4796779346899, b: -24.28063801061301, RMSE: 442.7765242749859
round: 1584, k: 77.4696779346899, b: -24.290638010613012, RMSE: 442.7033544530188
round: 1585, k: 77.45967793468989, b: -24.300638010613014, RMSE: 442.6301846356523
round: 1586, k: 77.44967793468989, b: -24.310638010613015, RMSE: 442.55701482288936
round: 1587, k: 77.43967793468988, b: -24.320638010613017, RMSE: 442.48384501473146
round: 1588, k: 77.42967793468988, b: -24.33063801061302, RMSE: 442.41067521118174
round: 1589,

round: 1705, k: 76.25967793468928, b: -25.5006380106132, RMSE: 433.8498406413952
round: 1706, k: 76.24967793468927, b: -25.510638010613203, RMSE: 433.7766713980365
round: 1707, k: 76.23967793468927, b: -25.520638010613204, RMSE: 433.7035021595687
round: 1708, k: 76.22967793468926, b: -25.530638010613206, RMSE: 433.6303329259943
round: 1709, k: 76.21967793468926, b: -25.540638010613208, RMSE: 433.557163697316
round: 1710, k: 76.20967793468925, b: -25.55063801061321, RMSE: 433.48399447353535
round: 1711, k: 76.19967793468925, b: -25.56063801061321, RMSE: 433.410825254656
round: 1712, k: 76.18967793468924, b: -25.570638010613212, RMSE: 433.3376560406794
round: 1713, k: 76.17967793468924, b: -25.580638010613214, RMSE: 433.26448683160845
round: 1714, k: 76.16967793468923, b: -25.590638010613215, RMSE: 433.1913176274457
round: 1715, k: 76.15967793468923, b: -25.600638010613217, RMSE: 433.11814842819354
round: 1716, k: 76.14967793468922, b: -25.61063801061322, RMSE: 433.0449792338551
round: 1

round: 1837, k: 74.9396779346886, b: -26.820638010613408, RMSE: 424.19154375226697
round: 1838, k: 74.9296779346886, b: -26.83063801061341, RMSE: 424.11837517654493
round: 1839, k: 74.9196779346886, b: -26.84063801061341, RMSE: 424.04520660605556
round: 1840, k: 74.90967793468859, b: -26.850638010613412, RMSE: 423.9720380408021
round: 1841, k: 74.89967793468858, b: -26.860638010613414, RMSE: 423.89886948078595
round: 1842, k: 74.88967793468858, b: -26.870638010613416, RMSE: 423.82570092601105
round: 1843, k: 74.87967793468857, b: -26.880638010613417, RMSE: 423.75253237647934
round: 1844, k: 74.86967793468857, b: -26.89063801061342, RMSE: 423.67936383219376
round: 1845, k: 74.85967793468856, b: -26.90063801061342, RMSE: 423.60619529315767
round: 1846, k: 74.84967793468856, b: -26.91063801061342, RMSE: 423.5330267593724
round: 1847, k: 74.83967793468855, b: -26.920638010613423, RMSE: 423.4598582308419
round: 1848, k: 74.82967793468855, b: -26.930638010613425, RMSE: 423.3866897075683
roun

round: 1978, k: 73.52967793468788, b: -28.230638010613628, RMSE: 413.8748274905061
round: 1979, k: 73.51967793468788, b: -28.24063801061363, RMSE: 413.80165968020344
round: 1980, k: 73.50967793468787, b: -28.25063801061363, RMSE: 413.7284918755343
round: 1981, k: 73.49967793468787, b: -28.260638010613633, RMSE: 413.65532407650187
round: 1982, k: 73.48967793468786, b: -28.270638010613634, RMSE: 413.5821562831094
round: 1983, k: 73.47967793468786, b: -28.280638010613636, RMSE: 413.5089884953594
round: 1984, k: 73.46967793468785, b: -28.290638010613637, RMSE: 413.4358207132556
round: 1985, k: 73.45967793468785, b: -28.30063801061364, RMSE: 413.3626529368005
round: 1986, k: 73.44967793468784, b: -28.31063801061364, RMSE: 413.2894851659971
round: 1987, k: 73.43967793468784, b: -28.320638010613642, RMSE: 413.2163174008485
round: 1988, k: 73.42967793468783, b: -28.330638010613644, RMSE: 413.14314964135764
round: 1989, k: 73.41967793468783, b: -28.340638010613645, RMSE: 413.06998188752783
roun

(73.31967793468777, -28.44063801061366)

###### 3.Gradient Descent to get optimal *k* and *b*

#### Equations:

$$ RMSE = \frac{1}{n}\sum{(y - (kx+b))^2} = \frac{1}{n}\sum(y^2 -2y(kx+b) + (kx+b)^2)) = \frac{1}{n}\sum{(y^2 - 2yxk - 2yb + k^2x^2 + 2kxb + b^2)}$$

$$ \frac{\partial{_{loss}}}{\partial{_k}} = \frac{2}{n}(-y + kx + b)x = \frac{2}{n}(-y + \hat{y})x$$

$$ \frac{\partial{_{loss}}}{\partial{_b}} = \frac{2}{n}(-y + kx + b) = \frac{2}{n}(-y + \hat{y})$$

In [5]:
def partial_k(x, y, y_hat):
    n = len(y)
    gradient = 0
    for x_i, y_i, y_hat_i in zip(list(x), list(y), list(y_hat)):
        gradient += (y_i - y_hat_i) * x_i
    return -2 / n * gradient

def partial_b(y, y_hat):
    n = len(y)
    gradient = 0
    for y_i, y_hat_i in zip(list(y), list(y_hat)):
        gradient += (y_i - y_hat_i)
    return -2 / n * gradient

def gradient(X, y, n, alpha=0.01, loss=RMSE):
    loss_min = float('inf')
    
    k = random.random() * 200 - 100
    b = random.random() * 200 - 100
    
    for i in range(n):
        y_hat = k * X + b
        loss_new = RMSE(y, y_hat)
        if loss_new < loss_min:
            loss_min = loss_new
            if i % 1000 == 0:
                print(f"round: {i}, k: {k}, b: {b}, {loss.__name__}: {loss_min}")
        k_gradient = partial_k(X, y, y_hat)
        b_gradient = partial_b(y, y_hat)
        k += -k_gradient * alpha
        b += -b_gradient * alpha
    return (k, b)
k,b = gradient(X, y, 20000)
print(f'k: {k}, b: {b}')

round: 0, k: 21.973219901695515, b: -60.65693405747956, RMSE: 56.032694147953386
round: 1000, k: 12.480613507231768, b: -56.16168676288914, RMSE: 7.020761350581001
round: 2000, k: 11.758403150551079, b: -51.56762145623035, RMSE: 6.864272322522944
round: 3000, k: 11.190577017799928, b: -47.95561241682018, RMSE: 6.765726472279578
round: 4000, k: 10.744132967181455, b: -45.11572912611663, RMSE: 6.704084577110146
round: 5000, k: 10.39312360210179, b: -42.88291713234211, RMSE: 6.665694712890597
round: 6000, k: 10.117148204506174, b: -41.1274050604681, RMSE: 6.641852523942843
round: 7000, k: 9.900167042021303, b: -39.74716228002433, RMSE: 6.627071253251235
round: 8000, k: 9.729569136253364, b: -38.66196886923024, RMSE: 6.617917499952458
round: 9000, k: 9.595439310792425, b: -37.80875322085546, RMSE: 6.6122526462838
round: 10000, k: 9.489981919641034, b: -37.1379263107485, RMSE: 6.608748412169915
round: 11000, k: 9.407067763074888, b: -36.610499576158176, RMSE: 6.606581293077354
round: 12000,

###### 4. Try different Loss function and learning rate. 

$$ MAE = \frac{1}{n}\sum{|y - \hat{y}|} = \frac{1}{n}{\sum{|y - (kx+b)}|} $$

$$ \frac{\partial{loss}}{\partial{k}} = \frac{1}{n}{\sum{\left\{
\begin{array}{rcl}
-x       &      & {y - \hat{y} > 0}\\
x     &      & {y - \hat{y} < 0}
\end{array} \right.}} $$

$$ \frac{\partial{loss}}{\partial{b}} = \frac{1}{n}{\sum {\left\{
\begin{array}{rcl}
-1       &      & {y - \hat{y} > 0}\\
1     &      & {y - \hat{y} < 0}
\end{array} \right.}} $$

In [6]:
# gradient(X, y, 20000, alpha=0.1) # overflow
# gradient(X, y, 20000, alpha=1) # overflow

def partial_k(x, y, y_hat):
    n = len(x)
    gradient = 0
    for xi, y_i, y_hat_i in zip(x, y, y_hat):
        if y_i - y_hat_i > 0:
            gradient += xi
        else:
            gradient -= xi
    return -1 / n * gradient

def partial_b(y, y_hat):
    n = len(y)
    gradient = 0
    for y_i, y_hat_i in zip(y, y_hat):
        if y_i - y_hat_i > 0:
            gradient += 1
        else:
            gradient -= 1
    return -1 / n
# MAE with default alpha
gradient(X, y, 20000, loss=MAE) 

round: 0, k: 21.082558361947818, b: 72.2998020025725, MAE: 182.57650366861864


(-8.337893100500922, 72.69505891961762)

In [7]:
# MAE with alpha = 0.1
gradient(X, y, 20000, loss=MAE, alpha=0.1)

round: 0, k: 77.27374600884403, b: 86.70902123451057, MAE: 551.9313139312618


(-11.247893516848169, 90.66159040439331)

In [8]:
# MAE with alpha = 0.5
gradient(X, y, 20000, loss=MAE, alpha=1)

round: 0, k: -33.89921134031128, b: 33.26095423449465, MAE: 204.66170713462762


(-5.141768652560754, 72.78664593406414)