In [48]:
%env JAX_PLATFORM_NAME=cpu

import jaxtyping  # noqa: F401

%load_ext jaxtyping
# %jaxtyping.typechecker beartype.beartype

env: JAX_PLATFORM_NAME=cpu
The jaxtyping extension is already loaded. To reload it, use:
  %reload_ext jaxtyping


In [49]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from tqdm.notebook import trange

from jaxtyping import Array, Bool, Float

from chaogatenn.chaogate import ChaoGate, NChaoGate
from chaogatenn.maps import (
    LogisticMap,
    DuffingMap,
    LorenzMap,
    RosslerMap,
    ChenMap,
    RosslerHyperchaosMap,
)
from chaogatenn.utils import grad_norm

In [50]:
# Training data for the AND gate
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
AND_Y = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output
OR_Y = jnp.array([0, 1, 1, 1], dtype=bool)  # OR gate output
XOR_Y = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR
NAND_Y = jnp.array([1, 1, 1, 0], dtype=bool)  # NAND
NOR_Y = jnp.array([1, 0, 0, 0], dtype=bool)  # NOR
XNOR_Y = jnp.array([1, 0, 0, 1], dtype=bool)  # XNOR
Y = NAND_Y

In [51]:
# 3 input combinations

X3 = jnp.array(
    [
        [0, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [0, 1, 1],
        [1, 0, 0],
        [1, 0, 1],
        [1, 1, 0],
        [1, 1, 1],
    ],
    dtype=bool,
)

AND_Y3 = jnp.array([0, 0, 0, 0, 0, 0, 0, 1], dtype=bool)  # AND gate output
OR_Y3 = jnp.array([0, 1, 1, 1, 1, 1, 1, 1], dtype=bool)  # OR gate output
NAND_Y3 = jnp.array([1, 1, 1, 1, 1, 1, 1, 0], dtype=bool)  # NAND
NOR_Y3 = jnp.array([1, 0, 0, 0, 0, 0, 0, 0], dtype=bool)  # NOR

X = X3
Y = NOR_Y3

In [52]:
Map = LogisticMap(a=4.0)
# Map = LorenzMap()
# Map = DuffingMap(
#     alpha=1.0, beta=1.0, delta=0.02, gamma=8.0, omega=0.5, dt=0.01, steps=1000
# )
# Map = ChenMap(dt=1e-5, steps=100)
# Map = RosslerHyperchaosMap()

In [53]:
Map(2)  # type: ignore

-8.0

In [54]:
DELTA, X0, X_THRESHOLD = jax.random.normal(jax.random.PRNGKey(42), (3,))
chao_gate = NChaoGate(DELTA=DELTA, X0=X0, X_THRESHOLD=X_THRESHOLD, Map=Map)

In [55]:
[chao_gate(x) for x in X]

[Array(4.0121733e-05, dtype=float32),
 Array(0.0005, dtype=float32),
 Array(0.0005, dtype=float32),
 Array(0.00469383, dtype=float32),
 Array(0.0005, dtype=float32),
 Array(0.00469383, dtype=float32),
 Array(0.00469383, dtype=float32),
 Array(0.03252234, dtype=float32)]

In [56]:
@eqx.filter_value_and_grad()
def compute_loss(
    chao_gate: ChaoGate, x: Bool[Array, "batch n"], y: Bool[Array, "batch"]
) -> Float[Array, ""]:  # noqa: F821
    pred = jax.vmap(chao_gate)(x)
    # binary cross entropy
    return -jnp.mean(y * jnp.log(pred + 1e-15) + (1 - y) * jnp.log(1 - pred + 1e-15))

In [57]:
@eqx.filter_jit
def make_step(
    model: ChaoGate,
    x: Bool[Array, "dim n"],
    y: Bool[Array, "dim"],  # noqa: F821
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> (Float[Array, "dim"], ChaoGate, optax.OptState):  # type: ignore  # noqa: F821
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    # jax.debug.print(f"{grads, updates}")
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [58]:
optim = optax.adabelief(3e-5)
opt_state = optim.init(eqx.filter(chao_gate, eqx.is_inexact_array))

In [59]:
epochs = 5_000

for epoch in trange(epochs):
    loss, chao_gate, opt_state = make_step(chao_gate, X, Y, optim, opt_state)  # type: ignore
    _, grads = compute_loss(chao_gate, X, Y)
    grad_norm_value = grad_norm(grads)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}, Grad Norm: {grad_norm_value}")

    if loss < 1e-3:
        break

  0%|          | 0/5000 [00:00<?, ?it/s]

Epoch 0, Loss: 1.271533727645874, Grad Norm: 1.7303177118301392
Epoch 10, Loss: 1.270793080329895, Grad Norm: 1.7301713228225708
Epoch 20, Loss: 1.2697443962097168, Grad Norm: 1.7299641370773315
Epoch 30, Loss: 1.2684203386306763, Grad Norm: 1.729699730873108
Epoch 40, Loss: 1.266859531402588, Grad Norm: 1.7293795347213745
Epoch 50, Loss: 1.2650904655456543, Grad Norm: 1.7290043830871582
Epoch 60, Loss: 1.2631356716156006, Grad Norm: 1.7285733222961426
Epoch 70, Loss: 1.2610108852386475, Grad Norm: 1.7280858755111694
Epoch 80, Loss: 1.2587285041809082, Grad Norm: 1.7275406122207642
Epoch 90, Loss: 1.256299614906311, Grad Norm: 1.7269353866577148
Epoch 100, Loss: 1.253732681274414, Grad Norm: 1.7262697219848633
Epoch 110, Loss: 1.2510342597961426, Grad Norm: 1.725541353225708
Epoch 120, Loss: 1.2482103109359741, Grad Norm: 1.7247490882873535
Epoch 130, Loss: 1.2452672719955444, Grad Norm: 1.7238924503326416
Epoch 140, Loss: 1.2422094345092773, Grad Norm: 1.7229704856872559
Epoch 150, Lo

In [60]:
print("\nTrained ChaoGate Parameters:")
print(
    f"DELTA: {chao_gate.DELTA}, X0: {chao_gate.X0}, X_THRESHOLD: {chao_gate.X_THRESHOLD}"
)


Trained ChaoGate Parameters:
DELTA: -1.0826739072799683, X0: -0.11785778403282166, X_THRESHOLD: -3.5346035957336426


In [61]:
# Test the trained model by verifying the output of the AND gate
print("\nTest the trained model:")
print("Input | Output")
for x, y in zip(X, Y):
    print(
        f"{x} | {chao_gate.Map(chao_gate.X0 + (x * chao_gate.DELTA ).sum()) > chao_gate.X_THRESHOLD}"
    )



Test the trained model:
Input | Output
[False False False] | True
[False False  True] | False
[False  True False] | False
[False  True  True] | False
[ True False False] | False
[ True False  True] | False
[ True  True False] | False
[ True  True  True] | False


In [62]:
pred_ys = jax.vmap(chao_gate)(X)
num_correct = jnp.sum((pred_ys > 0.5) == Y)
final_accuracy = (num_correct / len(X)).item()
print(f"final_accuracy={final_accuracy}")

final_accuracy=1.0
