In [None]:
import jax
from sklearn.datasets import fetch_openml
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import grad

In [None]:
mnist = fetch_openml('mnist_784')

  warn(


In [None]:
x, y = mnist['data'], mnist['target']
x=x/255
y=y.astype(jnp.int8)
x=jnp.array(x,dtype='int32')
y=jnp.array(y,dtype='int32')

In [None]:
x_train, x_test = x[:60000], x[60000:70000]
y_train, y_test = y[:60000], y[60000:70000]



In [None]:
x_train.shape

(60000, 784)

In [None]:
x_train=x_train.T
x_test=x_test.T


In [None]:
y_train=y_train.transpose()
y_test=y_test.transpose()

In [None]:
y_train.shape

(60000,)

In [None]:
y_train=y_train[:, np.newaxis]
y_test=y_test[:, np.newaxis]

In [None]:
y_train.shape

(60000, 1)

In [None]:
y_train=y_train.T
y_test=y_test.T

In [None]:
y_train.shape

(1, 60000)

In [None]:
x_train.shape

(784, 60000)

In [None]:
if jax.devices("gpu"):
    x_train = jax.device_put(x_train, device=jax.devices("gpu")[0])
    y_train = jax.device_put(y_train, device=jax.devices("gpu")[0])
    x_test = jax.device_put(x_test, device=jax.devices("gpu")[0])
    y_test = jax.device_put(y_test, device=jax.devices("gpu")[0])
else:
    print("No GPU available, using CPU.")

In [None]:
Y_train_temp = y_train.reshape(y_train.shape[1],)
Y_train = jnp.zeros((Y_train_temp.size, 10))
Y_train = Y_train.at[jnp.arange(Y_train_temp.size), Y_train_temp].set(1)
Y_train = Y_train.T

Y_test_temp = y_test.reshape(y_test.shape[1],)
Y_test = jnp.zeros((Y_test_temp.size, 10))
Y_test = Y_test.at[jnp.arange(Y_test_temp.size), Y_test_temp].set(1)
Y_test = Y_test.T

In [None]:
def initialize_parameters_deep(layer_dims):
   key = jax.random.PRNGKey(0)
   parameters = {}
   L = len(layer_dims)
   for l in range(1, L):
          weight_key, key = jax.random.split(key)
          #parameters['W' + str(l)] =jax.random.normal(key=jax.random.PRNGKey(0), shape=(layer_dims[l], layer_dims[l - 1])) * 0.01
          parameters['W' + str(l)] = jax.random.normal(weight_key, shape=(layer_dims[l], layer_dims[l - 1])) * 0.01
          parameters['b' + str(l)] =jnp.zeros((layer_dims[l],1))
   return parameters

In [None]:
def linear_forward(A, W, b):
    Z=jnp.dot(W,A)+b
    cache=(A,W,b)
    return Z, cache

In [None]:
def sigmoid(z):
   s=1/(1+jnp.exp(-z))
   cache=(z)
   return s,cache

In [None]:
def relu(z):
  #s=np.maximum(0,z)
  #cache=(z)
  return jnp.maximum(0,z),z

In [None]:
def softmax(z):
   s = jnp.exp(z)/jnp.sum(jnp.exp(z), axis = 0, keepdims = True)
   #activation_cache = (z)
   return s, z

In [None]:
def linear_activation_forward(A_prev, W, b, activation):

     if activation == "relu":
        Z, linear_cache = linear_forward(A_prev, W, b)
        A, activation_cache=relu(Z)
     elif activation == "softmax":
        Z, linear_cache = linear_forward(A_prev, W, b)
        A, activation_cache = softmax(Z)
     cache = (linear_cache, activation_cache)

     return A, cache

In [None]:
def L_model_forward(X, parameters):
   caches = []
   A = X
   L = len(parameters) // 2
   for l in range(1,L):
     A_prev=A
     A, cache= linear_activation_forward(A_prev, parameters["W"+str(l)],parameters["b"+str(l)],"relu")
     caches.append(cache)
   AL, cache= linear_activation_forward(A, parameters["W"+str(L)],parameters["b"+str(L)],"softmax")

   caches.append(cache)

   return AL,caches

In [None]:
def compute_cost(AL, Y):
    m = Y.shape[1]
    cost = - jnp.sum(Y*jnp.log(AL))/m
    jnp.squeeze(cost)

    return cost

In [None]:
def linear_backward(dZ, cache):
   A_prev, W, b = cache
   m = A_prev.shape[1]
   dW=(1/m)*(jnp.dot(dZ,A_prev.T))
   db=(1/m)*jnp.sum(dZ,axis=1,keepdims=True)
   dA_prev=jnp.dot(W.T,dZ)
   return dA_prev,dW,db

In [None]:
def softmax_backward(AL, Y):
    dZ = AL- Y
    return dZ

In [None]:
def relu_backward(dA, cache):
    Z = cache
    dZ = jnp.where(Z > 0, dA, 0)
    return dZ

In [None]:
def linear_activation_backward(Y,AL,dA, cache, activation):
  linear_cache, activation_cache = cache
  if activation == "relu":
    dZ=relu_backward(dA,activation_cache)
    dA_prev,dW,db=linear_backward(dZ,linear_cache)
  elif activation == "softmax":
    dZ=softmax_backward(AL,Y)
    dA_prev,dW,db=linear_backward(dZ,linear_cache)
  return dA_prev, dW, db

In [None]:
def L_model_backward(AL, Y, caches):
  grads = {}
  L = len(caches)
  m = AL.shape[1]
  Y = Y.reshape(AL.shape)
  dAL = -Y/AL
  #dAL=-(np.divide(Y, AL) - np.divide(1 - Y, 1 - AL))
  current_cache = caches[L-1]
  dA_prev_temp, dW_temp, db_temp =linear_activation_backward(Y,AL,dAL,current_cache,"softmax")
  grads["dA" + str(L-1)] = dA_prev_temp
  grads["dW" + str(L)] = dW_temp
  grads["db" + str(L)] = db_temp
  for l in reversed(range(L-1)):
    current_cache = caches[l]
    dA_prev_temp, dW_temp, db_temp =linear_activation_backward(Y,AL,grads["dA"+str(l+1)],current_cache,"relu")
    grads["dA" + str(l)] = dA_prev_temp
    grads["dW" + str(l+1)] = dW_temp
    grads["db" + str(l+1)] = db_temp
  return grads

In [None]:
def update_parameters(params, grads, learning_rate):
  parameters = params.copy()
  L = len(parameters) // 2
  for l in range(L):
        parameters["W" + str(l+1)] =parameters["W" + str(l+1)]-learning_rate*grads["dW"+str(l+1)]
        parameters["b" + str(l+1)] =parameters["b" + str(l+1)]-learning_rate*grads["db"+str(l+1)]

        # YOUR CODE ENDS HERE
  return parameters

In [None]:
def get_predictions(Y_hat):
    return jnp.argmax(Y_hat,0)

def get_accuracy(predictions,Y):
    predictions = predictions.reshape(1,predictions.shape[0])
    #print(predictions.shape)
    ans = 0
    for i in range(Y.shape[1]) :
        predict = predictions[0,i]
        if Y[predict,i]==1 :
            ans+=1
   # print(ans)
    return str((ans/Y.shape[1])*100) + '%'

In [None]:
layers_dims = [784, 30, 20, 10]

In [None]:

def L_layer_model(X, Y, layers_dims, learning_rate = 0.0075, num_iterations = 3000,print_cost=False):
  #np.random.seed(1)
  grads = {}
  costs = []
  m = X.shape[1]
  #print(layers_dims)
  cost = 2.5
  parameters=initialize_parameters_deep(layers_dims)
  for i in range(0, num_iterations):
    AL, caches=L_model_forward(X, parameters)
    grads=L_model_backward(AL, Y, caches)
    parameters=update_parameters(parameters, grads, learning_rate)
    cost = compute_cost(AL,Y)
#    Y_predict = jnp.zeros(AL.shape)
    Y_predict = jnp.zeros_like(AL)
    indices = (jnp.argmax(AL, axis=0), jnp.arange(AL.shape[1]))
    Y_predict = Y_predict.at[indices].set(1)
   # Y_predict = jax.ops.index_update(Y_predict, indices, 1)
   # Y_predict[jnp.argmax(AL, axis = 0), jnp.arange(AL.shape[1])] = 1
   # if i%100:
    #  learning_rate=learning_rate/(1+0.0002)
    if print_cost and i % 100 == 0 or i == num_iterations - 1:
            print("Cost after iteration {}: {}".format(i, jnp.squeeze(cost)))
            print(learning_rate)
            #print("train accuracy: {} %".format(100 - np.mean(np.abs(Y_predict - Y)) * 100))
          #  print("accuracy : " , get_accuracy(get_predictions(AL),Y))

    if i % 100 == 0 or i == num_iterations:
            costs.append(cost)

  return parameters,costs, Y_predict,AL

In [None]:
parameters, costs, Y_predict,AL =L_layer_model(x_train,Y_train, layers_dims, 0.8, 3000, True)
print("accuracy : " , get_accuracy(get_predictions(AL),Y_train))

Cost after iteration 0: 2.302584409713745
0.8
Cost after iteration 100: 2.3011558055877686
0.8
Cost after iteration 200: 2.3011536598205566
0.8
Cost after iteration 300: 2.3011510372161865
0.8
Cost after iteration 400: 2.3011474609375
0.8
Cost after iteration 500: 2.3011417388916016
0.8
Cost after iteration 600: 2.3011317253112793
0.8
Cost after iteration 700: 2.3011131286621094
0.8
Cost after iteration 800: 2.301072120666504
0.8
Cost after iteration 900: 2.30096173286438
0.8
Cost after iteration 1000: 2.300509452819824
0.8
Cost after iteration 1100: 2.2958993911743164
0.8
Cost after iteration 1200: 2.221557140350342
0.8
Cost after iteration 1300: 2.1743087768554688
0.8
Cost after iteration 1400: 2.086869478225708
0.8
Cost after iteration 1500: 2.064865827560425
0.8
Cost after iteration 1600: 1.9783920049667358
0.8
Cost after iteration 1700: 1.9253406524658203
0.8
Cost after iteration 1800: 1.8988292217254639
0.8
Cost after iteration 1900: 1.860556960105896
0.8
Cost after iteration 200

KeyboardInterrupt: ignored