In [1]:
import jax
from flax import struct
from flax import linen as nn
from flax.core import freeze, unfreeze
import jax.numpy as jnp
import jax.random as jr
from jax import vmap, grad, jit
from typing import Callable

key = jr.PRNGKey(123)



In [39]:
# Create the predict function from a set of parameters
def make_predict(W,b):
    def predict(x):
        return jnp.dot(W,x)+b
    return predict

In [40]:
# Create the loss from the data points set
def make_mse(x_batched,y_batched):
    def mse(W,b):
        # Define the squared loss for a single pair (x,y)
        def squared_error(x,y):
            y_pred = make_predict(W,b)(x)
            return jnp.inner(y-y_pred,y-y_pred)/2.0
        # We vectorize the previous to compute the average of the loss on all samples.
        return jnp.mean(vmap(squared_error)(x_batched,y_batched), axis=0)
    return jit(mse) # And finally we jit the result.

In [41]:
# Set problem dimensions
nsamples = 20
xdim = 10
ydim = 5

# Generate random ground truth W and b
k1, k2 = jr.split(key)
W = jr.normal(k1, (ydim, xdim))
b = jr.normal(k2, (ydim,))
true_predict = make_predict(W,b)

# Generate samples with additional noise
ksample, knoise = jr.split(k1)
x_samples = jr.normal(ksample, (nsamples, xdim))
y_samples = vmap(true_predict)(x_samples) + 0.1*jr.normal(knoise,(nsamples, ydim))

# Generate MSE for our samples
mse = make_mse(x_samples,y_samples)

In [42]:
# Initialize estimated W and b with zeros.
What = jnp.zeros_like(W)
bhat = jnp.zeros_like(b)

alpha = 0.3 # Gradient step size
print('Loss for "true" W,b: ', mse(W,b))
for i in range(101):
    # We perform one gradient update
    What, bhat = What - alpha*grad(mse,0)(What,bhat), bhat - alpha*grad(mse,1)(What,bhat)
    if (i%20==0):
        print("Loss step {}: ".format(i), mse(What,bhat))

Loss for "true" W,b:  0.026416078
Loss step 0:  11.808268
Loss step 20:  0.07759722
Loss step 40:  0.02324226
Loss step 60:  0.0143046575
Loss step 80:  0.012400283
Loss step 100:  0.01197335


## Flax setup

In [43]:
# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)

In [46]:
key, subkey = jr.split(key)
x = jr.normal(key, (10, ))
params = model.init(key, x)
jax.tree_map(lambda x: x.shape, params)

FrozenDict({'params': {'bias': (5,), 'kernel': (10, 5)}})

The size of x controls the shape of the model's parameters. `model.init` controls the initialisation of model parameters and returns the params as a frozen dictionary; an immutable structure.

In [50]:
# Evaluate the model at some data points x for a given set of params
model.apply(params, x)

DeviceArray([ 1.3209158,  0.1376045,  0.6276722,  1.2021543, -1.1306524],            dtype=float32)

# Gradient descent

In [54]:
# Set problem dimensions
nsamples = 20
xdim = 10
ydim = 5

# Generate random ground truth W and b
key = jr.PRNGKey(0)
k1, k2 = jr.split(key)
W = jr.normal(k1, (xdim, ydim))
b = jr.normal(k2, (ydim,))
true_params = freeze({'params': {'bias': b, 'kernel': W}})

In [56]:
# Generate samples with additional noise
ksample, knoise = jr.split(k1)
x_samples = jr.normal(ksample, (nsamples, xdim))
y_samples = jnp.dot(x,W) + b
y_samples += 0.1*jr.normal(knoise,(nsamples, ydim)) # Adding noise
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [57]:
def make_mse_func(x_batched, y_batched):
    def mse(params):
    # Define the squared loss for a single pair (x,y)
        def squared_error(x, y):
            pred = model.apply(params, x)
            return jnp.inner(y-pred, y-pred)/2.0
        # We vectorize the previous to compute the average of the loss on all samples.
        return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
    return jax.jit(mse) # And finally we jit the result.

# Get the sampled loss
loss = make_mse_func(x_samples, y_samples)


In [58]:
alpha = 0.3 # Gradient step size
print('Loss for "true" W,b: ', loss(true_params))
grad_fn = jax.value_and_grad(loss)

Loss for "true" W,b:  29.405216


In [59]:
for i in range(101):
    loss_val, grad = grad_fn(params)
    params = jax.tree_multimap(lambda old, grad: old - alpha * grad, params, grad)
    if i % 10 == 0:
        print('Loss step {}: '.format(i), loss_val)

Loss step 0:  24.20204
Loss step 10:  0.35897788
Loss step 20:  0.06296228
Loss step 30:  0.023336558
Loss step 40:  0.01489224
Loss step 50:  0.0125975255
Loss step 60:  0.011899963
Loss step 70:  0.011677139
Loss step 80:  0.0116043445
Loss step 90:  0.011580311
Loss step 100:  0.011572338


In [60]:
from flax import optim

In [61]:
# Choose an optimiser
optimizer_def = optim.GradientDescent(learning_rate=alpha) # Choose the method
# Create the wrapping optimizer with initial parameters
optimizer = optimizer_def.create(params) 
loss_grad_fn = jax.value_and_grad(loss)

In [64]:
for i in range(101):
    loss_val, grad = loss_grad_fn(optimizer.target)
    optimizer = optimizer.apply_gradient(grad) # Return the updated optimizer with parameters.
    if i % 10 == 0:
        print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.011568356
Loss step 10:  0.011568359
Loss step 20:  0.011568356
Loss step 30:  0.01156836
Loss step 40:  0.011568361
Loss step 50:  0.011568356
Loss step 60:  0.011568361
Loss step 70:  0.01156836
Loss step 80:  0.011568353
Loss step 90:  0.011568354
Loss step 100:  0.011568355


In [65]:
class GaussianProcess

SyntaxError: invalid syntax (<ipython-input-65-ee8783676bbf>, line 1)

In [3]:
from typing import Sequence

class SimpleMLP(nn.Module):
    features: Sequence[int]
    
    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, feat in enumerate(self.features):
            x = nn.Dense(feat, name=f'layers_{i}')(x)
            if i != len(self.features) - 1:
                x = nn.relu(x)
        return x

    def __add__(self, other):
        return self.features + other.features

m1 = SimpleMLP(features=5)
m2 = SimpleMLP(features=10)
m1+m2

15