In [1]:
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import SGDRegressor
from matplotlib import pyplot as plt
from sklearn.metrics import mean_squared_error
import pandas as pd

import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot
import plotly.io as pio
pio.renderers.default = 'iframe'

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
X = 2 * np.random.rand(1000, 1)
y = 4 + 3 * X 

In [4]:
GD_model = SGDRegressor(max_iter=1, eta0=0.0001, warm_start=True)

In [5]:
out = []

for i in range(100):
    GD_model.fit(X, y) # Singel itteration of GD
    out.append([
                i, GD_model.coef_[0], # Model Parameter (1)
                GD_model.intercept_[0], # Model Parameter (2)
                mean_squared_error(y, GD_model.predict(X))
                ])
    
    df = pd.DataFrame(out, columns=["itteration", "m", "b", "MSE"])
    _out = out[-1]
    print(f"Itteration Number: {_out[0]}, Param (1): {_out[1]}, Param (2): {_out[2]}, MSE: {_out[3]}")
    
    # input()
    # break

Itteration Number: 0, Param (1): 0.1892792436273006, Param (2): 0.16283822122433275, MSE: 47.507326427679494
Itteration Number: 1, Param (1): 0.36785942491595747, Param (2): 0.3172356566867626, MSE: 42.791905567813956
Itteration Number: 2, Param (1): 0.5382006365440316, Param (2): 0.46405273531269126, MSE: 38.530046056084636
Itteration Number: 3, Param (1): 0.6981031513410183, Param (2): 0.6031003560246581, MSE: 34.71950992843571
Itteration Number: 4, Param (1): 0.8515627916746693, Param (2): 0.7357165580816324, MSE: 31.258935236852047
Itteration Number: 5, Param (1): 0.9954045541174817, Param (2): 0.8613059147579504, MSE: 28.167709217262768
Itteration Number: 6, Param (1): 1.1315549761119272, Param (2): 0.9805241396567392, MSE: 25.387442214266496
Itteration Number: 7, Param (1): 1.2612018384461707, Param (2): 1.0939347176196463, MSE: 22.876108315409674
Itteration Number: 8, Param (1): 1.3849915379118944, Param (2): 1.2019556082865637, MSE: 20.603462313280275
Itteration Number: 9, Para

In [6]:
df

Unnamed: 0,itteration,m,b,MSE
0,0,0.189279,0.162838,47.507326
1,1,0.367859,0.317236,42.791906
2,2,0.538201,0.464053,38.530046
3,3,0.698103,0.603100,34.719510
4,4,0.851563,0.735717,31.258935
...,...,...,...,...
95,95,3.498061,3.353502,0.102177
96,96,3.497582,3.356901,0.101215
97,97,3.496916,3.360075,0.100319
98,98,3.496170,3.363218,0.099438


In [7]:
init_notebook_mode(connected=True)

trace = go.Scatter3d(
    x=df['b'],
    y=df['m'],
    z=df['MSE'],
    mode='lines+markers',
    marker=dict(size=5, color=df['MSE'], colorscale='Viridis', opacity=0.8),
    line=dict(color='blue', width=2)
)

layout = go.Layout(
    title='',
    scene=dict(
        xaxis=dict(title='b'),
        yaxis=dict(title='m'),
        zaxis=dict(title='MSE')
    ),
    width=800,
    height=600
)

fig = go.Figure(data=[trace], layout=layout)
iplot(fig)