In [None]:
from google.colab import drive             
drive.mount('/content/drive')

!pip install jax

Mounted at /content/drive
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:

# build it with guidance from  Dense noise leaf detector.py from MulticlassLR directory in the same repo.

import numpy as np
import matplotlib.pyplot as plt
import copy
import sys
from Dataset_setupWithJAX import dataset_setup
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import jax
import jax.numpy as jnp
from jax import random, jit
# from numba import jit

X, Y = dataset_setup("/content/drive/MyDrive/CitrusDataset/Leaves/Leafdataset/Training/")
X_test, Y_test = dataset_setup("/content/drive/MyDrive/CitrusDataset/Leaves/Leafdataset/Testing/")

train_set_flatten = jnp.reshape(X,(X.shape[0],-1)).T
test_set_flatten = jnp.reshape(X_test, (X_test.shape[0],-1)).T

train_class = jnp.reshape(Y, (1, Y.shape[0]))
test_class = jnp.reshape(Y_test, (1, Y_test.shape[0]))

# n_x = 12288     # num_px * num_px * 3
# n_h = 7
# n_y = 1
# layers_dims = (n_x, n_h, n_y)
# learning_rate = 0.0075



# n_in = 12288

# n_hid1 = 128
# n_hid2 = 128
# n_hid3 = 128
# n_hid4 = 128
# n_out = 1

layers_dims = (12288, 128, 128, 128, 128, 1)


def init_params(layer_dims):

    key = jax.random.PRNGKey(0)
    params = {}
    L = len(layer_dims)


    for l in range(1, L):
        params["W"+str(l)] = random.randint(key, shape = (layers_dims[l], layer_dims[l-1]), minval = 0.1, maxval = 0.00001) / jnp.sqrt(layer_dims[l-1]) # * 0.01
        params["b"+str(l)] = jnp.zeros((layer_dims[l], 1))


    return params


def sigmoid(z):

    s = 1 /(1+jnp.exp(-z))
    cache = z

    return s, cache


def sigmoid_back(dA, cache):

    Z = cache
    s = 1/(jnp.exp(-Z))
    dZ = dA * s * (1-s)

    return dZ


def relu(z):

    a = jnp.maximum(0, z)
    cache = z

    return a, cache


def relu_back(dA, cache):

    Z = cache
    dZ = jnp.array(dA, copy=True)

    dZ.at[Z <= 0].set(0)

    return dZ


def forward_prop(A, W, b):

    dot = jnp.dot(W, A)
    Z = dot + b
    cache = (A, W, b)

    return Z, cache


def forward_activation(A_prev, W, b, activation):
    

    if activation == "sigmoid":
        Z, linear_cache = forward_prop(A_prev, W, b)
        A, active_cache = sigmoid(Z)

    elif activation == "relu":
        Z, linear_cache = forward_prop(A_prev, W, b)
        A, active_cache = relu(Z)
    
    cache = (linear_cache, active_cache)

    return A, cache



def model_forward(X, params):

    caches = []
    A = X
    L = len(params) // 2

    for l in range (1, L):
        A_prev = A
        A, cache = forward_activation(A_prev, params["W" + str(l)], params["b" + str(l)], activation = "relu")
        caches.append(cache)
    AL, cache = forward_activation(A, params["W" + str(L)], params["b" + str(L)], activation = "sigmoid")
    caches.append(cache)

    return AL, caches


def cost_func(AL, Y):

    m = Y.shape[1]

    # l = np.where(AL < 0.0000000001, np.log((AL), np.float64), 0.999999999999)
    l = jnp.log((AL))
    lo = jnp.log((1-AL))
    cost = -(1/m) * jnp.sum((Y*l)+((1-Y)*(lo)))

    cost = jnp.squeeze(cost)

    return cost


def reg_cost(AL, Y, parameters, lambd):

    m = Y.shape[1]
    W1 = parameters["W1"]
    W2 = parameters["W2"]
    W3 = parameters["W3"]

    cross_entropy_cost = cost_func(AL, Y)

    L2_cost =  (1/m)*(lambd/2)*(jnp.sum(jnp.square(W1)) + jnp.sum(jnp.square(W2)) + jnp.sum(jnp.square(W3)))

    cost = cross_entropy_cost + L2_cost
    
    return cost


def back_prop(dZ, cache):

    A_prev, W, b = cache
    m = A_prev.shape[1]

    dW = (1/m) * jnp.dot((dZ), (A_prev.T))
    db = (1/m) * jnp.sum(dZ, axis = 1, keepdims = True)
    dA_prev = jnp.dot(W.T, dZ)
    
    return dA_prev, dW, db


def back_active(dA, cache, activation):

    linear_cache, activation_cache = cache

    if activation == "relu":
        dZ = relu_back(dA, activation_cache)
        dA_prev, dW, db = back_prop(dZ, linear_cache)

    elif activation == "sigmoid":
        dZ = sigmoid_back(dA, activation_cache)
        dA_prev, dW, db = back_prop(dZ, linear_cache)

    return dA_prev, dW, db


def model_back(AL, Y, caches):

    grads = {}
    L = len(caches)
    m = AL.shape[1]
    Y = jnp.reshape(Y,(AL.shape))

    dAL = - (jnp.divide(Y, AL) - jnp.divide(1 - Y, 1 - AL))

    current_cache = caches[L-1]
    dA_prev_temp, dW_temp, db_temp = back_active(dAL, current_cache, activation = "sigmoid")
    grads["dA" + str(L-1)] = dA_prev_temp
    grads["dW" + str(L)] = dW_temp
    grads["db" + str(L)] = db_temp

    for l in reversed(range(L-1)):
        current_cache = caches[l]
        dA_prev_temp, dW_temp, db_temp = back_active(grads["dA" + str(l + 1)], current_cache, activation = "relu")
        grads["dA" + str(l)] = dA_prev_temp
        grads["dW" + str(l + 1)] = dW_temp
        grads["db" + str(l + 1)] = db_temp

    return grads


def update_params(params, grads, learning_rate):

    parameters = params.copy()
    L = len(parameters) // 2

    for l in range(L):        
        parameters["W" + str(l+1)] = parameters["W" + str(l+1)] - (learning_rate * grads["dW" + str(l+1)])
        parameters["b" + str(l+1)] = parameters["b" + str(l+1)] - (learning_rate * grads["db" + str(l+1)])
    
    return parameters


def predict(X, y, parameters):

    m = X.shape[1]
    n = len(parameters) // 2
    p = jnp.zeros((1,m))

    probas, caches = model_forward(X, parameters)

    for i in range(0, probas.shape[1]):
        if probas [0,i] > 0.5:  
            p[0,i] = 1
        else:
            p[0,i] = 0

    print("Accuracy: "  + str(jnp.sum((p == y)/m)))
        
    return p


def nn_model(X, Y, layers_dims, learning_rate = 0.0075, n_iter = 10000, print_cost = False, lambd = 0.7):
    
    jax.random.PRNGKey(0)
    costs = []

    parameters = init_params(layers_dims)

    for i in range(0, n_iter):
        AL, caches = model_forward(X, parameters)
        cost = reg_cost(AL, Y, parameters, lambd)
        # cost = cost_func(AL, Y)
        grads = model_back(AL, Y, caches)
        parameters = update_params(parameters, grads, learning_rate)

        if print_cost and i % 100 == 0 or i == n_iter - 1:
            print("Cost after iteration {}: {}".format(i, jnp.squeeze(cost)))
        if i % 100 == 0 or i == n_iter:
            costs.append(cost)
    
    return parameters, costs

 
parameters, costs = nn_model(train_set_flatten, train_class, layers_dims, learning_rate = 0.0007, n_iter = 1000, print_cost = True, lambd = 0.0005) 

# pred_train = predict(train_set_flatten, train_class, parameters)

# print(pred_train.shape)
# cm = confusion_matrix(train_class.argmax(axis=1), pred_train.argmax(axis=1))
# cm_display = ConfusionMatrixDisplay(confusion_matrix = cm)

# cm_display.plot()
# plt.show()

# pred_test = predict(test_set_flatten, test_class, parameters)



Cost after iteration 0: 0.6931473016738892
Cost after iteration 100: 0.6931473016738892
Cost after iteration 200: 0.6931473016738892


KeyboardInterrupt: ignored

##Possibilities
to fit the test set well on the network you may use a bigger dev set.

things to test optimizing:
- learning rate.
- batch size.
- no. hidden cells.
- no. hidden layers.
- adam optimizer (beta1, beta2, epsilon).

##Record
- without regularization the network best result: 0.34 error and 0.76 accuracy. each result with a different learning rate.

- with the reg. a run with 0.00005, 2000 iter and 0.7 lambd will result error 0.639 and accuracy 0.526.

- with reg a run with 0.0005, 3000 iter and 0.7 lambd. will result training error of 0.4902 and accuracy of 0.7527.

- with reg a run with 0.001, 3000 iter and 0.7 lambd, will result training error of 0.4864 and accuracy of 0.75667.

- with reg a run with 0.005, 3000 iter and 0.7 lambd, will result training error 0.48048, accuracy 0.57665 and test accuracy 0.55161.

- with reg a run with 0.0007, 3000 iter and 0.7 lambd, will result training error 0.48800, accuracy 0.75272 and test accuracy 0.64516.

- with reg a run with 0.0007, 4000 iter and 0.001 lambd, will result 0.3537550592501022 as training error, training Accuracy: 0.7537091988130565
and test Accuracy: 0.6451612903225807.

- with reg a run with 0.0007, 4000 iter and 0.0001 lambd, will result training error 0.35358383575002045 training Accuracy: 0.7537091988130565
test Accuracy: 0.6451612903225807.

- all of the above was using the architecture: 
(12288, 128, 128, 128, 128, 1) and now it will be changed.

- with reg a run with 0.0007, 4000 iter and 0.001 lambd, arch 
(12288, 1024, 512, 256, 128, 1) and it result training error of 0.3578575759145978, training accuracy of 0.7378832838773492, and test accuracy of 0.5774193548387099

- duo to the time that has been taken by the bigger model will try to use jax to run the model after including adam optimizer and dropout