### Linear Regression with Gradient Descent from Scratch
We'll be using JAX library to obtain the gradients, everything else is from scratch

Implementation here is only done for single point, but can be extended to multiple points without any changes.

In [0]:
import jax
import jax.numpy as jnp

In [0]:
class LinearRegression():
    def __init__(self, w, b):
        self.w = w
        self.b = b

    def predict(self, x):
        return self.w*x + self.b

    def rms(self, xs, ys):
        return jnp.sqrt(jnp.sum(jnp.square(self.w*xs + self.b - ys)))

In [0]:
lr = LinearRegression(10., 5.)
xs = jnp.array([42.])
ys = jnp.array([21.])

print(lr.rms(xs, ys))



404.0


In [0]:
print(lr.predict(42.))

425.0


In [0]:
def loss_fn(w, b, xs, ys):
    lr = LinearRegression(w, b)
    return lr.rms(xs, ys)

In [0]:
grad_fn = jax.grad(loss_fn, argnums = (0,1))

print(loss_fn(13., 0., xs, ys))
print(grad_fn(13., 0., xs, ys))

525.0
(DeviceArray(42., dtype=float32), DeviceArray(1., dtype=float32))


In [0]:
def loss_fn(params, xs, ys):
    lr = LinearRegression(params['w'], params['b'])
    return lr.rms(xs, ys)

In [0]:
grad_fn = jax.grad(loss_fn)

In [0]:
params = {'w':42., 'b':0.}

for _ in range(15):
    print(loss_fn(params, xs, ys))
    grads = grad_fn(params, xs, ys)
    for name in params.keys():
        params[name] -=0.070538 * grads[name]

LinearRegression(params['w'], params['b']).predict(42.)

1743.0
1618.5004
1494.0007
1369.5011
1245.0016
1120.502
996.0024
871.50275
747.0031
622.5036
498.00403
373.50446
249.00485
124.50531
0.005754471


DeviceArray(-103.49381, dtype=float32)