In [16]:
%env JAX_PLATFORM_NAME=cpu

import jaxtyping  # noqa: F401

%load_ext jaxtyping
# %jaxtyping.typechecker beartype.beartype

env: JAX_PLATFORM_NAME=cpu
The jaxtyping extension is already loaded. To reload it, use:
  %reload_ext jaxtyping


In [17]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from tqdm.notebook import trange

from jaxtyping import Array, Bool, Float

from chaogatenn.chaogate import ChaoGate
from chaogatenn.maps import LogisticMap, DuffingMap, LorenzMap, RosslerMap
from chaogatenn.utils import grad_norm

In [18]:
# Training data for the AND gate
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
AND_Y = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output
OR_Y = jnp.array([0, 1, 1, 1], dtype=bool)  # OR gate output
XOR_Y = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR
NAND_Y = jnp.array([1, 1, 1, 0], dtype=bool)  # NAND
NOR_Y = jnp.array([1, 0, 0, 0], dtype=bool)  # NOR
XNOR_Y = jnp.array([1, 0, 0, 1], dtype=bool)  # XNOR
Y = XNOR_Y

In [19]:
# Map = LogisticMap(a=4.0)
# Map = LorenzMap()
# Map = DuffingMap(
#     alpha=1.0, beta=1.0, delta=0.02, gamma=8.0, omega=0.5, dt=0.01, steps=1000
# )
Map = RosslerMap()

In [20]:
Map(2)  # type: ignore

Array(-1.1745434, dtype=float32)

In [21]:
DELTA, X0, X_THRESHOLD = jax.random.normal(jax.random.PRNGKey(42), (3,))
chao_gate = ChaoGate(DELTA=DELTA, X0=X0, X_THRESHOLD=X_THRESHOLD, Map=Map)

In [22]:
[chao_gate(x) for x in X]

[Array(0.9000974, dtype=float32),
 Array(0.89128643, dtype=float32),
 Array(0.89128643, dtype=float32),
 Array(0.88171566, dtype=float32)]

In [23]:
@eqx.filter_value_and_grad()
def compute_loss(
    chao_gate: ChaoGate, x: Bool[Array, "batch 2"], y: Bool[Array, "batch"]
) -> Float[Array, ""]:  # noqa: F821
    pred = jax.vmap(chao_gate)(x)
    # binary cross entropy
    return -jnp.mean(y * jnp.log(pred + 1e-15) + (1 - y) * jnp.log(1 - pred + 1e-15))

In [24]:
@eqx.filter_jit
def make_step(
    model: ChaoGate,
    x: Bool[Array, "dim 2"],
    y: Bool[Array, "dim"],  # noqa: F821
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> (Float[Array, "dim"], ChaoGate, optax.OptState):  # type: ignore  # noqa: F821
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    # jax.debug.print(f"{grads, updates}")
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [25]:
optim = optax.adabelief(3e-3)
opt_state = optim.init(eqx.filter(chao_gate, eqx.is_inexact_array))

In [26]:
epochs = 5_000

for epoch in trange(epochs):
    loss, chao_gate, opt_state = make_step(chao_gate, X, Y, optim, opt_state)  # type: ignore
    _, grads = compute_loss(chao_gate, X, Y)
    grad_norm_value = grad_norm(grads)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}, Grad Norm: {grad_norm_value}")

    if loss < 1e-3:
        break

  0%|          | 0/5000 [00:00<?, ?it/s]

Epoch 0, Loss: 1.1673038005828857, Grad Norm: 0.4794231057167053
Epoch 10, Loss: 1.1385128498077393, Grad Norm: 0.4699413776397705
Epoch 20, Loss: 1.0989536046981812, Grad Norm: 0.4556359052658081
Epoch 30, Loss: 1.0516053438186646, Grad Norm: 0.4361053705215454
Epoch 40, Loss: 1.000169038772583, Grad Norm: 0.41114988923072815
Epoch 50, Loss: 0.948222279548645, Grad Norm: 0.3808229863643646
Epoch 60, Loss: 0.8989492654800415, Grad Norm: 0.34564855694770813
Epoch 70, Loss: 0.8549132347106934, Grad Norm: 0.30678901076316833
Epoch 80, Loss: 0.8178040981292725, Grad Norm: 0.26603856682777405
Epoch 90, Loss: 0.7882799506187439, Grad Norm: 0.22556018829345703
Epoch 100, Loss: 0.7660256624221802, Grad Norm: 0.18747439980506897
Epoch 110, Loss: 0.7500277757644653, Grad Norm: 0.15346908569335938
Epoch 120, Loss: 0.7389442920684814, Grad Norm: 0.12458822876214981
Epoch 130, Loss: 0.7314277291297913, Grad Norm: 0.10121803730726242
Epoch 140, Loss: 0.7263311147689819, Grad Norm: 0.0831783339381218

In [27]:
print("\nTrained ChaoGate Parameters:")
print(
    f"DELTA: {chao_gate.DELTA}, X0: {chao_gate.X0}, X_THRESHOLD: {chao_gate.X_THRESHOLD}"
)


Trained ChaoGate Parameters:
DELTA: 1.6840472483181657e-07, X0: 0.17210151255130768, X_THRESHOLD: -0.11809306591749191


In [28]:
[
    (
        bool(x1.item()),
        bool(x2.item()),
        (
            chao_gate.Map(chao_gate.X0 + x1 * chao_gate.DELTA + x2 * chao_gate.DELTA)
            > chao_gate.X_THRESHOLD
        ).item(),
    )
    for x1, x2 in X
]

[(False, False, True),
 (False, True, True),
 (True, False, True),
 (True, True, True)]

In [29]:
pred_ys = jax.vmap(chao_gate)(X)
num_correct = jnp.sum((pred_ys > 0.5) == Y)
final_accuracy = (num_correct / len(X)).item()
print(f"final_accuracy={final_accuracy}")

final_accuracy=0.5
