# AND Gate Perceptron

$$
Y = A \cdot B
$$

In [11]:
import jax.numpy as np
import jax

## Input and Output

In [12]:
X = np.array(
    [
        [0, 0],
        [0, 1],
        [1, 0],
        [1, 1],
    ]
)

Y = np.array([[0], [0], [0], [1]])

## Activation Function

* Sigmoid Function
* Derivative of Sigmoid Function

In [13]:
@jax.jit
def sigmoid(x):
    return 1/(1 + np.exp(-x))

In [14]:
@jax.jit
def derivativeSigmoid(x):
    return x * (1 - x)

## Intializing Random weights and bias

In [15]:
key = jax.random.PRNGKey(27)

In [16]:
keyWeights, keyBias = jax.random.split(key)
keyWeights, keyBias

(Array([ 112854966, 2582967503], dtype=uint32),
 Array([1903097535, 1612154735], dtype=uint32))

In [17]:
weights = jax.random.normal(keyWeights, (2, 1))
weights

Array([[-1.2051456],
       [ 1.0610977]], dtype=float32)

In [18]:
bias = jax.random.normal(keyBias, (1,))
bias

Array([0.06784267], dtype=float32)

In [19]:
lr = 10

## Training

In [20]:
for epoch in range(10000):

    # Forward Propagation
    predict = sigmoid(np.dot(X, weights) + bias)

    # Error calculation
    error = Y - predict

    # Backpropagation
    derivativePredict = derivativeSigmoid(predict) * error

    weights += np.dot(X.T, derivativePredict) * lr
    bias += np.sum(derivativePredict) * lr 

    if epoch % 1000 == 0:
        print(f"Error at epoc {epoch}: {np.mean(np.abs(error))}")

Error at epoc 0: 0.5086139440536499
Error at epoc 1000: 0.012997671961784363
Error at epoc 2000: 0.009107393212616444
Error at epoc 3000: 0.007405339740216732
Error at epoc 4000: 0.006397179327905178
Error at epoc 5000: 0.005711973644793034
Error at epoc 6000: 0.005207630340009928
Error at epoc 7000: 0.004816536791622639
Error at epoc 8000: 0.004501832649111748
Error at epoc 9000: 0.004241542425006628


## Final Weights and Bias 

In [21]:
weights, bias

(Array([[10.395933],
        [10.395933]], dtype=float32),
 Array([-15.678721], dtype=float32))

## Prediction

In [22]:
sigmoid(np.dot(X, weights) + bias)

Array([[1.5517357e-07],
       [5.0525931e-03],
       [5.0525931e-03],
       [9.9401891e-01]], dtype=float32)

In [23]:
sigmoid(np.dot(X, weights) + bias).round()

Array([[0.],
       [0.],
       [0.],
       [1.]], dtype=float32)