In [3]:
import numpy as np
from sklearn.datasets import load_linnerud
from sklearn.preprocessing import StandardScaler

In [4]:
X, y = load_linnerud(return_X_y=True)

In [5]:
X = StandardScaler().fit_transform(X)

In [6]:
class Linear:
    def __init__(self, in_dim, out_dim):
        self.w = np.random.randn(in_dim, out_dim)
        self.b = np.zeros(out_dim)
        self.dw = None
        self.db = None
        
    def forward(self, x):
        self.x = x
        return np.matmul(x, self.w) + self.b
    
    def backward(self, d):
        self.dw = np.repeat(np.sum(self.x, axis=0)[..., np.newaxis], self.w.shape[-1], axis=-1)
        assert self.dw.shape == self.w.shape
        self.db = np.sum(d, 0)
        assert self.db.shape == self.b.shape
        
        dd = np.repeat(np.sum(self.w, axis=-1)[np.newaxis, ...], self.x.shape[0], axis=0)
        assert dd.shape == self.x.shape
        return d.T @ dd

        
    def step(self, lr):
        self.w = self.w - lr * self.dw
        self.b = self.b - lr * self.db

In [7]:
class MeanSquaredError:
    def __init__(self, average=True): 
        self.average = average
    def forward(self, y_, y):
        assert y_.shape == y.shape
        l = np.sum(np.square(y - y_))
        if self.average:
            l /= len(y)
        return y_, l
    
    def backward(self, y_, y, average=True):
        d = -(y - y_)
        if average:
            d /= len(d)
        return d

In [11]:
class LinearRegression:
    def __init__(self):
        self.linear = Linear(3, 3)
        self.loss = MeanSquaredError()
        
    def forward(self, x, y):
        x = self.linear.forward(x)
        loss = self.loss.forward(x, y)
        return loss
    
    def backward(self, x, y_, y):
        d = self.loss.backward(y_, y)
        self.linear.backward(d)
    
    def step(self, lr):
        self.linear.step(lr)
        

In [15]:
linreg = LinearRegression()

In [16]:
for i in range(100000):
    y_, loss = linreg.forward(X, y)
    if i % 1 == 0:
        print(loss, 'loss')
    linreg.backward(X, y_, y)
    linreg.step(0.1)


36925.71431677842 loss
30029.031616778422 loss
24442.71862977843 loss
19917.805110308425 loss
16252.625159537729 loss
13283.829399413455 loss
10879.1048337128 loss
8931.277935495269 loss
7353.538147939066 loss
6075.568920018545 loss
5040.413845402922 loss
4201.938234964267 loss
3522.772990508957 loss
2972.649142500156 loss
2527.048825613027 loss
2166.1125689344512 loss
1873.7542010248058 loss
1636.9439230179926 loss
1445.127597832475 loss
1289.7563744322053 loss
1163.9056834779865 loss
1061.966623805069 loss
979.3959854700063 loss
912.5137684186051 loss
858.3391726069709 loss
814.4577499995469 loss
778.9137976875335 loss
750.1231963148025 loss
726.8028092028906 loss
707.9132956422417 loss
692.6127896581163 loss
680.2193798109749 loss
670.1807178347901 loss
662.0494016340806 loss
655.4630355115057 loss
650.1280789522201 loss
645.8067641391988 loss
642.3064991406516 loss
639.4712844918283 loss
637.1747606262816 loss
635.3145762951885 loss
633.8078269870034 loss
632.5873600473731 loss
631

KeyboardInterrupt: 