In [1]:
import jax

jax.__version__

'0.4.12'

In [2]:
from jax.example_libraries import stax, optimizers
import jax.numpy as jnp 
from tensorflow import keras
from sklearn.model_selection import train_test_split

(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)

X_train, X_test = X_train/255.0, X_test/255.0

classes =  jnp.unique(Y_train)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))

In [3]:
conv_init, conv_apply = stax.serial(
    stax.Conv(32,(3,3), padding="SAME"),
    stax.Relu,
    stax.Conv(16,(3,3), padding="SAME"),
    stax.Relu,

    stax.Flatten,
    stax.Dense(len(classes)),
    stax.Softmax
)

In [4]:
rng = jax.random.PRNGKey(123)
weights = conv_init(rng, (18, 28, 28, 1))
weights = weights[1]
for w in weights:
    if w:
        w, b = w
        print(f"Weight:{w.shape}, Biase:{b.shape}")

Weight:(3, 3, 1, 32), Biase:(1, 1, 1, 32)
Weight:(3, 3, 32, 16), Biase:(1, 1, 1, 16)
Weight:(12544, 10), Biase:(10,)


In [5]:
preds = conv_apply(weights, X_train[:5])
preds

Array([[0.1010583 , 0.08450613, 0.08876862, 0.10357412, 0.09424469,
        0.06837885, 0.13585268, 0.10559723, 0.10213768, 0.11588172],
       [0.10652138, 0.08236859, 0.11018915, 0.11165013, 0.08258919,
        0.0772452 , 0.14238864, 0.09115906, 0.0884075 , 0.10748117],
       [0.09566505, 0.09545239, 0.10094573, 0.10249694, 0.09882495,
        0.08883354, 0.11344208, 0.09690957, 0.10275006, 0.10467971],
       [0.10200436, 0.08598677, 0.10633809, 0.10407417, 0.09481844,
        0.0829622 , 0.12135128, 0.09306401, 0.09924015, 0.11016053],
       [0.09390712, 0.08359885, 0.10012674, 0.11463808, 0.09753538,
        0.07206484, 0.12989207, 0.08960546, 0.10824842, 0.11038305]],      dtype=float32)

In [6]:
def crossEntropyLoss(weights, input_data, actual):
    preds = conv_apply(weights,input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [9]:
from jax import value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs + 1):
        batches = jnp.arange((X.shape[0] // batch_size) + 1)

        losses = []
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch * batch_size), int(batch * batch_size + batch_size)
            else:
                start, end = int(batch * batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end]

            loss, gradients = value_and_grad(crossEntropyLoss)(opt_get_weights(opt_state), X_batch,Y_batch)

                ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss
        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
    return opt_state

In [10]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256

weights = conv_init(rng, (batch_size,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 228.892
CrossEntropyLoss : 145.997
CrossEntropyLoss : 127.736
CrossEntropyLoss : 117.267
CrossEntropyLoss : 109.906
CrossEntropyLoss : 104.241
CrossEntropyLoss : 99.730
CrossEntropyLoss : 96.072
CrossEntropyLoss : 93.087
CrossEntropyLoss : 90.570
CrossEntropyLoss : 88.415
CrossEntropyLoss : 86.527
CrossEntropyLoss : 84.845
CrossEntropyLoss : 83.308
CrossEntropyLoss : 81.903
CrossEntropyLoss : 80.579
CrossEntropyLoss : 79.349
CrossEntropyLoss : 78.185
CrossEntropyLoss : 77.085
CrossEntropyLoss : 76.043
CrossEntropyLoss : 75.054
CrossEntropyLoss : 74.103
CrossEntropyLoss : 73.185
CrossEntropyLoss : 72.294
CrossEntropyLoss : 71.427
