In [141]:
%env JAX_PLATFORM_NAME=cpu

import jaxtyping  # noqa: F401

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

env: JAX_PLATFORM_NAME=cpu
The jaxtyping extension is already loaded. To reload it, use:
  %reload_ext jaxtyping


In [142]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from tqdm.notebook import trange

from jaxtyping import Array, Bool, Float

from chaogatenn.chaogate import ChaoGate
from chaogatenn.maps import LogisticMap, DuffingMap, LorenzMap
from chaogatenn.utils import grad_norm

In [143]:
# Training data for the AND gate
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
AND_Y = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output
OR_Y = jnp.array([0, 1, 1, 1], dtype=bool)  # OR gate output
XOR_Y = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR
NAND_Y = jnp.array([1, 1, 1, 0], dtype=bool)  # NAND
NOR_Y = jnp.array([1, 0, 0, 0], dtype=bool)  # NOR
XNOR_Y = jnp.array([1, 0, 0, 1], dtype=bool)  # XNOR
Y = NAND_Y

In [144]:
# Map = LogisticMap(a=4.0)
Map = LorenzMap()
# Map = DuffingMap(
#     alpha=1.0, beta=5.0, delta=0.02, gamma=8.0, omega=0.5, dt=0.01, steps=1000
# )

In [145]:
Map(2)

Array(-8.679307, dtype=float32)

In [146]:
chao_gate = ChaoGate(DELTA=0.5, X0=0.5, X_THRESHOLD=1.0, Map=Map)

In [147]:
[chao_gate(x) for x in X]

[Array(3.8154376e-05, dtype=float32, weak_type=True),
 Array(4.0713287e-05, dtype=float32, weak_type=True),
 Array(4.0713287e-05, dtype=float32, weak_type=True),
 Array(4.9321417e-05, dtype=float32, weak_type=True)]

In [148]:
@eqx.filter_value_and_grad()
def compute_loss(
    chao_gate: ChaoGate, x: Bool[Array, "batch 2"], y: Bool[Array, "batch"]
) -> Float[Array, ""]:  # noqa: F821
    pred = jax.vmap(chao_gate)(x)
    # binary cross entropy
    return -jnp.mean(y * jnp.log(pred + 1e-15) + (1 - y) * jnp.log(1 - pred + 1e-15))

In [149]:
@eqx.filter_jit
def make_step(
    model: ChaoGate,
    x: Bool[Array, "dim 2"],
    y: Bool[Array, "dim"],  # noqa: F821
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> (Float[Array, "dim"], ChaoGate, optax.OptState):  # type: ignore  # noqa: F821
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    # jax.debug.print(f"{grads, updates}")
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [150]:
optim = optax.adabelief(3e-4)
opt_state = optim.init(eqx.filter(chao_gate, eqx.is_inexact_array))

In [151]:
epochs = 4_000

for epoch in trange(epochs):
    loss, chao_gate, opt_state = make_step(chao_gate, X, Y, optim, opt_state)  # type: ignore
    _, grads = compute_loss(chao_gate, X, Y)
    grad_norm_value = grad_norm(grads)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}, Grad Norm: {grad_norm_value}")

    if loss < 1e-3:
        break

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch 0, Loss: 7.597957611083984, Grad Norm: 0.7731635570526123
Epoch 10, Loss: 7.594200611114502, Grad Norm: 0.7740147113800049
Epoch 20, Loss: 7.588825702667236, Grad Norm: 0.7752259969711304
Epoch 30, Loss: 7.581957817077637, Grad Norm: 0.7767817974090576
Epoch 40, Loss: 7.573703765869141, Grad Norm: 0.7786918878555298
Epoch 50, Loss: 7.564134120941162, Grad Norm: 0.7809476852416992
Epoch 60, Loss: 7.553296089172363, Grad Norm: 0.7835504412651062
Epoch 70, Loss: 7.541145324707031, Grad Norm: 0.7864962220191956
Epoch 80, Loss: 7.52766227722168, Grad Norm: 0.7898054122924805
Epoch 90, Loss: 7.512805938720703, Grad Norm: 0.7934421300888062
Epoch 100, Loss: 7.496501445770264, Grad Norm: 0.7974137663841248
Epoch 110, Loss: 7.478690147399902, Grad Norm: 0.8016677498817444
Epoch 120, Loss: 7.459303379058838, Grad Norm: 0.8061787486076355
Epoch 130, Loss: 7.438255310058594, Grad Norm: 0.8109058737754822
Epoch 140, Loss: 7.415485382080078, Grad Norm: 0.8157999515533447
Epoch 150, Loss: 7.390

In [152]:
print("\nTrained ChaoGate Parameters:")
print(
    f"DELTA: {chao_gate.DELTA}, X0: {chao_gate.X0}, X_THRESHOLD: {chao_gate.X_THRESHOLD}"
)


Trained ChaoGate Parameters:
DELTA: 6.6756744384765625, X0: 12.585663795471191, X_THRESHOLD: -7.212272644042969


In [153]:
[
    (
        bool(x1.item()),
        bool(x2.item()),
        (
            chao_gate.Map(chao_gate.X0 + x1 * chao_gate.DELTA + x2 * chao_gate.DELTA)
            > chao_gate.X_THRESHOLD
        ).item(),
    )
    for x1, x2 in X
]

[(False, False, True),
 (False, True, True),
 (True, False, True),
 (True, True, False)]

In [154]:
pred_ys = jax.vmap(chao_gate)(X)
num_correct = jnp.sum((pred_ys > 0.5) == Y)
final_accuracy = (num_correct / len(X)).item()
print(f"final_accuracy={final_accuracy}")

final_accuracy=1.0
