<a href="https://colab.research.google.com/github/PetchMa/deeplearning_fundamentals/blob/main/MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax import jit, vmap, grad, pmap,value_and_grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader 

In [None]:
seed = 0

def init_MLP(layer_widths,parent_key, scale =0.01):
  params = []
  keys = jax.random.split(parent_key,num=len(layer_widths)-1)
  for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):
    weight_key, bias_key = jax.random.split(key)

    params.append(
                   [scale*jax.random.normal(weight_key, shape=(out_width, in_width)),
                    scale*jax.random.normal(bias_key, shape=(out_width,))]
    )
  return params

rng = jax.random.PRNGKey(seed)

MLP_params = init_MLP([784, 512, 256, 10], rng)
print(jax.tree_map(lambda x: x.shape, MLP_params))

[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]


In [None]:
def MLP_predict(params, x):
  hidden_layers = params[:-1]

  activation = x
  for w,b in hidden_layers:
    activation = jax.nn.relu(jnp.dot(w,activation)+b)

  w_last, b_last = params[-1]
  logits = jnp.dot(w_last,activation)+b_last
  return logits-logsumexp(logits) # basically does softmax lol but its log of softmax

mnist_img_size = 784


batched_MLP_predict = vmap(MLP_predict, in_axes=(None, 0))
# small test
dummy_imgs_flat = np.random.randn(16, np.prod(mnist_img_size))
print(dummy_img_flat.shape)
predictions = batched_MLP_predict(MLP_params, dummy_imgs_flat)
print(predictions.shape)

(784,)
(16, 10)


In [None]:
# data loading
def custom_transform(x):
  return np.ravel(np.array(x, dtype=np.float32))
train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='train_mnist', train=False, download=True, transform=custom_transform)

In [None]:
def custom_collate_fn(batch):
  transposed_data = list(zip(*batch))
  labels = np.array(transposed_data[1])
  imgs = np.stack(transposed_data[0])
  return imgs, labels


train_loader = DataLoader(train_dataset, batch_size=128, shuffle = True, collate_fn=custom_collate_fn)
batch_data = next(iter(train_loader))
imgs = batch_data[0]
labels = batch_data[1]
print(labels.shape)

(128,)


In [None]:
num_epochs = 10

def loss_fn(params, imgs, gt_lbls):
  predictions = batched_MLP_predict(params, imgs)
  return -jnp.mean(predictions * gt_lbls)
def update(params, imgs, gt_lbls, lr = 0.01):
  loss, grads = value_and_grad(loss_fn)(params,imgs,gt_lbls)
  return loss, jax.tree_multimap(lambda p,g:p-lr*g, params, grads)

for epochs in range(num_epochs):
  for count, (imgs, lbls) in enumerate(train_loader):
    gt_labels = jax.nn.one_hot(lbls, len(MNIST.classes))
    loss, MLP_params = update(MLP_params, imgs, gt_labels)
    if count %10:
      print(loss)
  break

0.036939897
0.035040345
0.03589448
0.037989363
0.031648155
0.04111546
0.030190075
0.02479034
0.030622883
0.03551941
0.029108495
0.030718375
0.04005232
0.024839683
0.02788696
0.03585416
0.03162801
0.030825837
0.020814296
0.03646675
0.032824684
0.029956942
0.026571726
0.030289158
0.034255087
0.023487752
0.03094072
0.030177874
0.034139518
0.031127146
0.041550975
0.03225876
0.026452316
0.02748906
0.033869576
0.025060622
0.036224402
0.02608375
0.0396219
0.033136014
0.023931673
0.02887987
0.026016075
0.029170787
0.026117254
0.028350431
0.030607313
0.03254031
0.04584183
0.028943632
0.026195599
0.031373166
0.025547868
0.029473139
0.029517261
0.025783485
0.04441143
0.029654214
0.028630568
0.037721552
0.024736708
0.02483247
0.025957445
0.0323667
0.026985884
0.02844376
0.026298268
0.029461503
0.026763955
0.015076037
0.027490968
0.030276382
0.02285366
0.031382617
0.022840772
0.041468997
0.016474327
0.02953183
0.032792505
0.0372267
0.023551175
0.02310056
0.025620788
0.029785527
0.036935512
0.025445

In [None]:
def accuracy(params,loader):
  for img, lbls in loader:
    batched 