In [1]:
import random
import itertools

import jax
import jax.numpy as jnp
import numpy as np

from __future__ import print_function

### Neural Network: Single hidden layer with 3 neurons

In [24]:
from jax._src.lax.lax import bitwise_xor
def sigmoid(x):
  return 1/(1+jnp.exp(-x))

def net(params, x):
  w1, b1, w2, b2 = params
  hidden = jnp.tanh(jnp.dot(w1, x)+b1)
  return sigmoid(jnp.dot(w2, hidden)+b2)

def loss(params, x, y):
  out = net(params, x)
  cross_entropy = -y*jnp.log(out) - (1 - y)*jnp.log(1 - out)
  return cross_entropy

def test_all_outputs(inputs, params):
  predictions = [int(net(params, inp) > 0.5) for inp in inputs]
  for inp, out in zip(inputs, predictions):
    print(inp, '->', out)
  return (predictions == [np.bitwise_xor(*inp) for inp in inputs])

In [25]:
def initial_params():
    return [
        np.random.randn(3, 2),  # w1
        np.random.randn(3),  # b1
        np.random.randn(3),  # w2
        np.random.randn(),  #b2
    ]

In [26]:
initial_params()

[array([[ 0.2288743 , -0.2190556 ],
        [-0.81096588, -0.12031487],
        [-0.46920282,  1.2080414 ]]),
 array([-0.13427283, -0.23177895,  1.08305937]),
 array([ 0.48815877,  0.4180898 , -1.46185995]),
 -0.868907718455939]

In [27]:
loss_grad = jax.grad(loss)

learning_rate = 1.

inputs = jnp.array([[0,0],[0,1],[1,0],[1,1]])

params = initial_params()

for n in itertools.count():
  x = inputs[np.random.choice(inputs.shape[0])]

  y = np.bitwise_xor(*x)
  grads = loss_grad(params, x, y)
  params = [ (param - learning_rate*grad) for param, grad in zip(params, grads)]
  if not n % 100:
    print("Iteration {}".format(n))
    if test_all_outputs(inputs, params):
      break


Iteration 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


## jax.jit

In [30]:
loss_grad = jax.jit(jax.grad(loss))

learning_rate = 1.

inputs = jnp.array([[0,0],[0,1],[1,0],[1,1]])

params = initial_params()

for n in itertools.count():
  x = inputs[np.random.choice(inputs.shape[0])]

  y = np.bitwise_xor(*x)
  grads = loss_grad(params, x, y)
  params = [ (param - learning_rate*grad) for param, grad in zip(params, grads)]
  if not n % 100:
    print("Iteration {}".format(n))
    if test_all_outputs(inputs, params):
      break


Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [31]:
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes = (None, 0, 0), out_axes = 0))

params = initial_params()

batch_size = 100

for n in itertools.count():

  x = inputs[np.random.choice(inputs.shape[0], size=batch_size)]
  
  y = np.bitwise_xor( x[:, 0], x[:, 1])

  grads = loss_grad(params, x, y)

  params = [param - learning_rate * np.mean(grad,axis=0) for param,grad in zip(params, grads)]

  if not n%100:
    print('Iteration {}'.format(n))
    if test_all_outputs(inputs, params):
      break

Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
