In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from jax import random, jit, lax, vmap, pmap, grad, value_and_grad

In [2]:
jax.__version__

'0.4.1'

In [3]:
rng = np.random.default_rng()
key = jax.random.PRNGKey(0)

In [26]:
n_features = 7
batch_size = 5
learing_rate = 0.01
n_epochs = 3

In [5]:
def init_classifier_mlp(n_features, hidden_layer):
    weights_init = jax.nn.initializers.glorot_normal()
    
    W1 = weights_init(key, (hidden_layer, n_features))
    var = jnp.sqrt(n_features)
    b1 = jax.random.uniform(key, (hidden_layer,), minval=-var, maxval=var)

    w2 = weights_init(key, (1, hidden_layer)).squeeze()
    var = jnp.sqrt(hidden_layer)
    b2 = jax.random.uniform(key, minval=-var, maxval=var)

    return W1, b1, w2, b2


In [15]:
def predict(W1, b1, w2, b2, x):
    z1 = jax.nn.relu(W1 @ x + b1)
    p = jax.nn.sigmoid(w2 @ z1 + b2)
    return p
predict = vmap(predict, in_axes=(None, None, None, None, 0))

In [16]:
def loss(W1, b1, w2, b2, X, y):
    p = predict(W1, b1, w2, b2, X)
    return -jnp.mean(y * jnp.log(p) + (1-y) * jnp.log(1 - p))

In [25]:
@jit
def update(W1, b1, w2, b2, X, y):
    dW1, db1, dw2, db2 = grad(loss, argnums=(0, 1, 2, 3))(W1, b1, w2, b2, X, y)
    return W1 - learing_rate * dW1, b1 - learing_rate * db1, w2 - learing_rate * dw2, b2 - learing_rate * db2

In [45]:
def accuracy(W1, b1, w2, b2, cutoff, X, y):
    p = predict(W1, b1, w2, b2, X)
    y_hat = jnp.where(p > cutoff, 1., 0.)
    return jnp.mean(y_hat == y)

In [36]:
X = jax.random.uniform(key, shape=(batch_size, n_features))
y = jax.random.choice(key, a=jnp.array([0, 1]), shape=(batch_size,), p=jnp.array([0.7, 0.3]))
print(X.shape,  y.shape)

(5, 7) (5,)


In [37]:
X

Array([[0.42231905, 0.600477  , 0.716727  , 0.6702663 , 0.19844568,
        0.62467563, 0.7409996 ],
       [0.5134349 , 0.1870414 , 0.3951999 , 0.05926502, 0.26171422,
        0.32934666, 0.8185371 ],
       [0.4934113 , 0.77583563, 0.86196864, 0.61747384, 0.93201673,
        0.5313933 , 0.11078537],
       [0.8293711 , 0.5833154 , 0.9171771 , 0.36765325, 0.9984219 ,
        0.15992486, 0.30305398],
       [0.6148175 , 0.44295645, 0.9467739 , 0.9498491 , 0.5454718 ,
        0.60043406, 0.09577227]], dtype=float32)

In [38]:
a1 = jnp.array([
    [1, 1],
    [2, 2],
    [3, 3]
])

a2 = jnp.array([
    [4, 4],
    [5, 5]
])

In [43]:
jnp.concatenate((a1, a2), axis=0)

Array([[1, 1],
       [2, 2],
       [3, 3],
       [4, 4],
       [5, 5]], dtype=int32)

In [44]:
c1 = jnp.array([1, 1, 1])
c2 = jnp.array([2, 2, 2])
jnp.concatenate((c1, c2))

Array([1, 1, 1, 2, 2, 2], dtype=int32)

In [19]:
W1, b1, w2, b2 = init_classifier_mlp(n_features, 8)
print(W1.shape, b1.shape)
print(w2.shape, b2.shape)

(8, 7) (8,)
(8,) ()


In [30]:
p = predict(W1, b1, w2, b2, X)
p

Array([0.98897505, 0.99490064, 0.99331564, 0.99691415, 0.99141484],      dtype=float32)

In [31]:
jnp.where(p > 0.5, 1., 0.)

Array([1., 1., 1., 1., 1.], dtype=float32, weak_type=True)

In [32]:
y_hat = jnp.array([1., 0., 1.])
y = jnp.array([0., 1., 1.])

In [33]:
jnp.mean(y_hat == y)

Array(0.33333334, dtype=float32)

In [34]:
y_hat == y

Array([False, False,  True], dtype=bool)

In [46]:
W1, b1, w2, b2 = init_classifier_mlp(n_features, 8)
X_all = None
y_all = None
for _ in range(n_epochs):
    X = jax.random.uniform(key, shape=(batch_size, n_features))
    y = jax.random.choice(key, a=jnp.array([0, 1]), shape=(batch_size,), p=jnp.array([0.7, 0.3]))
    W1, b1, w2, b2 = update(W1, b1, w2, b2, X, y)
    
    if X_all is None:
        X_all = X
        y_all = y
    else:
        X_all = jnp.concatenate((X_all, X), axis=0)
        y_all = jnp.concatenate((y_all, y))
    acc = accuracy(W1, b1, w2, b2, 0.5, X_all, y_all)
    print(f"Accuracy: {acc:.5f}")

Accuracy: 0.20000
Accuracy: 0.20000
Accuracy: 0.20000
