In [120]:
from jax import random
from jax import numpy as jnp
import jax

random_state = random.PRNGKey(0)

@jax.jit
def loss(intercept: float, slope: float, dataset: jnp.ndarray) -> float:
    """sum of squared residuals"""
    ssr = 0.0

    for i, point in enumerate(dataset):
        prediction = intercept + slope * point[0]
        residual = point[1] - prediction
        ssr += residual**2

    return ssr


def gradient_descent(dataset: jnp.ndarray) -> (float, float):
    learning_rate = 0.01
    stop_learning_at = 0.0001
    max_steps = 1000
    intercept = random.normal(random_state)
    slope = random.normal(random_state)

    @jax.jit
    def step(dataset: jnp.ndarray, intercept: float, slope: float, learning_rate: float) -> ((float, float), (float, float)):
        derivative_intercept = jax.grad(lambda i: loss(i, slope, dataset))
        derivative_slope = jax.grad(lambda s: loss(intercept, s, dataset))

        intercept_slope = derivative_intercept(intercept)
        slope_slope = derivative_slope(slope)

        intercept_step_size = intercept_slope * learning_rate
        slope_step_size = slope_slope * learning_rate

        return (intercept - intercept_step_size, intercept_slope), (slope - slope_step_size, slope_slope)


    for i in range(max_steps):
        ((intercept, intercept_slope), (slope, slope_slope)) = step(dataset, intercept, slope, learning_rate)

        if abs(slope_slope) < stop_learning_at and abs(intercept_slope) < stop_learning_at:
            break

    return intercept, slope

gradient_descent(jnp.array([[0.5, 1.4], [2.3, 1.9], [2.9, 3.2]]))

(Array(0.9486321, dtype=float32), Array(0.64106226, dtype=float32))