# CLOSED 

In [12]:
import torch 
import torch.nn as nn
import pandas as pd 
import numpy as np 
import plotly.graph_objects as go 

# Dataset and Parameters Initialization

In [13]:
secret_weight = torch.tensor([2.])
secret_bias = torch.tensor([1.])
#-------------------------------
init_weight = torch.tensor([-1.])
init_bias = torch.tensor([-3.])

X = torch.arange(-5.0,5,0.2)
y = secret_weight*X + secret_bias


#-----------------------------
#plottly 
scatter = go.Scatter(
  x = X,
  y = y,
  mode = 'markers',
  marker = dict(color= 'red',size=8),
  name = 'Data Points'
)

line = go.Scatter(
  x= X,
  y = init_weight*X + init_bias,
  mode = 'lines+text',
  line = dict(color = 'blue'),
  name = 'initialized model',
)

layout = go.Layout(
  title = 'Dataset and initialized Parameters' ,
  xaxis = dict(title='X'),
  yaxis= dict(title = 'y'),
  # paper_bgcolor= 'rgb(100,200,200)',
  plot_bgcolor= 'rgb(200,200,200)',

)

figure = go.Figure(data = [scatter,line],layout=layout)
figure.show()

# Loss function landscape 

In [14]:
# Span
W_span = torch.linspace((secret_weight-5).item(),(secret_weight+5).item(),100)
B_spane = torch.linspace((secret_bias-5).item(),(secret_bias+5).item(),100)

# Grird Formation 
W_grid,B_grid = torch.meshgrid(W_span,B_spane,indexing='xy')

# flatten 
W_flatten = W_grid.flatten()
B_flatten = B_grid.flatten()

# calculation cost function 
C_span = [ ]

for w,b in zip(W_flatten,B_flatten):
  pred = w*X + b 
  cost = torch.mean((pred - y)**2)
  C_span.append(cost.item())

C_grid = torch.tensor(C_span).view(W_grid.shape[0],W_grid.shape[1])
#----------------------------------------------------------------------------
# plot the cost function landscape 
  
cost_scatter = go.Surface(
  x = W_grid,
  y = B_grid,
  z = C_grid ,
  name = 'Cost function landscape',
  showscale = False
)

#--------
min_val_index = C_span.index(min(C_span))
init_pred = (init_weight*X + b)
init_cost = torch.mean((init_pred - y)**2)

global_minima = go.Scatter3d (
  x = (W_flatten[min_val_index],),
  y = (B_flatten[min_val_index],),
  z = (C_span[min_val_index ],),
  name = 'Global minima',
  mode = 'markers',
  marker = dict(color= 'green',size=12)
)

init_param = go.Scatter3d(
  x = (init_weight.item(),),
  y = (init_bias.item(),),
  z = (init_cost,),
  name = 'initial point',
  mode = 'markers',
  marker = dict(color='red',size=8)
)

layout = go.Layout(
  title = 'Cost Function',
  scene = dict(
  xaxis = dict(title = 'Weight'),
  yaxis = dict(title = 'Bias'),
  zaxis = dict(title = 'Cost')
),
paper_bgcolor= 'rgb(200,200,200)'

)


figure = go.Figure(data=[cost_scatter,global_minima,init_param], layout = layout)
figure.show()

# Model 

In [15]:
class Red_Formation(nn.Module):

  def __init__(self):
    super().__init__()
    self.layer = nn.Linear(1,1,bias=True)
    self.layer.weight.data = init_weight.view(1,1)
    self.layer.bias.data = init_bias.view(1)


  def forward(self,x):
    x = self.layer(x)
    return x 
  
#----------------------------------------
Red_model = Red_Formation()

for name,param in Red_model.named_parameters():
  print(f'{name:<12} | {param.item()}  | {param.shape}')


loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(Red_model.parameters(),momentum=0.9,lr = 0.01)



layer.weight | -1.0  | torch.Size([1, 1])
layer.bias   | -3.0  | torch.Size([1])


In [16]:
print(X.shape)
print(y.shape)
X = X.view(-1,1)
y = y.view(-1,1)
print(X.shape)
print(y.shape)

torch.Size([50])
torch.Size([50])
torch.Size([50, 1])
torch.Size([50, 1])


In [17]:
Red_model.train()

num_epochs = 30
W = [ ]
B = [ ]
L = [ ]


for epoch in range(num_epochs):
  print(f'epoch : {epoch}')

  print(f'w : {Red_model.layer.weight.data.item():.4f}')
  print(f'b : {Red_model.layer.bias.data.item():.4f}')
  W.append(Red_model.layer.weight.data.item())
  B.append(Red_model.layer.bias.data.item())

  y_hat = Red_model(X)
  
  loss = loss_fn(y_hat,y)
  print(f'loss {loss.item()}')
  L.append(loss.item())
  
  loss.backward()
  print(f'dl/dw : {Red_model.layer.weight.grad.item()}')
  print(f'dl/db : {Red_model.layer.bias.grad.item()}')

  optimizer.step()
  print(f'updated w : {Red_model.layer.weight.data.item():.4f}')
  print(f'updated b : {Red_model.layer.bias.data.item():.4f}')

  optimizer.zero_grad()

  print('-------------------')



epoch : 0
w : -1.0000
b : -3.0000
loss 88.66000366210938
dl/dw : -49.23999786376953
dl/db : -7.40000057220459
updated w : -0.5076
updated b : -2.9260
-------------------
epoch : 1
w : -0.5076
b : -2.9260
loss 65.88690948486328
dl/dw : -41.04156494140625
dl/db : -7.350479602813721
updated w : 0.3460
updated b : -2.7859
-------------------
epoch : 2
w : 0.3460
b : -2.7859
loss 35.897159576416016
dl/dw : -26.831947326660156
dl/db : -7.24098539352417
updated w : 1.3825
updated b : -2.5874
-------------------
epoch : 3
w : 1.3825
b : -2.5874
loss 15.606301307678223
dl/dw : -9.582206726074219
dl/db : -7.0512847900390625
updated w : 2.4112
updated b : -2.3382
-------------------
epoch : 4
w : 2.4112
b : -2.3382
loss 12.828592300415039
dl/dw : 7.5267720222473145
dl/db : -6.758692741394043
updated w : 3.2618
updated b : -2.0464
-------------------
epoch : 5
w : 3.2618
b : -2.0464
loss 23.327402114868164
dl/dw : 21.655866622924805
dl/db : -6.345132350921631
updated w : 3.8107
updated b : -1.7203

# Loss Function Landscape 

In [18]:
path_SGD_with_momentum = go.Scatter3d(
    x = W,
    y = B,
    z = L,
    mode = 'markers+lines',
    marker = dict(color='yellow',size = 10),
    name = 'SGD with momentum'
)

figure = go.Figure(data=[cost_scatter,global_minima,init_param,path_SGD_with_momentum])
figure.show()

# Vanilla SGD 

In [19]:
init_weight = torch.tensor([-1.])
init_bias = torch.tensor([-3.])


class White_Formation(nn.Module):
  
  def __init__(self):
    super().__init__()
    self.layer = nn.Linear(1,1,bias=True)
    self.layer.weight.data = init_weight.view(1,1)
    self.layer.bias.data = init_bias.view(1)

  def forward(self,x):
    x = self.layer(x)
    return x 

White_model = White_Formation()
for name,param in White_model.named_parameters():
  print(f'{name:<13} | {param.item():.3f} | {param.shape}')


loss = nn.MSELoss()
optimizer = torch.optim.SGD(White_model.parameters(),lr=0.01)

layer.weight  | -1.000 | torch.Size([1, 1])
layer.bias    | -3.000 | torch.Size([1])


In [20]:
White_model.train()

W_white = [ ]
B_white = [ ]
L_white = [ ]

for epoch in range(num_epochs):
  
  W_white.append(White_model.layer.weight.data.item())
  B_white.append(White_model.layer.bias.data.item())

  y_hat = White_model(X)
  loss = loss_fn(y_hat,y)
  print(f'{loss.item():.4f}')
  L_white.append(loss.item())

  loss.backward()
  optimizer.step()
  optimizer.zero_grad()

88.6600
65.8869
49.9067
38.6487
30.6753
24.9882
20.8943
17.9121
15.7071
14.0471
12.7704
11.7647
10.9519
10.2776
9.7037
9.2037
8.7590
8.3566
7.9872
7.6444
7.3233
7.0207
6.7340
6.4615
6.2016
5.9533
5.7158
5.4883
5.2702
5.0611


In [21]:
path_vanilla_SGD = go.Scatter3d(
  x = W_white,
  y = B_white,
  z = L_white,
  mode = 'markers+lines',
  marker = dict(color='rgb(23,222,222)',size = 10),
  name = 'Vanilla SGD'
)

figure = go.Figure(data=[cost_scatter,init_param,global_minima,path_vanilla_SGD,path_SGD_with_momentum])
figure.show()

# Sanity check 

In [22]:
weight = torch.tensor(-1)
bias = torch.tensor(-3)
eta = 0.01
beta = 0.9
velocity_weight = 0 
velocity_bias = 0


for _ in range(4):

  forward_pass = weight* X + bias
  los = torch.mean((y-forward_pass)**2)
  print(los.item())

  dw = torch.mean(2*(y-forward_pass)*(-1)*(X))
  print(f'dw : {dw}')

  db = torch.mean(2*(y-forward_pass)*(-1)*(1))
  print(f'db : {db}')

  velocity_weight = (beta * velocity_weight) + dw
  weight = weight - (eta*velocity_weight)
  print(f'weight_update {weight}')

  velocity_bias = (beta*velocity_bias) + db 
  bias = bias - (eta*velocity_bias)
  print(f'bias_update {bias}')

  print('---------')



88.66000366210938
dw : -49.24000549316406
db : -7.400000095367432
weight_update -0.507599949836731
bias_update -2.9260001182556152
---------
65.88690185546875
dw : -41.041568756103516
db : -7.350481033325195
weight_update 0.34597575664520264
bias_update -2.785895347595215
---------
35.897151947021484
dw : -26.831945419311523
db : -7.240985870361328
weight_update 1.382513403892517
bias_update -2.5873911380767822
---------
15.606298446655273
dw : -9.582198143005371
db : -7.0512847900390625
weight_update 2.4112191200256348
bias_update -2.3382246494293213
---------
