# **JAX Implementation of CA Model**

## **Installations**


In [56]:
!pip3 install jax
!pip3 install pytest



## **Imports**

In [57]:
import jax
import jax.numpy as jnp
from jax import lax
import time

## **Model**

In [58]:
class CAModel:

  def __init__(self, channel_n, rng):
    self.channel_n = channel_n
    k1, k2 = jax.random.split(rng)

    # Conv 1 -- output: 128 channels
    self.w1 = jax.random.normal(k1, (1, 1, channel_n*3, 128)) * 0.1 # why 0.1?
    self.b1 = jnp.zeros((128,))

    # Conv 2 -- output: channel_n channels
    self.w2 = jax.random.normal(k2, (1, 1, 128, channel_n)) * 0.1
    self.b2 = jnp.zeros((channel_n,))

  # Depthwise Convolution
  def perceive(self, x, angle):
    identify = jnp.float32([0,1,0])
    identify = jnp.outer(identify, identify)
    dx = jnp.outer(jnp.array([1,2,1]),jnp.array([-1,0,1])) / 8.0
    dy = dx.T
    c, s = jnp.cos(angle), jnp.sin(angle)
    base_filters = jnp.stack([identify, c*dx-s*dy, s*dx+c*dy])
    kernel = jnp.zeros((3,3,1,self.channel_n*3))
    for i in range(self.channel_n):
      for j in range(3):
        kernel = kernel.at[:,:,0,i*3+j].set(base_filters[j])
    y = lax.conv_general_dilated(
          x, # shape: [1,3,3,16]
          kernel, # shape: [3,3,1,48]
          window_strides=(1, 1),
          padding="SAME",
          dimension_numbers=("NHWC", "HWIO", "NHWC"),
          feature_group_count=self.channel_n  # depthwise conv here
          )
    return y

  # Dense Layer Convolutions
  def __call__(self, x, angle):
    y = self.perceive(x, angle)  # output channels = channel_n * 3 (e.g., 48)
    y = lax.conv_general_dilated(y, self.w1,
                                window_strides=(1,1),
                                padding='SAME',
                                dimension_numbers=("NHWC", "HWIO", "NHWC")) + self.b1  # Normal conv, output channels=128
    y = jax.nn.relu(y)
    y = lax.conv_general_dilated(y, self.w2,
                                window_strides=(1,1),
                                padding='SAME',
                                dimension_numbers=("NHWC", "HWIO", "NHWC")) + self.b2
    return x + y # old state + update

## ***Some Model Info***

In [59]:
# Initializing CA Model
rng = jax.random.PRNGKey(int(time.time()))
model = CAModel(channel_n=16, rng=rng)

# Calling the 'model' object (via __call__)
x = jnp.ones((1, 3, 3, 16))  # Dummy input
out = model(x, angle=0.0)

# Applying perceive function to 'model' object
y = model.perceive(x, angle=0.0)

# Copy of First Layer
z = lax.conv_general_dilated(y, model.w1,
                                window_strides=(1,1),
                                padding='SAME',
                                dimension_numbers=("NHWC", "HWIO", "NHWC")) + model.b1
z = jax.nn.relu(z)

# Copy of Second Layer
w = lax.conv_general_dilated(z, model.w2,
                                window_strides=(1,1),
                                padding='SAME',
                                dimension_numbers=("NHWC", "HWIO", "NHWC")) + model.b2

print("Input shape:", x.shape)
print("Output shape:", out.shape)
print("Mean change:", jnp.mean(jnp.abs(out-x)))
print("---------------------------------------")
print("Shape of depthwise layer:", y.shape)
print("Shape of first dense layer:", z.shape)
print("Shape of second dense layer:", w.shape)


Input shape: (1, 3, 3, 16)
Output shape: (1, 3, 3, 16)
Mean change: 0.31669202
---------------------------------------
Shape of depthwise layer: (1, 3, 3, 48)
Shape of first dense layer: (1, 3, 3, 128)
Shape of second dense layer: (1, 3, 3, 16)


## **Testing**

### **Unit Testing for Dense Layer Convolution**

In [60]:
def test_Dense_shape_preservation():
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  y = model(x, angle=0.0)
  assert y.shape == (1,3,3,channel_n)
  print("Dense layer shape preservation test passed!")

def test_Dense_angle_perservation():
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  y = model(x, angle=50.0)
  assert y.shape == (1,3,3,channel_n)
  print("Dense layer angle doesn't change shape!")

def test_Dense_reproducability():
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  y1 = model(x, angle=0.0)
  y2 = model(x, angle=0.0)
  assert jnp.allclose(y1,y2)
  print("Depthwise layer reproducability test passed!")

def test_Dense_IO():
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  y = model(x, angle=0.0)
  assert not jnp.allclose(y,x)
  print("Dense layer IO test passed!")


''' Tests if first dense layer actually changes output by comparing
    performance with zeroed first dense layer'''

def test_Dense_FirstLayer():
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  perceive_out = model.perceive(x, angle=0.0)
  y = lax.conv_general_dilated(perceive_out, model.w1,
                                window_strides=(1,1),
                                padding='SAME',
                                dimension_numbers=("NHWC", "HWIO", "NHWC")) + model.b1 # this is the normal layer
  y = jax.nn.relu(y)

  z = lax.conv_general_dilated(perceive_out, jnp.zeros_like(model.w1),
                                window_strides=(1,1),
                                padding='SAME',
                                dimension_numbers=("NHWC", "HWIO", "NHWC")) + jnp.zeros_like(model.b1) # this is the zeroed layer
  z = jax.nn.relu(z)

  assert not jnp.allclose(y,z)
  print("First Dense Layer Test Passed!")

''' Tests if second dense layer actually changes output by
    comparing performance with zeroed second dense layer'''

def test_Dense_SecondLayer():
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  perceive_out = model.perceive(x, angle=0.0)
  y = lax.conv_general_dilated(perceive_out, model.w1,
                                window_strides=(1,1),
                                padding='SAME',
                                dimension_numbers=("NHWC", "HWIO", "NHWC")) + model.b1
  firstLayer_out = jax.nn.relu(y)

  y = lax.conv_general_dilated(firstLayer_out, model.w2,
                                window_strides=(1,1),
                                padding='SAME',
                                dimension_numbers=("NHWC", "HWIO", "NHWC")) + model.b2 # this is the normal layer

  z = lax.conv_general_dilated(firstLayer_out, jnp.zeros_like(model.w2),
                                window_strides=(1,1),
                                padding='SAME',
                                dimension_numbers=("NHWC", "HWIO", "NHWC")) + jnp.zeros_like(model.b2) # this is the zeroed layer

  assert not jnp.allclose(y,z)
  print("Second Dense Layer Test Passed!")

test_Dense_shape_preservation()
test_Dense_angle_perservation()
test_Dense_reproducability()
test_Dense_IO()
test_Dense_FirstLayer()
test_Dense_SecondLayer()

# find shapes of the convolutions of dense layers


Dense layer shape preservation test passed!
Dense layer angle doesn't change shape!
Depthwise layer reproducability test passed!
Dense layer IO test passed!
First Dense Layer Test Passed!
Second Dense Layer Test Passed!


### **Unit Testing for Depthwise Layer Convolution**

In [61]:
def test_Depthwise_shape_preservation():
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  y = model.perceive(x, angle=0.0)
  assert y.shape == (1,3,3,channel_n*3)
  print("Depthwise layer shape preservation test passed!")

def test_angle_perservation():
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  y = model.perceive(x, angle=50.0)
  assert y.shape == (1,3,3,channel_n*3)
  print("Depthwise layer (angle) shape preservation test passed!")

def test_Depthwise_reproducability():
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  y1 = model.perceive(x, angle=0.0)
  y2 = model.perceive(x, angle=0.0)
  assert jnp.allclose(y1,y2)
  print("Depthwise layer reproducability test passed!")

def test_kernel_construction():
  angle = 0.0
  channel_n = 16
  identify = jnp.float32([0,1,0])
  identify = jnp.outer(identify, identify)
  dx = jnp.outer(jnp.array([1,2,1]),jnp.array([-1,0,1])) / 8.0
  dy = dx.T
  c, s = jnp.cos(angle), jnp.sin(angle)
  base_filters = jnp.stack([identify, c*dx-s*dy, s*dx+c*dy])
  kernel = jnp.zeros((3,3,1,channel_n*3))
  for i in range(channel_n):
    for j in range(3):
      kernel = kernel.at[:,:,0,i*3+j].set(base_filters[j])

  for k in range(channel_n):
    identity_match = jnp.allclose(kernel[:,:,0,k*3], identify)
    sobel_x_match = jnp.allclose(kernel[:,:,0,k*3+1], dx)
    sobel_y_match = jnp.allclose(kernel[:,:,0,k*3+2], dy)

  assert identity_match and sobel_x_match and sobel_y_match
  print("Kernel construction test passed!")

def test_Depthwise_IO(): # test that output is not just an expanded input
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  rng = jax.random.PRNGKey(int(time.time()))
  channel_n = 16
  model = CAModel(channel_n, rng=rng)
  x = jnp.ones((1, 3, 3, channel_n))
  y = model.perceive(x, angle=0.0)
  z = jnp.ones((1, 3, 3, channel_n*3))
  assert not jnp.allclose(y,z)
  print("Dense layer IO test passed!")


test_Depthwise_shape_preservation()
test_angle_perservation()
test_Depthwise_reproducability()
test_kernel_construction()
test_Depthwise_IO()







Depthwise layer shape preservation test passed!
Depthwise layer (angle) shape preservation test passed!
Depthwise layer reproducability test passed!
Kernel construction test passed!
Dense layer IO test passed!
