In [29]:
%env JAX_PLATFORM_NAME=cpu

import jaxtyping  # noqa: F401

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

env: JAX_PLATFORM_NAME=cpu
The jaxtyping extension is already loaded. To reload it, use:
  %reload_ext jaxtyping


In [30]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from tqdm.notebook import trange

from jaxtyping import Array, Bool, Float

from chaogatenn.chaogate import ChaoGate
from chaogatenn.maps import LogisticMap
from chaogatenn.utils import grad_norm

In [31]:
# Training data for the AND gate
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
AND_Y = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output
OR_Y = jnp.array([0, 1, 1, 1], dtype=bool)  # OR gate output
XOR_Y = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR
NAND_Y = jnp.array([1, 1, 1, 0], dtype=bool)  # NAND
XNOR_Y = jnp.array([1, 0, 0, 1], dtype=bool)  # XNOR
Y = XOR_Y

In [32]:
Map = LogisticMap(a=4.0)
# Map = LorenzMap()
# Map = DuffingMap(
#     alpha=1.0, beta=5.0, delta=0.02, gamma=8.0, omega=0.5, dt=0.01, steps=1000
# )

In [33]:
Map(2)

-8.0

In [34]:
chao_gate = ChaoGate(DELTA=1.0, X0=1.0, X_THRESHOLD=1.0, Map=Map)

In [35]:
[chao_gate(x) for x in X]

[Array(7.0954744e-23, dtype=float32, weak_type=True),
 Array(8.985826e-37, dtype=float32, weak_type=True),
 Array(8.985826e-37, dtype=float32, weak_type=True),
 Array(0., dtype=float32, weak_type=True)]

In [36]:
@eqx.filter_value_and_grad()
def compute_loss(
    chao_gate: ChaoGate, x: Bool[Array, "batch 2"], y: Bool[Array, "batch"]
) -> Float[Array, ""]:  # noqa: F821
    pred = jax.vmap(chao_gate)(x)
    # binary cross entropy
    return -jnp.mean(y * jnp.log(pred + 1e-15) + (1 - y) * jnp.log(1 - pred + 1e-15))

In [37]:
@eqx.filter_jit
def make_step(
    model: ChaoGate,
    x: Bool[Array, "dim 2"],
    y: Bool[Array, "dim"],  # noqa: F821
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> (Float[Array, "dim"], ChaoGate, optax.OptState):  # type: ignore  # noqa: F821
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    # jax.debug.print(f"{grads, updates}")
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [38]:
optim = optax.adabelief(3e-4)
opt_state = optim.init(eqx.filter(chao_gate, eqx.is_inexact_array))

In [39]:
epochs = 4_000

for epoch in trange(epochs):
    loss, chao_gate, opt_state = make_step(chao_gate, X, Y, optim, opt_state)  # type: ignore
    _, grads = compute_loss(chao_gate, X, Y)
    grad_norm_value = grad_norm(grads)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}, Grad Norm: {grad_norm_value}")

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch 0, Loss: 17.26938819885254, Grad Norm: 5.00504484080011e-07
Epoch 10, Loss: 17.26938819885254, Grad Norm: 5.441605708256247e-07
Epoch 20, Loss: 17.26938819885254, Grad Norm: 6.041383358024177e-07
Epoch 30, Loss: 17.26938819885254, Grad Norm: 6.836237389507005e-07
Epoch 40, Loss: 17.26938819885254, Grad Norm: 7.896969691500999e-07
Epoch 50, Loss: 17.26938819885254, Grad Norm: 9.350888490189391e-07
Epoch 60, Loss: 17.26938819885254, Grad Norm: 1.1426196806496591e-06
Epoch 70, Loss: 17.26938819885254, Grad Norm: 1.4552085758623434e-06
Epoch 80, Loss: 17.26938819885254, Grad Norm: 1.958715301952907e-06
Epoch 90, Loss: 17.26938819885254, Grad Norm: 2.8310419111221563e-06
Epoch 100, Loss: 17.26938819885254, Grad Norm: 4.4224880184629e-06
Epoch 110, Loss: 17.26938819885254, Grad Norm: 7.32496573618846e-06
Epoch 120, Loss: 17.26938819885254, Grad Norm: 1.2474984032451175e-05
Epoch 130, Loss: 17.269386291503906, Grad Norm: 2.152916567865759e-05
Epoch 140, Loss: 17.269386291503906, Grad No

In [40]:
print("\nTrained ChaoGate Parameters:")
print(
    f"DELTA: {chao_gate.DELTA}, X0: {chao_gate.X0}, X_THRESHOLD: {chao_gate.X_THRESHOLD}"
)


Trained ChaoGate Parameters:
DELTA: 4.162750997238618e-07, X0: 1.1203441619873047, X_THRESHOLD: -0.5393111109733582


In [41]:
[
    (
        bool(x1.item()),
        bool(x2.item()),
        (
            chao_gate.Map(chao_gate.X0 + x1 * chao_gate.DELTA + x2 * chao_gate.DELTA)
            > chao_gate.X_THRESHOLD
        ).item(),
    )
    for x1, x2 in X
]

[(False, False, True),
 (False, True, True),
 (True, False, True),
 (True, True, True)]

In [42]:
pred_ys = jax.vmap(chao_gate)(X)
num_correct = jnp.sum((pred_ys > 0.5) == Y)
final_accuracy = (num_correct / len(X)).item()
print(f"final_accuracy={final_accuracy}")

final_accuracy=0.25
