In [25]:
import numpy as np
import jax.numpy as jnp
from jax import grad,random
from jax import jit,vmap
from jax.tree_util import tree_map
from jax.nn import relu
from jax.scipy.special import logsumexp

In [22]:
seed = 0

def init_MLP(layer_widths, parent_key, scale = 0.01):
  params = []

  keys = random.split(parent_key, num=len(layer_widths)-1)

  for n_in, n_out, key in zip(layer_widths[:-1],layer_widths[1:],keys):
    weight_key, bias_key = random.split(key)

    params.append([
        scale*random.normal(weight_key, shape=(n_out,n_in)),
        scale*random.normal(bias_key, shape=(n_out,))
    ])

  return params

key = random.PRNGKey(seed)
MLP = init_MLP([784,128,128,10], key)
tree_map(lambda x: x.shape,MLP)

[[(128, 784), (128,)], [(128, 128), (128,)], [(10, 128), (10,)]]

In [27]:
def MLP_predict(params, x):
  hidden_layers = params[:-1]

  activation = x
  for w,b in hidden_layers:
    activation = relu(jnp.dot(w,activation)+b)

  w_last, b_last = params[-1]
  logits = jnp.dot(w_last,activation)+b_last

  return logits-logsumexp(logits)

batched_MLP_predict = vmap(MLP_predict, in_axes=(None,0))

dummy_img = np.random.randn(16,np.prod((28,28)))
pred = batched_MLP_predict(MLP, dummy_img)
pred.shape

(16, 10)

## ETL

In [32]:
import os
from torchvision.datasets import MNIST

In [33]:
print(os.getcwd())

/content


In [34]:
train_dataset = MNIST(root='train_mnist',train=True,download=True, transform=None)
print(type(train_dataset))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 53.1MB/s]


Extracting train_mnist/MNIST/raw/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 2.05MB/s]

Extracting train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 11.9MB/s]


Extracting train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 7.79MB/s]

Extracting train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw

<class 'torchvision.datasets.mnist.MNIST'>



