In [3]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [4]:
seed = 0


def init_MLP(parkey, layer_widths, scale = 0.01):
    params = []
    keys = jax.random.split(parkey, num=len(layer_widths)-1)
    for in_width, out_width, key in zip(layer_widths[:-1],layer_widths[1:], keys):
        wkey, bkey = jax.random.split(key)
        params.append([scale*jax.random.normal(wkey, shape=(out_width,in_width)),scale*jax.random.normal(bkey, shape=(out_width))])

    return params


key = jax.random.PRNGKey(seed)
MLP_params = init_MLP(key, [784,512,256,10])
print(jax.tree.map(lambda x: x.shape, MLP_params))


[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]


In [5]:
# @jax.jit
def predict(params, x):
    # print("hi")
    hidden_layers = params[:-1]

    activation = x
    for w, b in hidden_layers:
        activation = jax.nn.relu(jnp.dot(w, activation) + b)

    w_last, b_last = params[-1]

    logits = jnp.dot(w_last, activation) + b_last

    return logits - logsumexp(logits)

dummy_img_flat = np.random.randn(16, 784)

prediction = jax.vmap(predict, in_axes=(None, 0))(MLP_params, dummy_img_flat)
batch_predict = jax.vmap(predict, in_axes=(None, 0))
# prediction = jax.vmap(predict, in_axes=(None, 1))(MLP_params, dummy_img_flat)
print(prediction.shape)


(16, 10)


In [10]:
def custom_collate(batch):
    transposed_data = list(zip(*batch))
    # print((transposed_data))

    imgs = np.array(transposed_data[0])
    labels = np.array(transposed_data[1])

    # print(len(imgs))

    return imgs, labels


train_dataset = MNIST(root='./train_mnist',train=True, download=True,transform=lambda x: np.ravel(np.array(x, dtype=np.float32)))
test_dataset = MNIST(root='./test_mnist',train=False, download=True,transform=lambda x: np.ravel(np.array(x, dtype=np.float32)))
# print(type(train_dataset))
# print((train_dataset[0][0].shape))

BATCH_SIZE = 128

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate)

batch_data = next(iter(train_loader))
print(len(batch_data[0][0]))

784


In [11]:
from jax import grad, value_and_grad


NUM_EPOCHS = 10

def loss(params, imgs, gt_labels):
    output = batch_predict(params, imgs)
    return -jnp.mean(output*gt_labels)

def update(params, imgs, gt_labels, lr=0.01):
    l, grads = value_and_grad(loss)(params,imgs,gt_labels)
    return l, jax.tree.map(lambda p, g: p - lr*g, params, grads)

for epoch in range((NUM_EPOCHS)):

    for cnt, (imgs, labels) in enumerate(train_loader):
        gt_labels = jax.nn.one_hot(labels,len(MNIST.classes))
        l, MLP_params = update(MLP_params, imgs, gt_labels)

        if cnt % 50 == 0:
            print(l)
    break



0.24243307
0.09901209
0.06954474
0.04485036


KeyboardInterrupt: 