In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [3]:
seed = 0


def init_MLP(parkey, layer_widths, scale = 0.01):
    params = []
    keys = jax.random.split(parkey, num=len(layer_widths)-1)
    for in_width, out_width, key in zip(layer_widths[:-1],layer_widths[1:], keys):
        wkey, bkey = jax.random.split(key)
        params.append([scale*jax.random.normal(wkey, shape=(out_width,in_width)),scale*jax.random.normal(bkey, shape=(out_width))])

    return params


key = jax.random.PRNGKey(seed)
MLP_params = init_MLP(key, [784,512,256,10])
print(jax.tree.map(lambda x: x.shape, MLP_params))


[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]


In [4]:
# @jax.jit
def predict(params, x):
    # print("hi")
    hidden_layers = params[:-1]

    activation = x
    for w, b in hidden_layers:
        activation = jax.nn.relu(jnp.dot(w, activation) + b)

    w_last, b_last = params[-1]

    logits = jnp.dot(w_last, activation) + b_last

    return logits - logsumexp(logits)

dummy_img_flat = np.random.randn(16, 784)

prediction = jax.vmap(predict, in_axes=(None, 0))(MLP_params, dummy_img_flat)
batch_predict = jax.vmap(predict, in_axes=(None, 0))
# prediction = jax.vmap(predict, in_axes=(None, 1))(MLP_params, dummy_img_flat)
print(prediction.shape)


(16, 10)


In [5]:
def custom_collate(batch):
    transposed_data = list(zip(*batch))
    # print((transposed_data))

    imgs = np.array(transposed_data[0])
    labels = np.array(transposed_data[1])

    # print(len(imgs))

    return imgs, labels


train_dataset = MNIST(root='./train_mnist',train=True, download=True,transform=lambda x: np.ravel(np.array(x, dtype=np.float32)))
test_dataset = MNIST(root='./test_mnist',train=False, download=True,transform=lambda x: np.ravel(np.array(x, dtype=np.float32)))
# print(type(train_dataset))
# print((train_dataset[0][0].shape))

BATCH_SIZE = 128

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate)

batch_data = next(iter(train_loader))
print(len(batch_data[0][0]))

784


In [8]:
from jax import grad, value_and_grad


NUM_EPOCHS = 10

def loss(params, imgs, gt_labels):
    output = batch_predict(params, imgs)
    return -jnp.mean(output*gt_labels)

def update(params, imgs, gt_labels, lr=0.01):
    l, grads = value_and_grad(loss)(params,imgs,gt_labels)
    return l, jax.tree.map(lambda p, g: p - lr*g, params, grads)

for epoch in range((NUM_EPOCHS)):

    for cnt, (imgs, labels) in enumerate(train_loader):
        gt_labels = jax.nn.one_hot(labels,len(MNIST.classes))
        l, MLP_params = update(MLP_params, imgs, gt_labels)

        # if cnt % 50 == 0:
        print(l)
    break



0.034985673
0.030251497
0.015898306
0.027458886
0.01796588
0.015383233
0.012061434
0.03117384
0.040019322
0.019851575
0.009944989
0.02219074
0.021367017
0.029891139
0.017947875
0.02304166
0.03506242
0.030132512
0.032620814
0.04325784
0.033331152
0.016994527
0.01829738
0.03180411
0.033397105
0.026375938
0.037283655
0.016214786
0.019808222
0.023359388
0.028543223
0.022539264
0.027493984
0.028903132
0.035730515
0.013779183
0.011998841
0.020292196
0.022780454
0.013053355
0.020621426
0.021685978
0.034751084
0.020720262
0.03930826
0.023384405
0.021060808
0.03641354
0.02548731
0.018885398
0.024537155
0.022366678
0.02661773
0.03961853
0.0203943
0.021237403
0.022857346
0.017862849
0.024456888
0.027587492
0.027551262
0.031196887
0.026749391
0.029305452
0.021211242
0.017546715
0.02685385
0.02631746
0.016128778
0.019181807
0.018716136
0.030002391
0.031723198
0.024329623
0.027736623
0.020719036
0.026990367
0.031531096
0.024753598
0.03018405
0.0380037
0.0359079
0.020102404
0.028484423
0.023890516
0.