In [73]:
%env JAX_PLATFORM_NAME=cpu

import jaxtyping  # noqa: F401

# %load_ext jaxtyping
# %jaxtyping.typechecker beartype.beartype

env: JAX_PLATFORM_NAME=cpu


In [74]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from tqdm.notebook import trange

from jaxtyping import Array, Bool, Float
from jax.typing import ArrayLike
from typing import Protocol, runtime_checkable, Any
from beartype import beartype as typechecker

from chaogatenn.chaogate import ChaoGate
from chaogatenn.maps import LogisticMap
from chaogatenn.utils import grad_norm

In [75]:
Map = LogisticMap(a=4.0)

In [None]:
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
AND_Y = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output
OR_Y = jnp.array([0, 1, 1, 1], dtype=bool)  # OR gate output
XOR_Y = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR
NAND_Y = jnp.array([1, 1, 1, 0], dtype=bool)  # NAND
NOR_Y = jnp.array([1, 0, 0, 0], dtype=bool)  # NOR
XNOR_Y = jnp.array([1, 0, 0, 1], dtype=bool)  # XNOR
Y = NOR_Y

In [77]:
class LogisticMapChaogate(eqx.Module):
    DELTA: Float[Array, ""]  # Unconstrained encoding parameter
    X_THRESHOLD: Float[Array, ""]  # Unconstrained threshold
    X0: Float[Array, ""]  # Unconstrained initial state
    map_steps: int  # Number of logistic map iterations

    @typechecker
    def __call__(self, x: Bool[Array, "2"]) -> Float[Array, ""]:
        """
        Logistic map-based chaogate function.
        """
        x1, x2 = x
        # Apply sigmoid to constrain parameters to [0, 1]
        delta_constrained = jax.nn.sigmoid(self.DELTA)
        x0_constrained = jax.nn.sigmoid(self.X0)
        x_threshold_constrained = jax.nn.sigmoid(self.X_THRESHOLD)

        def logistic_map(state: Float) -> Float:
            a = 4  # Fixed parameter
            return a * state * (1.0 - state)

        # Compute initial state based on inputs and constrained parameters
        initial_state = x0_constrained + x1 * delta_constrained + x2 * delta_constrained

        # Iterate the logistic map
        state = initial_state
        for _ in range(self.map_steps):
            state = logistic_map(state)

        return jax.nn.sigmoid(state - x_threshold_constrained)

In [78]:
DELTA, X0, X_THRESHOLD = jax.random.normal(jax.random.PRNGKey(0), (3,))


In [79]:
chao_gate = LogisticMapChaogate(DELTA, X_THRESHOLD, X0, map_steps=1)

In [80]:
[chao_gate(x) for x in X]

[Array(0.58899033, dtype=float32),
 Array(0.1437102, dtype=float32),
 Array(0.1437102, dtype=float32),
 Array(5.2874624e-05, dtype=float32)]

In [81]:
def compute_loss(
    gate: ChaoGate, x: Bool[Array, "batch 2"], y: Bool[Array, "batch"]
) -> Float[Array, ""]:
    pred = jax.vmap(gate)(x)
    loss = -jnp.mean(y * jnp.log(pred + 1e-15) + (1 - y) * jnp.log(1 - pred + 1e-15))

    return loss

In [82]:
@eqx.filter_jit
def make_step(
    model: ChaoGate,
    x: Bool[Array, "dim 2"],
    y: Bool[Array, "dim"],  # noqa: F821
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> (Float[Array, "dim"], ChaoGate, optax.OptState):  # type: ignore  # noqa: F821
    loss, grads = eqx.filter_value_and_grad(compute_loss)(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    # jax.debug.print(f"{grads, updates}")
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [83]:
optim = optax.adabelief(3e-4)
opt_state = optim.init(eqx.filter(chao_gate, eqx.is_inexact_array))

In [84]:
epochs = 10_000

for epoch in trange(epochs):
    loss, chao_gate, opt_state = make_step(chao_gate, X, Y, optim, opt_state)  # type: ignore
    grads = eqx.filter_grad(compute_loss)(chao_gate, X, Y)
    grad_norm_value = grad_norm(grads)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}, Grad Norm: {grad_norm_value}")

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch 0, Loss: 1.102327823638916, Grad Norm: 0.664995014667511
Epoch 10, Loss: 1.0985865592956543, Grad Norm: 0.6626760959625244
Epoch 20, Loss: 1.093306541442871, Grad Norm: 0.6594369411468506
Epoch 30, Loss: 1.0866756439208984, Grad Norm: 0.6553881168365479
Epoch 40, Loss: 1.0789152383804321, Grad Norm: 0.6506514549255371
Epoch 50, Loss: 1.0702029466629028, Grad Norm: 0.6453244686126709
Epoch 60, Loss: 1.0606765747070312, Grad Norm: 0.6394805908203125
Epoch 70, Loss: 1.0504459142684937, Grad Norm: 0.6331790685653687
Epoch 80, Loss: 1.0396023988723755, Grad Norm: 0.626469075679779
Epoch 90, Loss: 1.028224229812622, Grad Norm: 0.6193932890892029
Epoch 100, Loss: 1.0163803100585938, Grad Norm: 0.6119894981384277
Epoch 110, Loss: 1.004132628440857, Grad Norm: 0.6042918562889099
Epoch 120, Loss: 0.9915360808372498, Grad Norm: 0.5963317155838013
Epoch 130, Loss: 0.9786419868469238, Grad Norm: 0.5881385207176208
Epoch 140, Loss: 0.9654965400695801, Grad Norm: 0.5797398090362549
Epoch 150, L

In [85]:
print("\nTrained ChaoGate Parameters:")
# print(
#     f"DELTA: {chao_gate.DELTA}, X0: {chao_gate.X0}, X_THRESHOLD: {chao_gate.X_THRESHOLD}"
# )

print(
    f"Constrained: DELTA: {jax.nn.sigmoid(chao_gate.DELTA)}, X0: {jax.nn.sigmoid(chao_gate.X0)}, X_THRESHOLD: {jax.nn.sigmoid(chao_gate.X_THRESHOLD)}"
)


Trained ChaoGate Parameters:
Constrained: DELTA: 0.5825539827346802, X0: 0.20872355997562408, X_THRESHOLD: 0.0004729581414721906


In [86]:
[
    (
        bool(x1.item()),
        bool(x2.item()),
        (
            Map(
                jax.nn.sigmoid(chao_gate.X0)
                + x1 * jax.nn.sigmoid(chao_gate.DELTA)
                + x2 * jax.nn.sigmoid(chao_gate.DELTA)
            )
            > jax.nn.sigmoid(chao_gate.X_THRESHOLD)
        ).item(),
    )
    for x1, x2 in X
]

[(False, False, True),
 (False, True, True),
 (True, False, True),
 (True, True, False)]

In [87]:
pred_ys = jax.vmap(chao_gate)(X)
num_correct = jnp.sum((pred_ys > 0.5) == Y)
final_accuracy = (num_correct / len(X)).item()
print(f"final_accuracy={final_accuracy}")

final_accuracy=1.0
