In [1]:
%load_ext autoreload
%reload_ext autoreload

%env JAX_PLATFORM_NAME=cpu
# %env EQX_ON_ERROR=breakpoint

env: JAX_PLATFORM_NAME=cpu


In [2]:
import jax
import jax.numpy as jnp
import equinox as eqx

In [3]:
DELTA = 1 / 4

AND_X = 0
OR_X = 1 / 8
XOR_X = 1 / 4
NAND_X = 3 / 8

AND_TRUE = 3 / 4
OR_TRUE = 11 / 16
XOR_TRUE = 3 / 4
NAND_TRUE = 11 / 16

In [4]:
gate_types = {
    "AND": jnp.array([1, 0, 0, 0]),
    "OR": jnp.array([0, 1, 0, 0]),
    "NOR": jnp.array([0, 0, 1, 0]),
    "NAND": jnp.array([0, 0, 0, 1]),
}

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [5]:
xy_combinations = [(0, 0), (0, 1), (1, 0), (1, 1)]
gate_outputs = {
    "AND": [0, 0, 0, 1],
    "OR": [0, 1, 1, 1],
    "NOR": [1, 0, 0, 0],
    "NAND": [1, 1, 1, 0],
}

In [6]:
def logistic_map(x: float, a: float = 4):
    """Logistic map function"""
    return a * x * (1 - x)

In [7]:
def chao_gate(x: bool, y: bool, DELTA: float, X0: float, X_TRUE: float) -> float:
    """Chao gate function
    Args:
        x: input x
        y: input y
        DELTA:
        X0: initial value
        X_TRUE
    RETURNS:
        float: output of the chao gate

    """
    return logistic_map(X0 + x * DELTA + y * DELTA) - X_TRUE

In [8]:
class MLP(eqx.Module):
    """
    MLP to learn the parameters of the chao gates
    """

    layer_sizes: list
    layers: list
    layer_norm: eqx.nn.LayerNorm

    def __init__(self, hidden_sizes: list, key):
        self.layer_sizes = [4] + hidden_sizes + [3]
        self.layers = []
        keys = jax.random.split(key, len(self.layer_sizes))

        self.layers = [
            eqx.nn.Linear(in_features, out_features, key=keys[key_idx])
            for key_idx, (in_features, out_features) in enumerate(
                zip(self.layer_sizes[:-1], self.layer_sizes[1:])
            )
        ]

        self.layer_norm = eqx.nn.LayerNorm(shape=(self.layer_sizes[1],))

    def __call__(self, gate_type: jnp.ndarray) -> jnp.ndarray:
        """
        takes in the gate type and returns the parameters of the chao gate
        """
        x = jax.nn.relu(self.layers[0](gate_type))
        x = self.layer_norm(x)
        for layer in self.layers[1:-1]:
            x = jax.nn.relu(layer(x))

        return self.layers[-1](x)

In [9]:
# | test

key = jax.random.PRNGKey(0)
mlp = MLP(hidden_sizes=[8, 8], key=key)
gate_type = jnp.array([1, 0, 0, 0])
params = mlp(gate_type)
params

Array([-0.34048057, -0.31616643,  0.02645211], dtype=float32)

In [10]:
def loss_fn(model: eqx.Module, gate_type: jnp.ndarray, data):
    pass