In [5]:
import jax.numpy as jnp
import matplotlib.pyplot as plt

from jax import jit
import jax

## Create dataset

In [64]:
# true model: Y = 0.5x_1 + 3x_2 - 0.4x_3

N = 1000
x_1 = jnp.linspace(0,1,N)
x_2 = jnp.linspace(0,4,N)
x_3 = jnp.linspace(0,0.5,N)

X = jnp.vstack((x_1,x_2,x_3)).T

w_real = jnp.array([0.5,3,-0.4])

y = jnp.dot(X,w_real)

In [66]:
# Add gaussian noise
key = jax.random.PRNGKey(0)

y = y + 0.4*jax.random.normal(key, shape=(N,))

In [94]:
train_size = int(0.8*N)
test_size = N - train_size

key,subkey = jax.random.split(key)
idx = jax.random.permutation(subkey,jnp.arange(N))

train_idx = idx[:train_size]
test_idx = idx[train_size:]

X_train = X[train_idx]
y_train = y[train_idx]

X_test = X[test_idx]
y_test = y[test_idx]

## Linear Regression using JAX

In [44]:
def model(theta,X):
  return jnp.dot(X,theta)

@jit
def loss_MSE(theta,X,y):
  return jnp.mean((y - model(theta,X))**2)

In [84]:
loss_grad = jit(jax.grad(loss_MSE))

def optimize(theta,X,y,lr=0.1):
  tol = 1e-6
  max_iter = 15000

  for i in range(max_iter):
    grad = loss_grad(theta,X,y)
    theta = theta - lr*grad

    if i % 1000 == 0:
      print(f"Iteration: {i}, Loss: {loss_MSE(theta,X,y)}")

    #if jnp.linalg.norm(grad) < tol:
    #  break

  return theta

In [88]:
theta = jax.random.normal(key, shape=(3,))
theta = optimize(theta,X_train,y_train,lr=0.15)

Iteration: 0, Loss: 32.01405334472656
Iteration: 1000, Loss: 23.70366859436035
Iteration: 2000, Loss: 23.70366859436035
Iteration: 3000, Loss: 23.70366859436035
Iteration: 4000, Loss: 23.70366859436035
Iteration: 5000, Loss: 23.70366859436035
Iteration: 6000, Loss: 23.70366859436035
Iteration: 7000, Loss: 23.70366859436035
Iteration: 8000, Loss: 23.70366859436035
Iteration: 9000, Loss: 23.70366859436035
Iteration: 10000, Loss: 23.70366859436035
Iteration: 11000, Loss: 23.70366859436035
Iteration: 12000, Loss: 23.70366859436035
Iteration: 13000, Loss: 23.70366859436035
Iteration: 14000, Loss: 23.70366859436035


In [97]:
train_error = loss_MSE(theta,X_train,y_train)
test_error = loss_MSE(theta,X_test,y_test)

print(f"Train error: {train_error}")
print(f"Test error: {test_error}")

Train error: 3.8326499462127686
Test error: 4.1850714683532715
