In [1]:
import numpy as np

In [2]:
class Linear:
    def __init__(self, in_dim, out_dim):
        self.w = np.random.randn(in_dim, out_dim)
        self.b = np.zeros(out_dim)
        self.dw = None
        self.db = None
        
    def forward(self, x):
        return np.matmul(x, self.w) + self.b
    
    def backward(self, d, x):
        self.dw = d * x
        assert self.dw.T.shape == self.w.shape, (self.dw.T.shape, self.w.shape)
        self.db = d
        
    def step(self, lr):
        print(self.dw.shape, self.w.shape)
        self.w = self.w - lr * self.dw.T
        self.b = self.b - lr * self.db

In [3]:
class MeanSquaredError:
    def __init__(self): 
        pass
    def forward(self, y_, y):
        assert y_.shape == y.shape
        return y_, np.sum(np.square(y - y_))
    
    def backward(self, y_, y):
        return -(y - y_)

In [4]:
class LinearRegression:
    def __init__(self):
        self.linear = Linear(8, 1)
        self.loss = MeanSquaredError()
        
    def forward(self, x, y):
        x = self.linear.forward(x)
        loss = self.loss.forward(x, y)
        return loss
    
    def backward(self, x, y_, y):
        d = self.loss.backward(y_, y)
        self.linear.backward(d, x)
    
    def step(self, lr):
        self.linear.step(lr)
        

In [5]:
linreg = LinearRegression()

In [6]:
x = np.random.randn(1, 8)
y = np.ones([1, 1])

In [7]:
for _ in range(20):
    y_, loss = linreg.forward(x, y)
    linreg.backward(x, y_, y)
    linreg.step(0.1)
    print(loss, 'loss')

(1, 8) (8, 1)
0.2891110748431639 loss
(1, 8) (8, 1)
0.09450205198452707 loss
(1, 8) (8, 1)
0.03088998868040899 loss
(1, 8) (8, 1)
0.010097044250764313 loss
(1, 8) (8, 1)
0.003300431853720666 loss
(1, 8) (8, 1)
0.0010788157554355096 loss
(1, 8) (8, 1)
0.0003526336812147316 loss
(1, 8) (8, 1)
0.00011526575552917583 loss
(1, 8) (8, 1)
3.767704307752194e-05 loss
(1, 8) (8, 1)
1.2315536115199025e-05 loss
(1, 8) (8, 1)
4.025592706219572e-06 loss
(1, 8) (8, 1)
1.3158498732638632e-06 loss
(1, 8) (8, 1)
4.301132815289291e-07 loss
(1, 8) (8, 1)
1.4059159688814402e-07 loss
(1, 8) (8, 1)
4.595532843191205e-08 loss
(1, 8) (8, 1)
1.5021468267052246e-08 loss
(1, 8) (8, 1)
4.9100836965304e-09 loss
(1, 8) (8, 1)
1.6049644067999104e-09 loss
(1, 8) (8, 1)
5.246164640504388e-10 loss
(1, 8) (8, 1)
1.714819551110857e-10 loss
