In [1]:
%env JAX_PLATFORM_NAME=cpu

import jaxtyping  # noqa: F401

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

env: JAX_PLATFORM_NAME=cpu


In [2]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from tqdm.notebook import trange

from jaxtyping import Array, Bool, Float

from chaogatenn.chaogate import ChaoGate
from chaogatenn.maps import LogisticMap, DuffingMap, LorenzMap
from chaogatenn.utils import grad_norm

In [3]:
# Training data for the AND gate
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
AND_Y = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output
OR_Y = jnp.array([0, 1, 1, 1], dtype=bool)  # OR gate output
XOR_Y = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR
NAND_Y = jnp.array([1, 1, 1, 0], dtype=bool)  # NAND
NOR_Y = jnp.array([1, 0, 0, 0], dtype=bool)  # NOR
XNOR_Y = jnp.array([1, 0, 0, 1], dtype=bool)  # XNOR
Y = XNOR_Y

In [4]:
# Map = LogisticMap(a=0.0)
Map = LorenzMap()
# Map = DuffingMap(
#     alpha=1.0, beta=5.0, delta=0.02, gamma=8.0, omega=0.5, dt=0.01, steps=1000
# )

In [5]:
Map(2)

-0.0

In [6]:
chao_gate = ChaoGate(DELTA=0.5, X0=0.5, X_THRESHOLD=1.0, Map=Map)

In [7]:
[chao_gate(x) for x in X]

[Array(0.26894143, dtype=float32, weak_type=True),
 Array(0.26894143, dtype=float32, weak_type=True),
 Array(0.26894143, dtype=float32, weak_type=True),
 Array(0.26894143, dtype=float32, weak_type=True)]

In [8]:
@eqx.filter_value_and_grad()
def compute_loss(
    chao_gate: ChaoGate, x: Bool[Array, "batch 2"], y: Bool[Array, "batch"]
) -> Float[Array, ""]:  # noqa: F821
    pred = jax.vmap(chao_gate)(x)
    # binary cross entropy
    return -jnp.mean(y * jnp.log(pred + 1e-15) + (1 - y) * jnp.log(1 - pred + 1e-15))

In [9]:
@eqx.filter_jit
def make_step(
    model: ChaoGate,
    x: Bool[Array, "dim 2"],
    y: Bool[Array, "dim"],  # noqa: F821
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> (Float[Array, "dim"], ChaoGate, optax.OptState):  # type: ignore  # noqa: F821
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    # jax.debug.print(f"{grads, updates}")
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [10]:
optim = optax.adabelief(3e-4)
opt_state = optim.init(eqx.filter(chao_gate, eqx.is_inexact_array))

In [11]:
epochs = 4_000

for epoch in trange(epochs):
    loss, chao_gate, opt_state = make_step(chao_gate, X, Y, optim, opt_state)  # type: ignore
    _, grads = compute_loss(chao_gate, X, Y)
    grad_norm_value = grad_norm(grads)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}, Grad Norm: {grad_norm_value}")

    if loss < 1e-3:
        break

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch 0, Loss: 0.8132616281509399, Grad Norm: 0.23099958896636963
Epoch 10, Loss: 0.8124070763587952, Grad Norm: 0.2302396148443222
Epoch 20, Loss: 0.8112018704414368, Grad Norm: 0.22917640209197998
Epoch 30, Loss: 0.8096903562545776, Grad Norm: 0.2278442233800888
Epoch 40, Loss: 0.8079244494438171, Grad Norm: 0.22628086805343628
Epoch 50, Loss: 0.8059465289115906, Grad Norm: 0.22451581060886383
Epoch 60, Loss: 0.803789496421814, Grad Norm: 0.22257111966609955
Epoch 70, Loss: 0.8014801144599915, Grad Norm: 0.22046375274658203
Epoch 80, Loss: 0.7990409135818481, Grad Norm: 0.21820753812789917
Epoch 90, Loss: 0.7964912056922913, Grad Norm: 0.21581421792507172
Epoch 100, Loss: 0.7938483953475952, Grad Norm: 0.21329385042190552
Epoch 110, Loss: 0.7911282181739807, Grad Norm: 0.2106555849313736
Epoch 120, Loss: 0.7883448600769043, Grad Norm: 0.20790773630142212
Epoch 130, Loss: 0.7855114936828613, Grad Norm: 0.2050580084323883
Epoch 140, Loss: 0.7826403379440308, Grad Norm: 0.20211361348628

In [12]:
print("\nTrained ChaoGate Parameters:")
print(
    f"DELTA: {chao_gate.DELTA}, X0: {chao_gate.X0}, X_THRESHOLD: {chao_gate.X_THRESHOLD}"
)


Trained ChaoGate Parameters:
DELTA: 0.5, X0: 0.5, X_THRESHOLD: 1.5031135092158365e-07


In [13]:
[
    (
        bool(x1.item()),
        bool(x2.item()),
        (
            chao_gate.Map(chao_gate.X0 + x1 * chao_gate.DELTA + x2 * chao_gate.DELTA)
            > chao_gate.X_THRESHOLD
        ).item(),
    )
    for x1, x2 in X
]

[(False, False, False),
 (False, True, False),
 (True, False, False),
 (True, True, False)]

In [14]:
pred_ys = jax.vmap(chao_gate)(X)
num_correct = jnp.sum((pred_ys > 0.5) == Y)
final_accuracy = (num_correct / len(X)).item()
print(f"final_accuracy={final_accuracy}")

final_accuracy=0.5
