In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn

## Simulate data

In [None]:
x = np.arange(-5, 5, 0.1)

w_true = (np.random.rand() - 0.5) * 20
b_true = (np.random.rand() - 0.5) * 20
y_true = w_true * x + b_true

noise = 5 * np.random.randn(x.shape[0])
y = y_true + noise

In [None]:
x_torch = torch.tensor(x.reshape(-1,1))
y_torch = torch.tensor(y.reshape(-1,1))

In [None]:
plt.scatter(x, y)

## Build model

In [None]:
class Model(torch.nn.Module):
    
    def __init__(self, w=0, b=0, rand=False):
        super(Model, self).__init__()
        
        self.rand = rand
        
        if self.rand:
            self.w = nn.Parameter(torch.randn(1))
            self.b = nn.Parameter(torch.randn(1))
        else:
            self.w = nn.Parameter(torch.tensor(float(w)))
            self.b = nn.Parameter(torch.tensor(float(b)))
            
    def forward(self, x):
        return self.w * x + self.b

In [None]:
model = Model(w=w_true, b=b_true, rand=True)

In [None]:
y_pred = model(x_torch)

In [None]:
plt.scatter(x, y, label=f'train: w={w_true:.3f}, b={b_true:.3f}')
plt.plot(x, y_pred.detach().flatten(), color='C1', 
         label=f'sgd: w={model.w.item():.3f}, b={model.b.item():.3f}')
plt.legend()

## Train model

In [None]:
lr = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-5)
criterion = torch.nn.MSELoss()

In [None]:
epochs = 100
train_loss = np.zeros(epochs)
train_w = np.zeros(epochs)
train_b = np.zeros(epochs)

print(f'epoch \ttrain loss \t\tw \t\t\tb')
for epoch in range (epochs):
    # predict labels and find loss
    y_pred = model(x_torch)
    loss = criterion(y_torch, y_pred)
    
    # set gradients to 0
    optimizer.zero_grad()
    # backpropogation
    loss.backward()
    # update model
    optimizer.step()

    # append values
    train_loss[epoch] = loss.item()
    train_w[epoch] = model.w.item()
    train_b[epoch] = model.b.item()
    
    # print metrics
    #print(f'{epoch} \t{train_loss[epoch]} \t{train_w[epoch]} \t{train_b[epoch]}')

In [None]:
fig, axs = plt.subplots(1,3,figsize=[15,5])

axs[0].scatter(np.arange(epochs), train_loss)
axs[0].set_xlabel('epochs')
axs[0].set_ylabel('loss')
axs[0].set_yscale('log')

axs[1].scatter(train_w, train_loss)
axs[1].vlines(w_true, train_loss.min(), train_loss.max(), color='C1')
axs[1].set_xlabel('w')
axs[1].set_ylabel('loss')
axs[1].set_yscale('log')

axs[2].scatter(train_b, train_loss)
axs[2].vlines(b_true, train_loss.min(), train_loss.max(), color='C1')
axs[2].set_xlabel('b')
axs[2].set_ylabel('loss')
axs[2].set_yscale('log')

In [None]:
y_pred_train = train_w.reshape(-1, 1) @ x.reshape(1, -1) + train_b.reshape(-1, 1)
y_pred_trained = y_pred_train[-1]

In [None]:
plt.scatter(x, y, label=f'train: w={w_true:.3f}, b={b_true:.3f}')
for y_pred_train_i in y_pred_train[:-1]:
    plt.plot(x, y_pred_train_i, color='C1', alpha=0.1)
plt.plot(x, y_pred_trained, color='C1', 
         label=f'sgd: w={model.w.item():.3f}, b={model.b.item():.3f}')
plt.legend()

## Compare with least squares

In [None]:
z = np.polyfit(x, y, deg=1)
p = np.poly1d(z)
fit = p(x)

In [None]:
plt.scatter(x, y, label=f'train: w={w_true:.3f}, b={b_true:.3f}')
plt.plot(x, y_pred_trained, color='C1', 
         label=f'sgd: w={model.w.item():.3f}, b={model.b.item():.3f}')
plt.legend()
plt.plot(x, fit, '--', color='C2',
         label=f'polyfit: w={z[0]:.3f}, b={z[1]:.3f}')
plt.legend()

In [None]:
fig, axs = plt.subplots(1,2,figsize=[10,5])
axs[0].hist(y - y_pred_trained, alpha=0.5, color='C1', label='sgd')
axs[0].hist(y - fit, alpha=0.5, color='C2', label='polyfit')
axs[0].legend()
axs[1].plot(y - y_pred_trained, alpha=0.5, color='C1', label='sgd')
axs[1].plot(y - fit, alpha=0.5, color='C2', label='polyfit')
axs[1].legend()

In [None]:
print (f'\t sgd\t\t\t polyfit')
print (f'weight\t {model.w.item()}\t {z[0]}')
print (f'bias\t {model.b.item()}\t {z[1]}')
print (f'mse\t {np.mean(np.square(y_pred_trained - y))}\t {np.mean(np.square(fit - y))}')
print (f'mae\t {np.mean(np.abs(y_pred_trained - y))}\t {np.mean(np.abs(fit - y))}')