In [71]:
import numpy as np


In [72]:
class Linear:
    def __init__(self, in_dim, out_dim):
        self.w = np.ones([in_dim, out_dim])
        self.b = np.zeros(out_dim)
        self.dw = None
        self.db = None
        
    def forward(self, x):
        return np.matmul(x, self.w) + self.b
    
    def backward(self, d, x):
        self.dw = np.matmul(x.T, d)
        assert self.dw.T.shape == self.w.shape, (self.dw.T.shape, self.w.shape)
        self.db = d
        
    def step(self, lr):
        self.w = self.w - lr * self.dw.T
        self.b = self.b - lr * self.db

In [73]:
class MeanSquaredError:
    def __init__(self): 
        pass
    def forward(self, y_, y):
        assert y_.shape == y.shape
        return y_, np.sum(np.square(y - y_))
    
    def backward(self, y_, y):
        return -(y - y_)

In [74]:
class LinearRegression:
    def __init__(self):
        self.linear = Linear(2, 2)
        self.loss = MeanSquaredError()
        
    def forward(self, x, y):
        x = self.linear.forward(x)
        loss = self.loss.forward(x, y)
        return loss
    
    def backward(self, x, y_, y):
        d = self.loss.backward(y_, y)
        self.linear.backward(d, x)
    
    def step(self, lr):
        self.linear.step(lr)
        

In [75]:
linreg = LinearRegression()

In [76]:
x = np.array([[1, 0], [0, 1]])
y = np.array([[0, 0], [0, 0]])

In [77]:
for _ in range(1000):
    y_, loss = linreg.forward(x, y)
    print(loss, 'loss')
    linreg.backward(x, y_, y)
    linreg.step(0.1)


4.0 loss
2.5600000000000005 loss
1.6384 loss
1.048576 loss
0.6710886399999999 loss
0.4294967295999999 loss
0.2748779069439999 loss
0.17592186044415997 loss
0.11258999068426233 loss
0.07205759403792786 loss
0.0461168601842738 loss
0.029514790517935232 loss
0.018889465931478572 loss
0.012089258196146287 loss
0.007737125245533639 loss
0.0049517601571415415 loss
0.0031691265005705763 loss
0.002028240960365163 loss
0.0012980742146337012 loss
0.00083076749736557 loss
0.0005316911983139658 loss
0.00034028236692093647 loss
0.00021778071482939868 loss
0.00013937965749081306 loss
8.92029807941191e-05 loss
5.7089907708235556e-05 loss
3.653754093327102e-05 loss
2.3384026197294315e-05 loss
1.4965776766268362e-05 loss
9.5780971304123e-06 loss
6.1299821634643125e-06 loss
3.923188584616896e-06 loss
2.5108406941548137e-06 loss
1.6069380442590807e-06 loss
1.0284403483259468e-06 loss
6.582018229285699e-07 loss
4.212491666742847e-07 loss
2.6959946667163445e-07 loss
1.7254365866984606e-07 loss
1.1042794154

1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-31 loss
1.9721522630525295e-

In [78]:
print(linreg.linear.b)

[[-0.5 -0.5]
 [-0.5 -0.5]]


In [79]:
print(linreg.linear.w)

[[0.5 0.5]
 [0.5 0.5]]
