<a href="https://colab.research.google.com/github/YCCS-Summer-2023-DDNMA/project/blob/80-practice-with-neural-networks-mk/Michael_Kupferstein/nn_from_scratch/nnScratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MLP training on MNIST

In [34]:
# todo: add the training loop, loss fn

import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import jax
from jax import jit,vmap,pmap,grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [19]:
seed = 0
mnist_img_size = (28,28)

def init_MLP(layer_widths,parent_key,scale=0.01):

  params = []
  keys = jax.random.split(parent_key,num=len(layer_widths)-1)

  for in_width,out_width,key in zip(layer_widths[:-1],layer_widths[1:],keys):
    weight_key, bias_key = jax.random.split(key)
    params.append([
            scale*jax.random.normal(weight_key,shape=(out_width,in_width)),
            scale*jax.random.normal(bias_key,shape=(out_width,))
            ]
    )

  return params

# test
key = jax.random.PRNGKey(seed)
MLP_params = init_MLP([784,512,256,10],key)

print(jax.tree_map(lambda x:x.shape,MLP_params))

[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]


In [24]:
from jax._src.numpy.reductions import percentile
def MLP_predict(params,x):
  hidden_layers = params[:-1]

  activation = x
  for w,b in hidden_layers:
    activation = jax.nn.relu(jnp.dot(w,activation) + b)

  w_last,b_last = params[-1]
  logits = jnp.dot(w_last,activation) + b_last

  # log(exp(o1)) - log(sum(exp(o1),exp(o2),...exp(o10)))
  # log(exp(o1)/ sum(...))
  return logits - logsumexp(logits)

# tests

#dummy_img_flat = np.random.randn(np.prod(mnist_img_size))
#print(dummy_img_flat.shape)

#predications = MLP_predict(MLP_params,dummy_img_flat)
#rint(predications.shape)

batched_MLP_predict = vmap(MLP_predict,in_axes=(None,0))

dummy_imgs_flat = np.random.randn(16, np.prod(mnist_img_size))
print(dummy_imgs_flat.shape)
predications = batched_MLP_predict(MLP_params, dummy_imgs_flat)
print(predications.shape)

(16, 784)
(16, 10)


In [39]:
# todo: add data loading in PyTorch
def custom_transform(x):
  return np.ravel(np.array(x,dtype=np.float32))

def custom_collate_fn(batch):
  transposed_data = list(zip(*batch))

  labels = np.array(transposed_data[1])
  imgs = np.stack(transposed_data[0])

  return imgs, labels

batch_size = 128
train_dataset = MNIST(root='train_mnist',train=True,download=True,transform=custom_transform)
test_dataset = MNIST(root='test_mnist',train=False,download=True,transform=custom_transform)

img = train_dataset[0][0]
print(img.shape)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn)

batch_data = next(iter(train_loader))
imgs = batch_data[0]
lbls = batch_data[1]
print(imgs.shape,imgs[0].dtype,lbls.shape,lbls[0].dtype)

(784,)
(128, 784) float32 (128,) int64


In [26]:
import os

print(os.getcwd())

/content


# Visualizations

In [None]:
# todo: visualize the MPL weight
# todo: visulaize emedding using t-SNE
# todo: dead neurons

# Parallelization