In [3]:
import jax
import jax.numpy as jnp
from jax import random, jit
import pandas as pd

In [4]:
data = pd.read_csv('data/digit-recognizer/train.csv')
key = random.key(42)


In [5]:
data = jnp.array(data)
m, n = data.shape
random.permutation(key, data, independent=True)
print(m, n)

42000 785


In [6]:

dev = data[0:1000].T
y_dev = dev[0]
x_dev = dev[1:n]
x_dev = x_dev / 255

train = data[1000:m].T
y_train = train[0]
x_train = train[1:n]
x_train = x_train / 255

print(x_train.shape)

(784, 41000)


In [7]:
y_train

Array([1, 5, 1, ..., 7, 6, 9], dtype=int32)

In [8]:
def generate_subkey(key):
    key, subkey = random.split(key)
    return subkey

def ReLU(vec):
    return jnp.maximum(vec, 0)

def der_ReLU(vec):
    return vec > 0

def softmax(vec):
    return jnp.exp(vec) / sum(jnp.exp(vec))

def one_hot_enc(vec):
    one_hot = jnp.zeros((vec.size, vec.max() + 1))
    one_hot = one_hot.at[jnp.arange(vec.size), vec].set(1)
    return one_hot.T

def init_params():
    key = random.key(42)
    key, subkey = random.split(key)
    first_weights = random.uniform(subkey, shape=(10, 784)) - 0.5
    key, subkey = random.split(key)
    first_bias = random.uniform(subkey, shape=(10, 1)) - 0.5
    key, subkey = random.split(key)
    second_weights = random.uniform(subkey, shape=(10, 10)) - 0.5
    key, subkey = random.split(key)
    second_bias = random.uniform(subkey, shape=(10, 1)) - 0.5
    return first_weights, first_bias, second_weights, second_bias

def forward(w1, b1, w2, b2, inp):
    out1 = w1.dot(inp) + b1
    act1 = ReLU(out1)
    out2 = w2.dot(act1) + b2
    act2 = softmax(out2)
    return out1, act1, out2, act2

def backward(out1, act1, out2, act2, first_w, second_w, inp, labels):
    one_hot_label = one_hot_enc(labels)
    d_out2 = act2 - one_hot_label
    d_weight2 = 1 / m * d_out2.dot(act1.T)
    d_bias2 = 1 / m * jnp.sum(d_out2)
    d_out1 = second_w.T.dot(d_out2) * der_ReLU(out1)
    d_weight1 = 1 / m * d_out1.dot(inp.T)
    d_bias1 = 1 / m * jnp.sum(out1)
    return d_weight1, d_bias1, d_weight2, d_bias2

def update_params(weight1, bias1, weight2, bias2, d_weight1, d_bias1, d_weight2, d_bias2, learning_rate):
    w1 = weight1 - learning_rate * d_weight1
    b1 = bias1 - learning_rate * d_bias1
    w2 = weight2 - learning_rate * d_weight2
    b2 = bias2 - learning_rate * d_bias2
    return w1, b1, w2, b2


In [9]:
def get_predictions(act2):
    return jnp.argmax(act2, 0)

def get_accuracy(predictions, labels):
    print(predictions, labels)
    return jnp.sum(predictions == labels) / labels.size

def gradient_descent(inp, labels, learning_rate, iterations):
    w1, b1, w2, b2 = init_params()
    for i in range(0, iterations):
        out1, act1, out2, act2 = forward(w1, b1, w2, b2, inp)
        dw1, db1, dw2, db2 = backward(out1, act1, out2, act2, w1, w2, inp, labels)
        w1, b1, w2, b2 = update_params(w1, b1, w2, b2, dw1, db1, dw2, db2, learning_rate)
        if i % 100 == 0:
            print("Iteration: ", i)
            predictions = get_predictions(act2)
            print(get_accuracy(predictions, labels))
    return w1, b1, w2, b2

In [10]:
jax.devices() 

[CpuDevice(id=0)]

In [11]:
w1, b1, w2, b2 = gradient_descent(x_train, y_train, 0.2, 1000)

Iteration:  0
[2 3 0 ... 0 3 7] [1 5 1 ... 7 6 9]
0.071707316
Iteration:  100
[1 6 1 ... 7 6 9] [1 5 1 ... 7 6 9]
0.7574634
Iteration:  200
[1 4 1 ... 7 6 9] [1 5 1 ... 7 6 9]
0.8167317
Iteration:  300
[1 4 1 ... 7 6 9] [1 5 1 ... 7 6 9]
0.8406829
Iteration:  400
[1 4 1 ... 7 6 9] [1 5 1 ... 7 6 9]
0.8523171
Iteration:  500
[1 5 1 ... 7 6 9] [1 5 1 ... 7 6 9]
0.8592683
Iteration:  600
[1 9 1 ... 7 6 9] [1 5 1 ... 7 6 9]
0.8588049
Iteration:  700
[1 9 1 ... 7 6 9] [1 5 1 ... 7 6 9]
0.86765856
Iteration:  800
[1 9 1 ... 7 6 9] [1 5 1 ... 7 6 9]
0.8342439
Iteration:  900
[1 9 1 ... 7 6 9] [1 5 1 ... 7 6 9]
0.87217075
