In [1]:
import jax
from jax.example_libraries import stax, optimizers
import jax.numpy as jnp

In [3]:
from yellowbrick.datasets import load_energy
from sklearn.model_selection import train_test_split
from jax import numpy as jnp

features = [
   "relative compactness",
   "surface area",
   "wall area",
   "roof area",
   "overall height",
   "orientation",
   "glazing area",
   "glazing area distribution",
]
target = ["heating load", "cooling load"]

df = load_energy(return_dataset=True).to_dataframe()
X, Y = df[features], df[target]

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

samples, features = X_train.shape

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape



((614, 8), (154, 8), (614, 2), (154, 2))

In [4]:
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

In [5]:
stax.Dense(5)

(<function jax.example_libraries.stax.Dense.<locals>.init_fun(rng, input_shape)>,
 <function jax.example_libraries.stax.Dense.<locals>.apply_fun(params, inputs, **kwargs)>)

In [6]:
neural_net_init, neural_net_apply = stax.serial(
                                                  stax.Dense(5),
                                                  stax.Relu,
                                                  stax.Dense(10),
                                                  stax.Relu,
                                                  stax.Dense(15),
                                                  stax.Relu,
                                                  stax.Dense(1),
                                                )

In [7]:
neural_net_init, neural_net_apply

(<function jax.example_libraries.stax.serial.<locals>.init_fun(rng, input_shape)>,
 <function jax.example_libraries.stax.serial.<locals>.apply_fun(params, inputs, **kwargs)>)

In [8]:
rando = jax.random.PRNGKey(123)

weights = neural_net_init(rando, (features,))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (8, 5), Biases : (5,)
Weights : (5, 10), Biases : (10,)
Weights : (10, 15), Biases : (15,)
Weights : (15, 1), Biases : (1,)


In [9]:
pred = neural_net_apply(weights, X_train[:5])

pred

DeviceArray([[-0.259568  ],
             [-0.49770448],
             [-1.2057577 ],
             [-0.00148434],
             [-0.53028834]], dtype=float32)

In [10]:
def MeanSquaredErrorLoss(weights, input_data, actual):
    pred = neural_net_apply(weights, input_data)
    pred = pred.squeeze()
    return jnp.power(actual - pred, 2).mean()

In [11]:
from jax import grad, value_and_grad

def TrainModel(X, Y, epochs, opt_state):

    for i in range(1,epochs+1):
        loss, gradients = value_and_grad(MeanSquaredErrorLoss)(opt_get_weights(opt_state), X, Y)

        ## Update Weights
        opt_state = opt_update(i, gradients, opt_state)

        if i%100 ==0: ## Print MSE every 100 epochs
            print("MSE : {:.2f}".format(loss))

    return opt_state

In [12]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e3)
epochs = 2500

weights = neural_net_init(rando, (features,))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModel(X_train, Y_train, epochs, opt_state)

ValueError: Incompatible shapes for broadcasting: ((614, 2), (1, 614))