In [2]:
import sys
import tqdm
import pandas as pd


import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

import numpy as np

In [3]:
# Step 2. Parse and visualize data
# parse train data: read CSV files with train features (train_x) and train targets (train_y)
x_train = pd.read_csv("D:\\Dataset\\train\\train_x.csv", header=None)
y_train = pd.read_csv("D:\\Dataset\\train\\train_y.csv", header=None)

# show first 10 samples
print(pd.concat([x_train, y_train], axis=1).head(10))

x_train = x_train.to_numpy()
y_train = y_train.to_numpy()
print("Shape of train features:", x_train.shape)
print("Shape of train targets:", y_train.shape)

       0      1     0
0  23.98  6.459  11.8
1  21.52  6.193  11.0
2   7.74  6.750  23.7
3   4.81  7.249  35.4
4  18.06  5.454  15.2
5   5.90  6.487  24.4
6   2.94  6.998  33.4
7   6.36  7.163  31.6
8  17.44  6.749  13.4
9   4.56  6.975  34.9
Shape of train features: (354, 2)
Shape of train targets: (354, 1)


In [4]:
# Step 3. Prototypes.

# In this demo we will use linear regression to predict targets from features.
# In linear regression model with parameters thetas 
# the prediction y is calculated from features x using linear combination of x and thetas.
# For example, for the case of 2 features: 
# y = theta_0 * x_o + theta_1 * x_1

# Let's define some helper functions

def predict_fn(x, thetas):
    '''
    Predict target from features x using parameters thetas and linear regression
    
    param x: input features, shape NxM, N - number of samples to predict, M - number of features
    param thetas: vector of linear regression parameters, shape Mx1
    return y_hat: predicted scalar value for each input samples, shape Nx1
    '''    
    # TODO: calculate y_hat using linear regression
    y_hat = np.zeros((x.shape[0], 1))
    for i in range(len(x)):
        y_hat[i] = thetas[0] * x[i][0] + thetas[1] * x[i][1]
    return y_hat


def loss_fn(x_train, y_train, thetas):
    '''
    Calculate average loss value for train dataset (x_train, y_train).
    
    param x_train: input features, shape NxM, N - number of samples to predict, M - number of features
    param y_train: input tagrets, shape Nx1
    param thetas: vector of linear regression parameters, shape Mx1
    return loss: predicted scalar value for each input samples, shape Mx1
    '''
    y_predicted = predict_fn(x_train, thetas)    
    loss = np.mean(np.power(y_train - y_predicted, 2))   
    return loss


def gradient_fn(x_train, y_train, thetas):
    '''
    Calculate gradient value for linear regression.
    
    param x_train: input features, shape NxM, N - number of samples to predict, M - number of features
    param y_train: input tagrets, shape Nx1
    param thetas: vector of linear regression parameters, shape Mx1
    return g: predicted scalar value for each input samples, shape Mx1
    '''  
    # TODO: calculate vector gradient
    g = np.zeros_like(thetas)
    for i in range(len(x_train)):
        g[0] += -2 * x_train[i][0] * (y_train[i] - x_train[i][0] * thetas[0] - x_train[i][1] * thetas[1])
        g[1] += -2 * x_train[i][1] * (y_train[i] - x_train[i][0] * thetas[0] - x_train[i][1] * thetas[1])
    g[0] = g[0] / len(x_train)
    g[1] = g[1] / len(x_train)
    return g

In [None]:
# Step 4. Gradient descent.

# now let's find optimal parameters using gradient descent
MAX_ITER = 1000000
thetas = np.random.randn(2, 1)
alpha = 4.22e-3

progress = tqdm.tqdm(range(MAX_ITER), "Training", file=sys.stdout)
loss_val = loss_fn(x_train, y_train, thetas)
progress.set_postfix(loss_val=loss_val)

for iter in progress:
    gradient = gradient_fn(x_train, y_train, thetas)
    print(gradient)
    thetas_2 = thetas - alpha * gradient
    
    # TODO: add stop conditions
    if (abs(thetas_2[0] - thetas[0]) < 0.00000001) or (abs(thetas_2[1] -thetas[1]) < 0.00000001):
        progress.close()
        loss_val = loss_fn(x_train, y_train, thetas)
        print("Stop condition detected")
        print("Final loss:", loss_val)
        break
    
    if iter % 100 == 0:
        loss_val = loss_fn(x_train, y_train, thetas_2)
        progress.set_postfix(loss_val=f"{loss_val:8.4f}", thetas=f"{thetas_2[0][0]:5.4f} {thetas_2[1][0]:5.4f}")
    thetas = thetas_2
    
progress.close()

Training:   0%|                                                          | 0/1000000 [00:00<?, ?it/s, loss_val=1.27e+3][[-827.64405959]
 [-446.58327978]]
Training:   0%|                                  | 0/1000000 [00:00<?, ?it/s, loss_val=1208.8392, thetas=2.9688 0.9562][[906.04602541]
 [234.77671013]]
[[-829.11926445]
 [-425.058166  ]]
[[893.36451643]
 [249.95592526]]
[[-829.34129779]
 [-406.8979044 ]]
[[881.8993738 ]
 [262.10735005]]
[[-828.53527778]
 [-391.51896387]]
[[871.44549438]
 [271.76105434]]
Training:   0%|                        | 8/1000000 [00:00<3:33:31, 78.06it/s, loss_val=1208.8392, thetas=2.9688 0.9562][[-826.88706002]
 [-378.43944746]]
[[861.8334579 ]
 [279.35475087]]
[[-824.55007743]
 [-367.2613877 ]]
[[852.92331133]
 [285.24988424]]
[[-821.65098866]
 [-357.65612599]]
[[844.59943586]
 [289.74491703]]
[[-818.29434257]
 [-349.35223902]]
[[836.76630823]
 [293.08630151]]
[[-814.56643007]
 [-342.12556842]]
Training:   0%|                       | 17/1000000 [00:00<3:26:3

Training:   0%|                       | 171/1000000 [00:01<3:13:52, 85.95it/s, loss_val=487.9329, thetas=0.6449 5.3875][[483.97380886]
 [186.97575435]]
[[-482.31589473]
 [-186.33526389]]
[[480.66367208]
 [185.69693615]]
[[-479.01709831]
 [-185.06082356]]
[[477.37617506]
 [184.42686416]]
[[-475.7408639 ]
 [-183.79509998]]
[[474.11116294]
 [183.16547861]]
[[-472.48703721]
 [-182.53803351]]
[[470.86848192]
 [181.91272013]]
Training:   0%|                       | 180/1000000 [00:02<3:13:59, 85.90it/s, loss_val=487.9329, thetas=0.6449 5.3875][[-469.25546498]
 [-181.28956489]]
[[467.64797925]
 [180.66852975]]
[[-466.04599503]
 [-180.04963529]]
[[464.44950324]
 [179.43284888]]
[[-462.85847619]
 [-178.81818629]]
[[461.27290323]
 [178.20561935]]
[[-459.69275833]
 [-177.59515987]]
[[458.11802958]
 [176.98678337]]
[[-456.54869236]
 [-176.3804984 ]]
Training:   0%|                       | 189/1000000 [00:02<3:12:58, 86.35it/s, loss_val=487.9329, thetas=0.6449 5.3875][[454.9847337 ]
 [175.77628354]

[[270.06874646]
 [104.33686565]]
Training:   0%|                      | 342/1000000 [00:04<3:16:07, 84.95it/s, loss_val=144.5677, thetas=-0.0011 5.1383][[-269.14359499]
 [-103.97944775]]
[[268.22161274]
 [103.62325422]]
[[-267.30278885]
 [-103.26828088]]
[[266.38711249]
 [102.91452354]]
[[-265.47457288]
 [-102.56197804]]
[[264.56515929]
 [102.21064022]]
[[-263.65886099]
 [-101.86050595]]
[[262.75566733]
 [101.51157111]]
[[-261.85556765]
 [-101.16383158]]
Training:   0%|                      | 351/1000000 [00:04<3:18:20, 84.00it/s, loss_val=144.5677, thetas=-0.0011 5.1383][[260.95855138]
 [100.81728327]]
[[-260.06460793]
 [-100.47192211]]
[[259.17372679]
 [100.12774401]]
[[-258.28589747]
 [ -99.78474494]]
[[257.4011095 ]
 [ 99.44292086]]
[[-256.51935248]
 [ -99.10226773]]
[[255.64061602]
 [ 98.76278154]]
[[-254.76488977]
 [ -98.42445831]]
[[253.89216342]
 [ 98.08729404]]
Training:   0%|                      | 360/1000000 [00:04<3:18:35, 83.89it/s, loss_val=144.5677, thetas=-0.0011 5.138

[[-151.22273214]
 [ -58.42255386]]
[[150.70470133]
 [ 58.2224207 ]]
[[-150.18844509]
 [ -58.02297312]]
Training:   0%|                       | 513/1000000 [00:06<3:15:52, 85.05it/s, loss_val=57.5426, thetas=-0.3263 5.0127][[149.67395735]
 [ 57.82420877]]
[[-149.16123204]
 [ -57.62612532]]
[[148.65026313]
 [ 57.42872041]]
[[-148.14104461]
 [ -57.23199175]]
[[147.63357047]
 [ 57.035937  ]]
[[-147.12783474]
 [ -56.84055385]]
[[146.62383147]
 [ 56.64584002]]
[[-146.12155472]
 [ -56.45179319]]
[[145.62099858]
 [ 56.2584111 ]]
Training:   0%|                       | 522/1000000 [00:06<3:16:29, 84.78it/s, loss_val=57.5426, thetas=-0.3263 5.0127][[-145.12215715]
 [ -56.06569146]]
[[144.62502456]
 [ 55.87363201]]
[[-144.12959495]
 [ -55.68223047]]
[[143.63586249]
 [ 55.49148461]]
[[-143.14382137]
 [ -55.30139216]]
[[142.65346579]
 [ 55.1119509 ]]
[[-142.16478999]
 [ -54.92315859]]
[[141.67778819]
 [ 54.73501301]]
[[-141.19245468]
 [ -54.54751195]]
Training:   0%|                       | 531/100

Training:   0%|                       | 685/1000000 [00:08<3:12:27, 86.54it/s, loss_val=42.8720, thetas=-0.4220 4.9757][[82.95033691]
 [32.04657433]]
[[-82.66618102]
 [-31.93679512]]
[[82.38299854]
 [31.82739196]]
[[-82.10078613]
 [-31.71836358]]
[[81.81954047]
 [31.60970869]]
[[-81.53925825]
 [-31.50142601]]
[[81.25993617]
 [31.39351426]]
[[-80.98157094]
 [-31.28597217]]
[[80.70415928]
 [31.17879849]]
Training:   0%|                       | 694/1000000 [00:08<3:13:22, 86.13it/s, loss_val=42.8720, thetas=-0.4220 4.9757][[-80.42769793]
 [-31.07199194]]
[[80.15218362]
 [30.96555127]]
[[-79.87761313]
 [-30.85947522]]
[[79.6039832 ]
 [30.75376255]]
[[-79.33129063]
 [-30.64841201]]
[[79.05953219]
 [30.54342236]]
[[-78.7887047 ]
 [-30.43879236]]
Training:   0%|                       | 694/1000000 [00:08<3:13:22, 86.13it/s, loss_val=35.4863, thetas=-0.4900 4.9495][[78.51880495]
 [30.33452079]]
[[-78.24982978]
 [-30.23060641]]
Training:   0%|                       | 703/1000000 [00:08<3:16:57,

[[-44.88054376]
 [-17.33890102]]
[[44.72680031]
 [17.27950463]]
Training:   0%|                       | 866/1000000 [00:10<3:16:08, 84.90it/s, loss_val=31.7681, thetas=-0.5382 4.9309][[-44.57358353]
 [-17.22031171]]
[[44.42089161]
 [17.16132155]]
[[-44.26872275]
 [-17.10253348]]
[[44.11707516]
 [17.04394679]]
[[-43.96594706]
 [-16.98556079]]
[[43.81533667]
 [16.92737481]]
[[-43.66524221]
 [-16.86938814]]
[[43.51566191]
 [16.81160012]]
[[-43.36659402]
 [-16.75401006]]
Training:   0%|                       | 875/1000000 [00:10<3:18:19, 83.96it/s, loss_val=31.7681, thetas=-0.5382 4.9309][[43.21803678]
 [16.69661727]]
[[-43.06998844]
 [-16.6394211 ]]
[[42.92244725]
 [16.58242085]]
[[-42.77541149]
 [-16.52561587]]
[[42.62887941]
 [16.46900547]]
[[-42.48284929]
 [-16.41258901]]
[[42.33731942]
 [16.3563658 ]]
[[-42.19228808]
 [-16.30033519]]
[[42.04775356]
 [16.24449653]]
Training:   0%|                       | 884/1000000 [00:10<3:16:31, 84.73it/s, loss_val=31.7681, thetas=-0.5382 4.9309][[-

[[24.28276103]
 [ 9.38126758]]
[[-24.19957765]
 [ -9.34913097]]
[[24.11667922]
 [ 9.31710445]]
Training:   0%|                      | 1046/1000000 [00:12<3:15:29, 85.17it/s, loss_val=28.9538, thetas=-0.5966 4.9083][[-24.03406477]
 [ -9.28518764]]
[[23.95173332]
 [ 9.25338016]]
[[-23.86968391]
 [ -9.22168165]]
[[23.78791557]
 [ 9.19009172]]
[[-23.70642734]
 [ -9.15861001]]
[[23.62521825]
 [ 9.12723614]]
[[-23.54428736]
 [ -9.09596975]]
[[23.4636337 ]
 [ 9.06481046]]
[[-23.38325633]
 [ -9.03375791]]
Training:   0%|                      | 1055/1000000 [00:12<3:15:38, 85.10it/s, loss_val=28.9538, thetas=-0.5966 4.9083][[23.30315431]
 [ 9.00281174]]
[[-23.22332668]
 [ -8.97197157]]
[[23.14377251]
 [ 8.94123705]]
[[-23.06449086]
 [ -8.91060782]]
[[22.98548081]
 [ 8.88008351]]
[[-22.90674141]
 [ -8.84966377]]
[[22.82827174]
 [ 8.81934823]]
[[-22.75007087]
 [ -8.78913654]]
[[22.6721379 ]
 [ 8.75902835]]
Training:   0%|                      | 1064/1000000 [00:12<3:15:10, 85.30it/s, loss_val=28.

Training:   0%|                      | 1226/1000000 [00:14<3:14:50, 85.43it/s, loss_val=28.2406, thetas=-0.6261 4.8969][[-12.95916154]
 [ -5.00657079]]
[[12.91476844]
 [ 4.9894202 ]]
[[-12.87052742]
 [ -4.97232837]]
[[12.82643795]
 [ 4.95529509]]
[[-12.78249952]
 [ -4.93832016]]
[[12.7387116 ]
 [ 4.92140338]]
[[-12.69507368]
 [ -4.90454455]]
[[12.65158525]
 [ 4.88774347]]
[[-12.6082458 ]
 [ -4.87099994]]
Training:   0%|                      | 1235/1000000 [00:14<3:15:43, 85.05it/s, loss_val=28.2406, thetas=-0.6261 4.8969][[12.56505481]
 [ 4.85431377]]
[[-12.52201177]
 [ -4.83768476]]
[[12.47911618]
 [ 4.82111272]]
[[-12.43636754]
 [ -4.80459744]]
[[12.39376533]
 [ 4.78813874]]
[[-12.35130907]
 [ -4.77173643]]
[[12.30899824]
 [ 4.7553903 ]]
[[-12.26683236]
 [ -4.73910016]]
[[12.22481092]
 [ 4.72286583]]
Training:   0%|                      | 1244/1000000 [00:14<3:14:07, 85.75it/s, loss_val=28.2406, thetas=-0.6261 4.8969][[-12.18293343]
 [ -4.70668711]]
[[12.14119939]
 [ 4.69056381]]
[[-

Training:   0%|                      | 1407/1000000 [00:16<3:19:35, 83.39it/s, loss_val=28.0598, thetas=-0.6409 4.8912][[6.96363982]
 [2.69029409]]
[[-6.93978507]
 [-2.68107818]]
[[6.91601204]
 [2.67189384]]
[[-6.89232045]
 [-2.66274096]]
[[6.86871002]
 [2.65361943]]
[[-6.84518046]
 [-2.64452916]]
[[6.82173151]
 [2.63547002]]
[[-6.79836289]
 [-2.62644191]]
[[6.77507432]
 [2.61744474]]
Training:   0%|                      | 1416/1000000 [00:16<3:17:22, 84.32it/s, loss_val=28.0598, thetas=-0.6409 4.8912][[-6.75186552]
 [-2.60847838]]
[[6.72873623]
 [2.59954274]]
[[-6.70568618]
 [-2.59063771]]
[[6.68271508]
 [2.58176318]]
[[-6.65982267]
 [-2.57291905]]
[[6.63700869]
 [2.56410522]]
[[-6.61427285]
 [-2.55532159]]
[[6.5916149 ]
 [2.54656804]]
[[-6.56903457]
 [-2.53784448]]
Training:   0%|                      | 1425/1000000 [00:16<3:14:42, 85.48it/s, loss_val=28.0598, thetas=-0.6409 4.8912][[6.54653159]
 [2.5291508 ]]
[[-6.52410569]
 [-2.52048691]]
[[6.50175662]
 [2.51185269]]
[[-6.47948411]

[[3.67827588]
 [1.42104476]]
[[-3.66567552]
 [-1.4161768 ]]
[[3.65311833]
 [1.41132553]]
Training:   0%|                      | 1596/1000000 [00:18<3:13:32, 85.98it/s, loss_val=28.0293, thetas=-0.6453 4.8895][[-3.64060415]
 [-1.40649087]]
[[3.62813285]
 [1.40167277]]
[[-3.61570426]
 [-1.39687118]]
[[3.60331825]
 [1.39208604]]
[[-3.59097467]
 [-1.38731729]]
Training:   0%|                      | 1596/1000000 [00:18<3:13:32, 85.98it/s, loss_val=28.0140, thetas=-0.6484 4.8883][[3.57867337]
 [1.38256488]]
[[-3.56641422]
 [-1.37782874]]
[[3.55419706]
 [1.37310883]]
[[-3.54202175]
 [-1.36840509]]
Training:   0%|                      | 1605/1000000 [00:18<3:15:20, 85.18it/s, loss_val=28.0140, thetas=-0.6484 4.8883][[3.52988815]
 [1.36371746]]
[[-3.51779611]
 [-1.35904589]]
[[3.5057455 ]
 [1.35439032]]
[[-3.49373616]
 [-1.3497507 ]]
[[3.48176797]
 [1.34512698]]
[[-3.46984077]
 [-1.34051909]]
[[3.45795444]
 [1.33592699]]
[[-3.44610882]
 [-1.33135062]]
[[3.43430378]
 [1.32678992]]
Training:   0%

Training:   0%|                      | 1777/1000000 [00:20<3:12:12, 86.56it/s, loss_val=28.0063, thetas=-0.6505 4.8874][[1.9562883]
 [0.7557816]]
[[-1.94958681]
 [-0.75319258]]
[[1.94290827]
 [0.75061244]]
[[-1.93625262]
 [-0.74804113]]
[[1.92961976]
 [0.74547863]]
[[-1.92300963]
 [-0.7429249 ]]
[[1.91642214]
 [0.74037993]]
[[-1.90985722]
 [-0.73784367]]
[[1.90331478]
 [0.7353161 ]]
Training:   0%|                      | 1786/1000000 [00:20<3:11:04, 87.07it/s, loss_val=28.0063, thetas=-0.6505 4.8874][[-1.89679476]
 [-0.73279719]]
[[1.89029707]
 [0.73028691]]
[[-1.88382164]
 [-0.72778523]]
[[1.87736839]
 [0.72529212]]
[[-1.87093725]
 [-0.72280755]]
[[1.86452814]
 [0.72033149]]
[[-1.85814099]
 [-0.71786391]]
[[1.85177571]
 [0.71540479]]
[[-1.84543224]
 [-0.71295408]]
[[1.8391105 ]
 [0.71051178]]
Training:   0%|                      | 1796/1000000 [00:21<3:09:07, 87.97it/s, loss_val=28.0063, thetas=-0.6505 4.8874][[-1.83281042]
 [-0.70807784]]
[[1.82653192]
 [0.70565224]]
[[-1.82027492]
 

[[1.04045048]
 [0.40196188]]
[[-1.0368863 ]
 [-0.40058492]]
[[1.03333432]
 [0.39921267]]
[[-1.02979452]
 [-0.39784512]]
[[1.02626684]
 [0.39648225]]
[[-1.02275124]
 [-0.39512406]]
Training:   0%|                      | 1967/1000000 [00:23<3:13:58, 85.75it/s, loss_val=28.0004, thetas=-0.6532 4.8864][[1.01924769]
 [0.39377052]]
[[-1.01575614]
 [-0.39242161]]
[[1.01227655]
 [0.39107732]]
[[-1.00880888]
 [-0.38973764]]
[[1.00535309]
 [0.38840255]]
[[-1.00190914]
 [-0.38707204]]
[[0.99847699]
 [0.38574608]]
[[-0.99505659]
 [-0.38442466]]
[[0.99164791]
 [0.38310777]]
Training:   0%|                      | 1976/1000000 [00:23<3:13:57, 85.76it/s, loss_val=28.0004, thetas=-0.6532 4.8864][[-0.98825091]
 [-0.38179539]]
[[0.98486554]
 [0.3804875 ]]
[[-0.98149177]
 [-0.3791841 ]]
[[0.97812956]
 [0.37788516]]
[[-0.97477886]
 [-0.37659067]]
[[0.97143964]
 [0.37530062]]
[[-0.96811187]
 [-0.37401498]]
[[0.96479549]
 [0.37273375]]
[[-0.96149047]
 [-0.37145691]]
Training:   0%|                      | 198

 [-0.21305079]]
Training:   0%|                      | 2147/1000000 [00:25<3:15:35, 85.03it/s, loss_val=27.9989, thetas=-0.6546 4.8859][[0.54957809]
 [0.21232096]]
[[-0.54769545]
 [-0.21159363]]
[[0.54581926]
 [0.21086879]]
[[-0.54394949]
 [-0.21014644]]
[[0.54208613]
 [0.20942656]]
[[-0.54022915]
 [-0.20870914]]
[[0.53837853]
 [0.20799419]]
[[-0.53653426]
 [-0.20728168]]
[[0.5346963 ]
 [0.20657161]]
Training:   0%|                      | 2156/1000000 [00:25<3:15:38, 85.01it/s, loss_val=27.9989, thetas=-0.6546 4.8859][[-0.53286463]
 [-0.20586398]]
[[0.53103925]
 [0.20515876]]
[[-0.52922011]
 [-0.20445597]]
[[0.52740721]
 [0.20375558]]
[[-0.52560051]
 [-0.20305759]]
[[0.52380001]
 [0.202362  ]]
[[-0.52200567]
 [-0.20166878]]
[[0.52021748]
 [0.20097794]]
[[-0.51843541]
 [-0.20028947]]
Training:   0%|                      | 2165/1000000 [00:25<3:13:27, 85.97it/s, loss_val=27.9989, thetas=-0.6546 4.8859][[0.51665945]
 [0.19960336]]
[[-0.51488958]
 [-0.19891959]]
[[0.51312576]
 [0.19823817]

[[-0.28929913]
 [-0.11176623]]
[[0.28830811]
 [0.11138336]]
Training:   0%|                      | 2336/1000000 [00:27<3:16:15, 84.72it/s, loss_val=27.9986, thetas=-0.6552 4.8856][[-0.28732047]
 [-0.1110018 ]]
[[0.28633622]
 [0.11062155]]
[[-0.28535535]
 [-0.11024261]]
[[0.28437783]
 [0.10986496]]
[[-0.28340366]
 [-0.1094886 ]]
[[0.28243283]
 [0.10911354]]
[[-0.28146532]
 [-0.10873976]]
[[0.28050113]
 [0.10836726]]
[[-0.27954024]
 [-0.10799603]]
Training:   0%|                      | 2345/1000000 [00:27<3:15:32, 85.03it/s, loss_val=27.9986, thetas=-0.6552 4.8856][[0.27858264]
 [0.10762608]]
[[-0.27762833]
 [-0.10725739]]
[[0.27667728]
 [0.10688997]]
[[-0.27572949]
 [-0.10652381]]
[[0.27478495]
 [0.1061589 ]]
[[-0.27384364]
 [-0.10579524]]
[[0.27290556]
 [0.10543282]]
[[-0.27197069]
 [-0.10507165]]
[[0.27103902]
 [0.10471172]]
Training:   0%|                      | 2354/1000000 [00:27<3:16:08, 84.77it/s, loss_val=27.9986, thetas=-0.6552 4.8856][[-0.27011055]
 [-0.10435301]]
[[0.26918525

[[0.15333645]
 [0.05923916]]
[[-0.15281118]
 [-0.05903623]]
[[0.15228771]
 [0.05883399]]
[[-0.15176603]
 [-0.05863245]]
[[0.15124614]
 [0.0584316 ]]
[[-0.15072803]
 [-0.05823143]]
Training:   0%|                      | 2525/1000000 [00:29<3:13:20, 85.98it/s, loss_val=27.9985, thetas=-0.6556 4.8855][[0.15021169]
 [0.05803195]]
[[-0.14969712]
 [-0.05783316]]
[[0.14918432]
 [0.05763504]]
[[-0.14867327]
 [-0.05743761]]
[[0.14816397]
 [0.05724085]]
[[-0.14765642]
 [-0.05704476]]
[[0.14715061]
 [0.05684935]]
[[-0.14664653]
 [-0.05665461]]
[[0.14614417]
 [0.05646053]]
Training:   0%|                      | 2534/1000000 [00:29<3:12:31, 86.35it/s, loss_val=27.9985, thetas=-0.6556 4.8855][[-0.14564354]
 [-0.05626712]]
[[0.14514462]
 [0.05607437]]
[[-0.14464741]
 [-0.05588228]]
[[0.1441519 ]
 [0.05569085]]
[[-0.14365809]
 [-0.05550007]]
[[0.14316598]
 [0.05530995]]
[[-0.14267554]
 [-0.05512048]]
[[0.14218679]
 [0.05493166]]
[[-0.14169972]
 [-0.05474348]]
Training:   0%|                      | 254

Training:   0%|                      | 2706/1000000 [00:31<3:14:54, 85.28it/s, loss_val=27.9984, thetas=-0.6557 4.8854][[-0.08071665]
 [-0.03118362]]
[[0.08044015]
 [0.0310768 ]]
[[-0.08016459]
 [-0.03097034]]
[[0.07988998]
 [0.03086425]]
[[-0.0796163 ]
 [-0.03075852]]
[[0.07934357]
 [0.03065316]]
[[-0.07907177]
 [-0.03054815]]
[[0.0788009]
 [0.0304435]]
[[-0.07853096]
 [-0.03033922]]
Training:   0%|                      | 2715/1000000 [00:31<3:14:34, 85.42it/s, loss_val=27.9984, thetas=-0.6557 4.8854][[0.07826194]
 [0.03023529]]
[[-0.07799385]
 [-0.03013171]]
[[0.07772667]
 [0.03002849]]
[[-0.07746041]
 [-0.02992563]]
[[0.07719506]
 [0.02982311]]
[[-0.07693062]
 [-0.02972095]]
[[0.07666708]
 [0.02961914]]
[[-0.07640445]
 [-0.02951767]]
[[0.07614272]
 [0.02941656]]
Training:   0%|                      | 2724/1000000 [00:31<3:13:14, 86.02it/s, loss_val=27.9984, thetas=-0.6557 4.8854][[-0.07588188]
 [-0.02931579]]
[[0.07562194]
 [0.02921536]]
[[-0.07536289]
 [-0.02911528]]
[[0.07510472]


Training:   0%|                      | 2896/1000000 [00:33<3:10:46, 87.11it/s, loss_val=27.9984, thetas=-0.6558 4.8854][[-0.04205426]
 [-0.01624701]]
[[0.0419102 ]
 [0.01619135]]
[[-0.04176663]
 [-0.01613589]]
[[0.04162356]
 [0.01608061]]
[[-0.04148097]
 [-0.01602553]]
Training:   0%|                      | 2896/1000000 [00:33<3:10:46, 87.11it/s, loss_val=27.9984, thetas=-0.6558 4.8854][[0.04133887]
 [0.01597063]]
[[-0.04119726]
 [-0.01591592]]
[[0.04105613]
 [0.0158614 ]]
[[-0.04091549]
 [-0.01580706]]
Training:   0%|                      | 2905/1000000 [00:33<3:13:19, 85.96it/s, loss_val=27.9984, thetas=-0.6558 4.8854][[0.04077533]
 [0.01575292]]
[[-0.04063565]
 [-0.01569895]]
[[0.04049645]
 [0.01564517]]
[[-0.04035772]
 [-0.01559158]]
[[0.04021947]
 [0.01553817]]
[[-0.0400817 ]
 [-0.01548494]]
[[0.03994439]
 [0.0154319 ]]
[[-0.03980756]
 [-0.01537903]]
[[0.03967119]
 [0.01532635]]
Training:   0%|                      | 2914/1000000 [00:34<3:13:27, 85.90it/s, loss_val=27.9984, thetas

 [-0.00870047]]
[[0.02244341]
 [0.00867066]]
[[-0.02236653]
 [-0.00864096]]
[[0.02228991]
 [0.00861136]]
[[-0.02221355]
 [-0.00858186]]
[[0.02213746]
 [0.00855246]]
[[-0.02206162]
 [-0.00852317]]
Training:   0%|                      | 3085/1000000 [00:36<3:15:20, 85.06it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.02198605]
 [0.00849397]]
[[-0.02191073]
 [-0.00846487]]
[[0.02183567]
 [0.00843587]]
[[-0.02176087]
 [-0.00840698]]
[[0.02168633]
 [0.00837818]]
[[-0.02161204]
 [-0.00834948]]
[[0.02153801]
 [0.00832087]]
[[-0.02146423]
 [-0.00829237]]
[[0.0213907 ]
 [0.00826396]]
Training:   0%|                      | 3094/1000000 [00:36<3:14:18, 85.51it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-0.02131742]
 [-0.00823565]]
[[0.0212444 ]
 [0.00820744]]
[[-0.02117162]
 [-0.00817933]]
[[0.02109909]
 [0.00815131]]
[[-0.02102682]
 [-0.00812338]]
[[0.02095479]
 [0.00809556]]
[[-0.020883  ]
 [-0.00806782]]
Training:   0%|                      | 3094/1000000 [00:36<3:14:18, 85.51it/s, loss_va

Training:   0%|                      | 3265/1000000 [00:38<3:12:57, 86.09it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.01185487]
 [0.00457995]]
[[-0.01181426]
 [-0.00456426]]
[[0.01177379]
 [0.00454862]]
[[-0.01173346]
 [-0.00453304]]
[[0.01169326]
 [0.00451751]]
[[-0.01165321]
 [-0.00450204]]
[[0.01161329]
 [0.00448661]]
[[-0.0115735 ]
 [-0.00447124]]
[[0.01153386]
 [0.00445593]]
Training:   0%|                      | 3274/1000000 [00:38<3:12:37, 86.24it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-0.01149435]
 [-0.00444066]]
[[0.01145497]
 [0.00442545]]
[[-0.01141573]
 [-0.00441029]]
[[0.01137663]
 [0.00439518]]
[[-0.01133765]
 [-0.00438013]]
[[0.01129882]
 [0.00436512]]
[[-0.01126011]
 [-0.00435017]]
[[0.01122154]
 [0.00433527]]
[[-0.0111831 ]
 [-0.00432042]]
Training:   0%|                      | 3283/1000000 [00:38<3:12:23, 86.34it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.01114479]
 [0.00430562]]
[[-0.01110661]
 [-0.00429087]]
[[0.01106856]
 [0.00427617]]
[[-0.01103065]

[[-0.00624043]
 [-0.00241089]]
[[0.00621905]
 [0.00240263]]
Training:   0%|                      | 3454/1000000 [00:40<3:12:40, 86.20it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-0.00619775]
 [-0.0023944 ]]
[[0.00617652]
 [0.0023862 ]]
[[-0.00615536]
 [-0.00237803]]
[[0.00613427]
 [0.00236988]]
[[-0.00611326]
 [-0.00236176]]
[[0.00609232]
 [0.00235367]]
[[-0.00607145]
 [-0.00234561]]
[[0.00605065]
 [0.00233757]]
[[-0.00602992]
 [-0.00232957]]
Training:   0%|                      | 3463/1000000 [00:40<3:11:51, 86.57it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.00600927]
 [0.00232159]]
[[-0.00598868]
 [-0.00231363]]
[[0.00596817]
 [0.00230571]]
[[-0.00594772]
 [-0.00229781]]
[[0.00592735]
 [0.00228994]]
[[-0.00590704]
 [-0.00228209]]
[[0.00588681]
 [0.00227428]]
[[-0.00586664]
 [-0.00226649]]
[[0.00584654]
 [0.00225872]]
Training:   0%|                      | 3472/1000000 [00:40<3:12:29, 86.28it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-0.00582652]
 [-0.00225098]]
[[0.00580656

Training:   0%|                      | 3634/1000000 [00:42<3:11:53, 86.54it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-0.00334182]
 [-0.00129106]]
[[0.00333038]
 [0.00128664]]
[[-0.00331897]
 [-0.00128223]]
[[0.0033076 ]
 [0.00127784]]
[[-0.00329627]
 [-0.00127346]]
[[0.00328498]
 [0.0012691 ]]
[[-0.00327372]
 [-0.00126475]]
[[0.00326251]
 [0.00126042]]
[[-0.00325133]
 [-0.0012561 ]]
Training:   0%|                      | 3643/1000000 [00:42<3:12:24, 86.31it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.0032402]
 [0.0012518]]
[[-0.0032291 ]
 [-0.00124751]]
[[0.00321803]
 [0.00124324]]
[[-0.00320701]
 [-0.00123898]]
[[0.00319602]
 [0.00123473]]
[[-0.00318508]
 [-0.0012305 ]]
[[0.00317416]
 [0.00122629]]
[[-0.00316329]
 [-0.00122209]]
[[0.00315246]
 [0.0012179 ]]
Training:   0%|                      | 3652/1000000 [00:42<3:12:12, 86.39it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-0.00314166]
 [-0.00121373]]
[[0.00313089]
 [0.00120957]]
[[-0.00312017]
 [-0.00120543]]
[[0.00310948]


[[-0.00177735]
 [-0.00068665]]
[[0.00177126]
 [0.0006843 ]]
[[-0.00176519]
 [-0.00068195]]
[[0.00175914]
 [0.00067962]]
[[-0.00175312]
 [-0.00067729]]
Training:   0%|                      | 3823/1000000 [00:44<3:14:49, 85.22it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.00174711]
 [0.00067497]]
[[-0.00174113]
 [-0.00067266]]
[[0.00173516]
 [0.00067035]]
[[-0.00172922]
 [-0.00066806]]
[[0.0017233 ]
 [0.00066577]]
[[-0.00171739]
 [-0.00066349]]
[[0.00171151]
 [0.00066121]]
[[-0.00170565]
 [-0.00065895]]
[[0.0016998 ]
 [0.00065669]]
Training:   0%|                      | 3832/1000000 [00:44<3:15:00, 85.14it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-0.00169398]
 [-0.00065444]]
[[0.00168818]
 [0.0006522 ]]
[[-0.00168239]
 [-0.00064997]]
[[0.00167663]
 [0.00064774]]
[[-0.00167089]
 [-0.00064552]]
[[0.00166516]
 [0.00064331]]
[[-0.00165946]
 [-0.00064111]]
[[0.00165377]
 [0.00063891]]
[[-0.00164811]
 [-0.00063672]]
Training:   0%|                      | 3841/1000000 [00:44<3:14:01, 85.

Training:   0%|                      | 4003/1000000 [00:46<3:16:03, 84.67it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.00094204]
 [0.00036394]]
[[-0.00093882]
 [-0.0003627 ]]
[[0.0009356 ]
 [0.00036145]]
[[-0.00093239]
 [-0.00036022]]
[[0.0009292 ]
 [0.00035898]]
[[-0.00092602]
 [-0.00035775]]
[[0.00092285]
 [0.00035653]]
[[-0.00091968]
 [-0.00035531]]
[[0.00091653]
 [0.00035409]]
Training:   0%|                      | 4012/1000000 [00:46<3:14:45, 85.24it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-0.00091339]
 [-0.00035288]]
[[0.00091026]
 [0.00035167]]
[[-0.00090715]
 [-0.00035046]]
[[0.00090404]
 [0.00034926]]
[[-0.00090094]
 [-0.00034806]]
[[0.00089786]
 [0.00034687]]
[[-0.00089478]
 [-0.00034568]]
[[0.00089171]
 [0.0003445 ]]
[[-0.00088866]
 [-0.00034332]]
Training:   0%|                      | 4021/1000000 [00:47<3:15:07, 85.07it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.00088562]
 [0.00034214]]
[[-0.00088258]
 [-0.00034097]]
[[0.00087956]
 [0.0003398 ]]
[[-0.00087655]

Training:   0%|                      | 4193/1000000 [00:48<3:12:03, 86.41it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.00049081]
 [0.00018962]]
[[-0.00048913]
 [-0.00018897]]
[[0.00048746]
 [0.00018832]]
[[-0.00048579]
 [-0.00018768]]
[[0.00048412]
 [0.00018703]]
[[-0.00048247]
 [-0.00018639]]
[[0.00048081]
 [0.00018575]]
[[-0.00047917]
 [-0.00018512]]
Training:   0%|                      | 4193/1000000 [00:49<3:12:03, 86.41it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[0.00047752]
 [0.00018448]]
Training:   0%|                      | 4202/1000000 [00:49<3:15:48, 84.76it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-0.00047589]
 [-0.00018385]]
[[0.00047426]
 [0.00018322]]
[[-0.00047263]
 [-0.00018259]]
[[0.00047101]
 [0.00018197]]
[[-0.0004694 ]
 [-0.00018135]]
[[0.00046779]
 [0.00018072]]
[[-0.00046619]
 [-0.00018011]]
[[0.00046459]
 [0.00017949]]
[[-0.000463  ]
 [-0.00017887]]
Training:   0%|                      | 4211/1000000 [00:49<3:14:13, 85.45it/s, loss_val=27.9984, thetas

 [-0.0001005 ]]
[[0.00025925]
 [0.00010016]]
[[-2.58365856e-04]
 [-9.98156356e-05]]
[[2.57480794e-04]
 [9.94737056e-05]]
Training:   0%|                      | 4382/1000000 [00:51<3:12:57, 85.99it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[-2.56598764e-04]
 [-9.91329470e-05]]
[[2.55719755e-04]
 [9.87933557e-05]]
[[-2.54843758e-04]
 [-9.84549278e-05]]
[[2.53970761e-04]
 [9.81176591e-05]]
[[-2.53100755e-04]
 [-9.77815459e-05]]
[[2.5223373e-04]
 [9.7446584e-05]]
[[-2.51369674e-04]
 [-9.71127695e-05]]
[[2.50508579e-04]
 [9.67800986e-05]]
[[-2.49650433e-04]
 [-9.64485673e-05]]
Training:   0%|                      | 4391/1000000 [00:51<3:11:27, 86.67it/s, loss_val=27.9984, thetas=-0.6559 4.8854][[2.48795227e-04]
 [9.61181716e-05]]
[[-2.47942950e-04]
 [-9.57889078e-05]]
[[2.47093593e-04]
 [9.54607720e-05]]
[[-2.46247145e-04]
 [-9.51337601e-05]]
[[2.45403598e-04]
 [9.48078685e-05]]
[[-2.44562940e-04]
 [-9.44830933e-05]]
[[2.43725161e-04]
 [9.41594306e-05]]
[[-2.42890253e-04]
 [-9.38368767e-0

In [12]:
for i in range(10):
    y_hat = predict_fn(x_train, thetas)
    print("Target: ", y_train[i][0], ", predicted:", y_hat[i][0])

Target:  11.8 , predicted: 15.825663061797137
Target:  11.0 , predicted: 16.139711609970796
Target:  23.7 , predicted: 27.899383727052204
Target:  35.4 , predicted: 32.259012900433305
Target:  15.2 , predicted: 14.798901476444952
Target:  24.4 , predicted: 27.821420233616394
Target:  33.4 , predicted: 32.259351209667344
Target:  31.6 , predicted: 30.822201690957264
Target:  13.4 , predicted: 21.53211004080499
Target:  34.9 , predicted: 31.084403500872533
