In [1]:
import jax
from jax.example_libraries import stax, optimizers
import jax.numpy as jnp

In [2]:
from yellowbrick.datasets import load_concrete
from sklearn.model_selection import train_test_split
from jax import numpy as jnp

X, Y = load_concrete(data_home=None, return_dataset=False)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

samples, features = X_train.shape

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape



((824, 8), (206, 8), (824,), (206,))

In [3]:
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

In [4]:
stax.Dense(5)

(<function jax.example_libraries.stax.Dense.<locals>.init_fun(rng, input_shape)>,
 <function jax.example_libraries.stax.Dense.<locals>.apply_fun(params, inputs, **kwargs)>)

In [5]:
neural_net_init, neural_net_apply = stax.serial(
                                                  stax.Dense(5),
                                                  stax.Relu,
                                                  stax.Dense(10),
                                                  stax.Relu,
                                                  stax.Dense(15),
                                                  stax.Relu,
                                                  stax.Dense(1),
                                                )

In [6]:
neural_net_init, neural_net_apply

(<function jax.example_libraries.stax.serial.<locals>.init_fun(rng, input_shape)>,
 <function jax.example_libraries.stax.serial.<locals>.apply_fun(params, inputs, **kwargs)>)

In [7]:
rando = jax.random.PRNGKey(123)

weights = neural_net_init(rando, (features,))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (8, 5), Biases : (5,)
Weights : (5, 10), Biases : (10,)
Weights : (10, 15), Biases : (15,)
Weights : (15, 1), Biases : (1,)


In [8]:
pred = neural_net_apply(weights, X_train[:5])

pred

DeviceArray([[-0.08696245],
             [-0.03468872],
             [-0.00760121],
             [-0.18586212],
             [-0.04303019]], dtype=float32)

In [9]:
def MeanSquaredErrorLoss(weights, input_data, actual):
    pred = neural_net_apply(weights, input_data)
    pred = pred.squeeze()
    return jnp.power(actual - pred, 2).mean()

In [10]:
from jax import grad, value_and_grad

def TrainModel(X, Y, epochs, opt_state):

    for i in range(1,epochs+1):
        loss, gradients = value_and_grad(MeanSquaredErrorLoss)(opt_get_weights(opt_state), X, Y)

        ## Update Weights
        opt_state = opt_update(i, gradients, opt_state)

        if i%100 ==0: ## Print MSE every 100 epochs
            print("MSE : {:.2f}".format(loss))

    return opt_state

In [11]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e3)
epochs = 2500

weights = neural_net_init(rando, (features,))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModel(X_train, Y_train, epochs, opt_state)

MSE : 102.73
MSE : 65.51
MSE : 74.54
MSE : 70.01
MSE : 69.14
MSE : 66.94
MSE : 65.08
MSE : 63.40
MSE : 61.86
MSE : 60.28
MSE : 59.17
MSE : 57.93
MSE : 56.80
MSE : 56.03
MSE : 54.69
MSE : 53.78
MSE : 53.00
MSE : 52.67
MSE : 51.12
MSE : 50.58
MSE : 50.69
MSE : 49.96
MSE : 49.53
MSE : 49.23
MSE : 49.04


In [12]:
## Make Predictions on test dataset
test_preds = neural_net_apply(opt_get_weights(final_opt_state), X_test) 
test_preds = test_preds.ravel()

## Make Predictions on train dataset
train_preds = neural_net_apply(opt_get_weights(final_opt_state), X_train) 
train_preds = train_preds.ravel()

test_preds[:5], train_preds[:5]

(DeviceArray([54.604584, 39.705597, 27.754255, 15.736749, 35.262512], dtype=float32),
 DeviceArray([16.224113, 36.45395 , 44.029045, 12.285954, 26.516912], dtype=float32))

In [13]:
from sklearn.metrics import r2_score

print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.to_py(), Y_train.to_py())))
print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds.to_py(), Y_test.to_py())))

Train R^2 Score : 0.72
Test  R^2 Score : 0.71


In [14]:
def TrainModelWithBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(MeanSquaredErrorLoss)(opt_get_weights(opt_state), X_batch, Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        if i % 100 == 0: ## Print MSE every 100 epochs
            print("MSE : {:.2f}".format(jnp.array(losses).mean()))

    return opt_state

In [15]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e3)
epochs = 500

weights = neural_net_init(rando, (features,))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelWithBatches(X_train, Y_train, epochs, opt_state)

MSE : 55.23
MSE : 42.40
MSE : 37.67
MSE : 34.27
MSE : 34.07


In [16]:
def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        preds.append(neural_net_apply(weights, X_batch))

    return preds

In [17]:
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

test_preds[:5], train_preds[:5]

(DeviceArray([67.84538 , 41.260525, 26.361334, 15.545101, 39.20099 ], dtype=float32),
 DeviceArray([16.22281  , 40.606377 , 50.369236 ,  7.7457733, 30.480265 ],            dtype=float32))

In [18]:
from sklearn.metrics import r2_score

print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds, Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds, Y_train)))

Test  R^2 Score : 0.84
Train R^2 Score : 0.86
