<a href="https://colab.research.google.com/github/StarPrecursor/Adv_Python_Practice/blob/main/week_4/Week4_excercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Week 4 Excercise
Train a simple (single layer, single fully connected layer, etc.) neural network through JAX vectorization (maybe you can try vectorization techniques in PyTorch and Tensorflow respectively). Compare their performance with each and with the results derived in Week 3’s exercise.

## Common Settings

In [26]:
import matplotlib.pyplot as plt
from jax import random

seed = 42
seed_prng = random.PRNGKey(seed)

n_layers = 3
n_nodes = 100
init_scale = 0.05

n_epochs = 20
batch_size = 128
learn_rate = 0.0001

## Prepare Data

In [27]:
# from sklearn.datasets import load_breast_cancer
# 
# data = load_breast_cancer()
# X = data.data.transpose()
# y = data.target
# 
# print(type(X), X.shape)
# print(type(y), y.shape)
# print(data.feature_names)
# print(data.target_names)

In [28]:
from sklearn.datasets import make_blobs
X, y = make_blobs(
  n_samples=5000, n_features=50, centers=2, cluster_std=25, random_state=0
)
X = X.transpose()
y = y.transpose()

## Jax implementation

In [29]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random

### Build components

In [30]:
def lin(params, x):
  return jnp.dot(params[0], x) + params[1]

def relu(x):
  return jnp.maximum(0, x)

def sigmoid(x):
  return 0.5 * (jnp.tanh(x / 2) + 1)

def loss(params, y_pred, y_true):
  label_probs = y_pred * y_true + (1 - y_pred) * (1 - y_true)
  return -jnp.sum(jnp.log(label_probs))


### Utilities

In [31]:
def predict(params_list, x):
  cur = x
  # hidden
  for params in params_list[:-1]:
    cur = relu(lin(params, cur))
  # output
  params = params_list[-1]

  return sigmoid(lin(params, cur))

def batch_predict(params_list, x):
  # vmap vectorize
  return vmap(predict, in_axes=(None, 1))(params_list, x)

def loss(params_list, x, y):
  preds = predict(params_list, x)
  label_probs = preds * y + (1 - preds) * (1 - y)
  # jnp.sum vectorize
  return -jnp.sum(jnp.log(label_probs))

loss_jit = jit(grad(loss))

def update(params_list, x, y):
  grads = loss_jit(params_list, x, y)
  new_params_list = []
  for (w, b), (dw, db) in zip(params_list, grads):
    new_params_list.append(
      (w - learn_rate * dw, b - learn_rate * db)
    )
  return new_params_list

### Model

In [32]:
class JAX_Model:

  def __init__(self, input_dim):
    self.dims = [input_dim] + [n_nodes] * n_layers + [1]
    self.params_list = []

  def init_params(self):
    seed_list = random.split(seed_prng, len(self.dims) - 1)
    for i in range(len(self.dims) - 1):
      w_seed, b_seed = random.split(seed_list[i])
      dim_in = self.dims[i]
      dim_out = self.dims[i+1]
      w = init_scale * random.normal(w_seed, (dim_out, dim_in))
      b = init_scale * random.normal(b_seed, (dim_out, 1))
      self.params_list.append((w, b))

  def train(self, x, y):
    for epoch in range(n_epochs):
      self.params_list = update(self.params_list, x, y)
      cur_loss = loss(self.params_list, x, y)
      print(f"epoch={epoch}, loss={cur_loss}")
  
  def predict(self, x):
    return predict(self.params_list, x)

  def batch_predict(self, x):
    return batch_predict(self.params_list, x)


In [33]:
jax_mod = JAX_Model(X.shape[0])
jax_mod.init_params()
y_init = jax_mod.predict(X)
print(y_init.shape)
#print(y_init)
cur_loss = loss(jax_mod.params_list, X, y)
print(cur_loss)

(1, 5000)
3709.6646


In [34]:
%%time
jax_mod.train(X, y)

epoch=0, loss=6534.31494140625
epoch=1, loss=9767.412109375
epoch=2, loss=3760.5927734375
epoch=3, loss=3072.244873046875
epoch=4, loss=2404.154296875
epoch=5, loss=8105.107421875
epoch=6, loss=2960.2490234375
epoch=7, loss=1782.263427734375
epoch=8, loss=2220.838623046875
epoch=9, loss=2514.513671875
epoch=10, loss=1711.5341796875
epoch=11, loss=1349.20068359375
epoch=12, loss=1298.455322265625
epoch=13, loss=1274.458251953125
epoch=14, loss=1255.9686279296875
epoch=15, loss=1239.9541015625
epoch=16, loss=1225.56005859375
epoch=17, loss=1212.0341796875
epoch=18, loss=1199.35546875
epoch=19, loss=1187.437255859375
CPU times: user 1.5 s, sys: 80.7 ms, total: 1.58 s
Wall time: 1.07 s


---
Comparing with Tensorflow and PyTorch

---

## Tensorflow implementation

In [35]:
import tensorflow as tf

In [36]:
print(X.shape, y.shape)

(50, 5000) (5000,)


In [37]:
input_dim = X.shape[0]
X_t = X.transpose()

tf_mod = tf.keras.Sequential()
tf_mod.add(tf.keras.layers.Dense(n_nodes, input_shape=(input_dim,)))
for i in range(n_layers):
  tf_mod.add(tf.keras.layers.Dense(n_nodes))
tf_mod.add(tf.keras.layers.Dense(1))

opt = tf.keras.optimizers.SGD(learning_rate=0.01)
tf_mod.compile(optimizer=opt, loss="BinaryCrossentropy")

In [38]:
%%time
tf_mod.fit(X_t, y, batch_size=batch_size, epochs=n_epochs)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
CPU times: user 3.1 s, sys: 117 ms, total: 3.21 s
Wall time: 2.62 s


<keras.callbacks.History at 0x7f55cfa2dbd0>

## PyTorch implementation

In [39]:
import torch
import torch.nn as nn

In [40]:
input_dim = X.shape[0]
X_t = X.transpose()

In [41]:
layers = []
layers.append(nn.Linear(input_dim, n_nodes))
layers.append(nn.ReLU())
for i in range(n_layers):
  layers.append(nn.Linear(n_nodes, n_nodes))
  layers.append(nn.ReLU())
layers.append(nn.Linear(n_nodes, 1))
layers.append(nn.Sigmoid())

model = nn.Sequential(*layers)
print(model)

Sequential(
  (0): Linear(in_features=50, out_features=100, bias=True)
  (1): ReLU()
  (2): Linear(in_features=100, out_features=100, bias=True)
  (3): ReLU()
  (4): Linear(in_features=100, out_features=100, bias=True)
  (5): ReLU()
  (6): Linear(in_features=100, out_features=100, bias=True)
  (7): ReLU()
  (8): Linear(in_features=100, out_features=1, bias=True)
  (9): Sigmoid()
)


In [42]:
data_x = torch.from_numpy(X_t).float()
data_y = torch.from_numpy(y.reshape(-1,1)).float()

loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)

In [43]:
%%time
for epoch in range(n_epochs):
  pred_y = model(data_x).float()
  loss = loss_function(pred_y, data_y)
  model.zero_grad()
  loss.backward()
  optimizer.step()

CPU times: user 593 ms, sys: 20.9 ms, total: 614 ms
Wall time: 612 ms
