# Linear Regression

In [14]:
import jax
import optax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

## Model and Initialize weights

In [2]:
model = nn.Dense(features=5)

In [3]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
params = model.init(key2, x) # Initialization call
jax.tree_map(lambda x: x.shape, params) # Checking output shapes



FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [4]:
print(x)

[-2.6105604   0.03385275  1.0863333  -1.480299    0.48895663  1.0625157
  0.5417483   0.01702273  0.27226844  0.3052244 ]


In [5]:
model.apply(params, x)

DeviceArray([-0.7358944,  1.3583755, -0.7976872,  0.8168598,  0.6297792],            dtype=float32)

## Generate dataset

In [6]:
# Set problem dimensions
nsamples = 20
xdim = 10
ydim = 5

# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (xdim, ydim))
b = random.normal(k2, (ydim,))
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.normal(ksample, (nsamples, xdim))
y_samples = jnp.dot(x_samples, W) + b
y_samples += 0.1*random.normal(knoise,(nsamples, ydim)) # Adding noise
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [10]:
print(x_samples[0:2])

[[-1.2215654  -1.3890593  -1.1903723  -1.3718774   0.33591986  1.2989264
  -0.798972    0.30542666  0.2626647   0.16920286]
 [ 1.3967623   0.33278397 -0.02423434  0.7279588  -0.02220635 -0.04626204
   1.4736844  -0.8445644   0.35538498  0.92975104]]


In [11]:
print(y_samples[0:2])

[[-0.44524443 -4.254266   -2.085645   -0.2375911  -0.5508175 ]
 [-2.023773    2.1660166   3.138679   -1.5070897   2.0076575 ]]


## Train with Gradient Descent

### Make a MSE Function

In [12]:
def make_mse_func(x_batched, y_batched):
    def mse(params):
    # Define the squared loss for a single pair (x,y)
        def squared_error(x, y):
            pred = model.apply(params, x)
            return jnp.inner(y-pred, y-pred)/2.0
        # We vectorize the previous to compute the average of the loss on all samples.
        return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
    return jax.jit(mse) # And finally we jit the result.

# Get the sampled loss
loss = make_mse_func(x_samples, y_samples)

### Gradient Descent

In [13]:
alpha = 0.3 # Gradient step size
print('Loss for "true" W,b: ', loss(true_params))
grad_fn = jax.value_and_grad(loss)

for i in range(101):
    # We perform one gradient update
    loss_val, grads = grad_fn(params)
    params = jax.tree_multimap(lambda p, g: p - alpha * g, params, grads)
    if i % 10 == 0:
        print('Loss step {}: '.format(i), loss_val)

Loss for "true" W,b:  0.023639778
Loss step 0:  38.094772
Loss step 10:  0.44692174
Loss step 20:  0.10053458
Loss step 30:  0.03582275
Loss step 40:  0.018846864
Loss step 50:  0.013864852
Loss step 60:  0.012312567
Loss step 70:  0.011812925
Loss step 80:  0.01164931
Loss step 90:  0.011595252
Loss step 100:  0.011577307


## Train with Optax

In [15]:
tx = optax.sgd(learning_rate=alpha)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(loss)

In [16]:
for i in range(101):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.011576376
Loss step 10:  0.011571027
Loss step 20:  0.011569239
Loss step 30:  0.011568645
Loss step 40:  0.011568449
Loss step 50:  0.0115683805
Loss step 60:  0.011568364
Loss step 70:  0.011568367
Loss step 80:  0.011568349
Loss step 90:  0.011568355
Loss step 100:  0.011568355
