## Task3

In [3]:
import numpy as np
from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split

Create Dataset

In [4]:
# load dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

X_data = np.concatenate((X_train, X_test), axis=0)
y_data = np.concatenate((y_train, y_test), axis=0)
print(X_data.shape,y_data.shape)

Xtrain, Xtest, ytrain, ytest = train_test_split(X_data, y_data, test_size=0.2, random_state=42)
print(Xtrain.shape, ytrain.shape, Xtest.shape, ytest.shape)

(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)
(70000, 28, 28) (70000,)
(56000, 28, 28) (56000,) (14000, 28, 28) (14000,)


Import Necessary Libraries

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import grad, vmap, value_and_grad
from sklearn.metrics import accuracy_score
key = jax.random.PRNGKey(0)
from tqdm import tqdm
from jax.scipy.special import logsumexp
import numpy as np

## Exp 1
Tried one approach but probably used a wrong loss function so had poor prediction

In [None]:
class neural_network():
  def __init__(self):
    self.loss_ = []

  def relu_activation(self,x): # relu activation
    return jnp.maximum(0, x) 

  def forward(self,params,x): # computing forward propogation
    hidden1, hidden2, last = params
    x = self.relu_activation(jnp.dot(x,hidden1['weights'])+ hidden1['biases'])
    x = self.relu_activation(jnp.dot(x,hidden2['weights'])+ hidden2['biases'])
    x = jnp.dot(x,last['weights'])+ last['biases']
    return x 

  def loss(self,params, x, y): # loss function
    batch_forward = vmap(self.forward, in_axes=(None, 0), out_axes=0)
    preds = batch_forward(params, x)
    return -jnp.sum(preds*y)

  def init_neuralnet_params(self,layer_dims): # initializing parameters
    params = []
    for i,(in_dim, out_dim) in enumerate(zip(layer_dims[:-1], layer_dims[1:])):
      params.append(
          dict(weights=jax.random.normal(jax.random.PRNGKey(i),(in_dim, out_dim)),
              biases=jnp.ones(shape=(out_dim,))
              )
      )
    return params


  def update(self, params, x, y,lr): # updating the parameters
    grads = jax.grad(self.loss)
    grad_loss = grads(params, x, y)

    # using jax.tree
    return jax.tree_map(
      lambda p, g: p - lr * g, params, grad_loss)
    

  def fit(self, X, y, batch_size=1, n_iter=150, lr=0.01, lr_type='constant'): # training the model

    X = jnp.array(X)
    X = jnp.reshape(X,(X.shape[0], 28*28))
    y = jnp.array(y)

    index = 0
    params = self.init_neuralnet_params([784, 512, 512, 10])

    for iter in tqdm(range(n_iter)):
      if index >= X.shape[0]:
        index = 0

      if lr_type == 'inverse':
                lr = lr / (iter+1)

      sub_X = X[index:index+batch_size, :]
      sub_y = y[index:index+batch_size]

      params_updated =  self.update(params, sub_X, sub_y,lr)

      self.loss_.append(self.loss(params, sub_X, sub_y))

      index += batch_size

    return params_updated

  def predict_values(self, X): # prediting 
    X = jnp.array(X)
    X = jnp.reshape(X,(X.shape[0], 28*28))

    batch_forward = vmap(self.forward, in_axes=(None, 0), out_axes=0)
    return batch_forward(params, X)

  def predict(self,params,X):
    preds = self.predict_values(X)
    return np.argmax(preds, axis=1)

  def plot_loss(self):
    fig = plt.figure(figsize=(10,8))
    plt.plot(self.loss_)
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title('LossVsIterations')
    plt.legend()
    # plt.show()
    plt.savefig('LossVsIterations.png')

In [None]:
model = neural_network()

params = model.fit(Xtrain, ytrain, batch_size=10, n_iter=15000, lr=0.03, lr_type='constant')
y_pred_test = model.predict(params, Xtest)
y_pred_train = model.predict(params, Xtrain)
acc_test = accuracy_score(ytest, y_pred_test)
acc_train = accuracy_score(ytrain, y_pred_train)

print("Test Accuracy for our model =", acc_test)
print("Train Accuracy for our model =", acc_train)
model.plot_loss()

## Exp2 
using cross entropy loss function and defining all required functions. Unable to predict due to some JAX issue which can be resolved. 

In [14]:
class neural_network1():
  def __init__(self):
    self.loss_ = []

  def relu_activation(self, x):
    return jnp.maximum(0, x)

  def softmax(self, x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=0)

  def forward(self,params,x): # forward pass across the network
    for w, b in params[:-1]: # iterating over the first and inner layers
      x = self.relu_activation(jnp.dot(x,w)+ b)
    w, b = params[-1] # for the last layer
    x = jnp.dot(x,w)+ b
    return x 

  def batch_forward(self, params, x):
    batch = vmap(self.forward, in_axes=(None, 0), out_axes=0)
    return batch(params, x)

  def loss(self, params, x, y): # computing the loss = -summation(y * log y-hat)
    x_ = self.batch_forward(params, x)
    y_pred = self.softmax(x_)
    print(y_pred)
    loss= -1* jnp.sum(jnp.dot(y,jnp.log(y_pred)))
    return loss

  def init_neuralnet_params(self,layer_dims): # randomly initializing the parameters for the network with the weights according to dimensions of layers
    params = []
    for i,(in_dim, out_dim) in enumerate(zip(layer_dims[:-1], layer_dims[1:])):
      weights=jax.random.normal(jax.random.PRNGKey(i),(in_dim, out_dim))
      biases=jnp.ones((out_dim,))
      params.append((weights, biases))
    return params
  
  def fit(self, X, y, batch_size=1, n_iter=150, lr=0.01, lr_type='constant'):

    X = jnp.array(X)
    X = jnp.reshape(X,(X.shape[0], 28*28))
    y = jnp.array(y)

    index = 0
    params = self.init_neuralnet_params([784, 512, 512, 10])
    grads = jax.grad(self.loss)

    for iter in tqdm(range(n_iter)):
      if index >= X.shape[0]:
        index = 0

      if lr_type == 'inverse':
                lr = lr / (iter+1)

      sub_X = X[index:index+batch_size, :]
      sub_y = y[index:index+batch_size]

      grad_loss = grads(params, sub_X, sub_y)
      grad_loss_np = np.array(grad_loss)
      params_np = np.array(params)
      params_np = params_np - (lr * grad_loss_np)
      self.loss_.append(self.loss(params, X, y))

      index += batch_size

    params_updated = params_np

    return params_updated

  def predict_values(self, X):
    X = jnp.array(X)
    X = jnp.reshape(X,(X.shape[0], 28*28))
    return self.softmax(self.batch_forward(params, X))

  def predict(self,params,X):
    preds = self.predict_values(X)
    # print(preds)
    return np.argmax(preds, axis=1)

  def plot_loss(self):
    fig = plt.figure(figsize=(10,8))
    plt.plot(self.loss_)
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title('LossVsIterations')
    plt.legend()
    # plt.show()
    plt.savefig('LossVsIterations.png')

In [None]:
model = neural_network1()

params = model.fit(Xtrain, ytrain, batch_size=256, n_iter=50, lr=0.03, lr_type='constant')
y_pred_test = model.predict(params, Xtest)
y_pred_train = model.predict(params, Xtrain)
acc_test = accuracy_score(ytest, y_pred_test)
acc_train = accuracy_score(ytrain, y_pred_train)

print("Test Accuracy for our model =", acc_test)
print("Train Accuracy for our model =", acc_train)
model.plot_loss()

## References

* https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
* https://jax.readthedocs.io/en/latest/notebooks/quickstart.html