In [185]:
import numpy as np
from numba import njit

In [211]:
@njit
def sign(arr):
    return np.array(list(map(lambda x: [1] if x>=0 else [-1], arr[:, 0])))

class Model():
    def __init__(self, inp_size, act_func=None):
        self.inp_size = inp_size
        self.weights = np.random.rand(inp_size, 1)
        self.bias = np.random.randn(1)
        self.act = act_func

    def __call__(self, X_b):
        # X_b shape: (bs, inp_size)
        res = np.matmul(X_b, self.weights) \
              + np.matmul(np.ones((X_b.shape[0], self.inp_size)), np.stack([self.bias]*self.inp_size))
        return res
    def train(self, epochs, X_b, y_b, lr=0.001):
        for epoch in range(epochs):
            preds = self(X_b)
            loss = np.absolute(preds-y_b).mean()
            #if epoch%5==0: print(loss)
            weights_grad = np.transpose((sign(preds-y_b) * X_b).mean(axis=0)).reshape(2, 1)
            bias_grad = (sign(preds-y_b)).mean(axis=0)
            self.weights -= lr*weights_grad
            self.bias -= lr*bias_grad

In [212]:
model = Model(2)

In [213]:
X_b = np.random.randn(10000, 2)
y_b = X_b.sum(axis=1)[:, None]

In [214]:
y_b.shape

(10000, 1)

In [218]:
model.train(200, X_b, y_b, lr=6e-3)

In [163]:
model(X_b)

array([[-1.73846323e+00],
       [ 2.32820913e+00],
       [ 1.79651925e+00],
       [ 1.37511853e+00],
       [ 4.36758166e-01],
       [ 2.79309213e-01],
       [-1.62956734e+00],
       [-1.45703508e+00],
       [-1.10765633e+00],
       [-1.91975006e+00],
       [-8.60070658e-01],
       [-1.71896605e+00],
       [-3.19751783e+00],
       [ 1.39709574e+00],
       [ 1.58055963e+00],
       [-5.02495514e-03],
       [-2.22417675e+00],
       [ 2.35860336e-01],
       [ 3.40417041e-01],
       [-5.64293807e-01],
       [-3.13562218e-01],
       [ 1.36524494e+00],
       [-1.79262097e+00],
       [ 1.35736621e-01],
       [-3.46370640e-01],
       [-6.62981579e-01],
       [ 5.52339927e-01],
       [ 5.14048853e-01],
       [-1.16056196e-01],
       [ 2.23689411e-01],
       [-1.25404573e+00],
       [ 7.39396695e-01],
       [-9.71526776e-01],
       [-7.24489418e-01],
       [-2.49753139e-01],
       [ 4.30662306e-01],
       [-1.80300375e+00],
       [-1.80619604e+00],
       [-1.4

In [219]:
model.weights

array([[1.00060695],
       [1.0001764 ]])