# OR Gate Perceptron

$$
Y = A + B
$$

In [1]:
import jax.numpy as np
import jax

## Sigmoid Function

In [2]:
@jax.jit
def sigmoid(x):
    return 1 / (1+np.exp(-x))

## Derivative of Sigmoid

In [3]:
@jax.jit
def sigmoidDerivative(x):
    return x * (1 - x)

## Inputs and Labels

In [4]:
X = np.array(
    [[0, 0],
     [0, 1],
     [1, 0],
     [1, 1]]
)

In [5]:
Y = np.array([[0], [1], [1], [1]])

## Init Random Weights and Bias

In [6]:
key = jax.random.PRNGKey(0)

In [7]:
inputSize = 2
outputSize = 1

In [8]:
keyWeights, biasWeights = jax.random.split(key)
keyWeights, biasWeights

(Array([1797259609, 2579123966], dtype=uint32),
 Array([ 928981903, 3453687069], dtype=uint32))

In [9]:
weights = jax.random.normal(keyWeights, (inputSize, outputSize))
weights

Array([[ 1.0040143],
       [-0.9063372]], dtype=float32)

In [10]:
bias = jax.random.normal(biasWeights, (outputSize,))
bias

Array([-2.4424558], dtype=float32)

## Learning Rate

In [11]:
lr = 1

## Training 

In [12]:
for epoch in np.arange(10000):
    # Forward Propagation
    predict = sigmoid(np.dot(X, weights) + bias)

    # Calculate the error
    error = Y - predict

    # Backpropagation
    derivativePredict = sigmoidDerivative(predict) * error

    weights += np.dot(X.T, derivativePredict) * lr
    bias += np.sum(derivativePredict, axis=0) * lr

    if epoch % 1000 == 0:
        print(f"Error at epoc {epoch}: {np.mean(np.abs(error))}")

Error at epoc 0: 0.6916971802711487
Error at epoc 1000: 0.030672451481223106
Error at epoc 2000: 0.021218687295913696
Error at epoc 3000: 0.01714799925684929
Error at epoc 4000: 0.014758002012968063
Error at epoc 5000: 0.013143092393875122
Error at epoc 6000: 0.011959503404796124
Error at epoc 7000: 0.01104463916271925
Error at epoc 8000: 0.010310417041182518
Error at epoc 9000: 0.009704383090138435


## Final Weight and Bias

In [13]:
weights, bias

(Array([[8.672171],
        [8.672165]], dtype=float32),
 Array([-4.102197], dtype=float32))

## Prediction

In [15]:
sigmoid(np.dot(X, weights) + bias)

Array([[0.0162673],
       [0.9897479],
       [0.989748 ],
       [0.9999982]], dtype=float32)

In [16]:
sigmoid(np.dot(X, weights) + bias).round()

Array([[0.],
       [1.],
       [1.],
       [1.]], dtype=float32)