In [88]:
%env JAX_PLATFORM_NAME=cpu

import jaxtyping  # noqa: F401

# %load_ext jaxtyping
# %jaxtyping.typechecker beartype.beartype

env: JAX_PLATFORM_NAME=cpu


In [89]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from tqdm.notebook import trange

from jaxtyping import Array, Bool, Float
from jax.typing import ArrayLike
from typing import Protocol, runtime_checkable, Any
from beartype import beartype as typechecker

from chaogatenn.chaogate import ChaoGate
from chaogatenn.maps import LogisticMap
from chaogatenn.utils import grad_norm

In [90]:
Map = LogisticMap(a=4.0)

In [91]:
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
AND_Y = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output
OR_Y = jnp.array([0, 1, 1, 1], dtype=bool)  # OR gate output
XOR_Y = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR
NAND_Y = jnp.array([1, 1, 1, 0], dtype=bool)  # NAND
NOR_Y = jnp.array([1, 0, 0, 0], dtype=bool)  # NOR
XNOR_Y = jnp.array([1, 0, 0, 1], dtype=bool)  # XNOR
Y = NOR_Y

In [92]:
class LogisticMapChaogate(eqx.Module):
    DELTA: Float[Array, ""]  # Unconstrained encoding parameter
    X_THRESHOLD: Float[Array, ""]  # Unconstrained threshold
    X0: Float[Array, ""]  # Unconstrained initial state
    map_steps: int  # Number of logistic map iterations

    @typechecker
    def __call__(self, x: Bool[Array, "2"]) -> Float[Array, ""]:
        """
        Logistic map-based chaogate function.
        """
        x1, x2 = x
        # Apply sigmoid to constrain parameters to [0, 1]
        delta_constrained = jax.nn.sigmoid(self.DELTA)
        x0_constrained = jax.nn.sigmoid(self.X0)
        x_threshold_constrained = jax.nn.sigmoid(self.X_THRESHOLD)

        def logistic_map(state: Float) -> Float:
            a = 4  # Fixed parameter
            return a * state * (1.0 - state)

        # Compute initial state based on inputs and constrained parameters
        initial_state = x0_constrained + x1 * delta_constrained + x2 * delta_constrained

        # Iterate the logistic map
        state = initial_state
        for _ in range(self.map_steps):
            state = logistic_map(state)

        return jax.nn.sigmoid(state - x_threshold_constrained)

In [93]:
DELTA, X0, X_THRESHOLD = jax.random.normal(jax.random.PRNGKey(0), (3,))


In [94]:
chao_gate = LogisticMapChaogate(DELTA, X_THRESHOLD, X0, map_steps=1)

In [95]:
[chao_gate(x) for x in X]

[Array(0.58899033, dtype=float32),
 Array(0.1437102, dtype=float32),
 Array(0.1437102, dtype=float32),
 Array(5.2874624e-05, dtype=float32)]

In [96]:
def compute_loss(
    gate: ChaoGate, x: Bool[Array, "batch 2"], y: Bool[Array, "batch"]
) -> Float[Array, ""]:
    pred = jax.vmap(gate)(x)
    loss = -jnp.mean(y * jnp.log(pred + 1e-15) + (1 - y) * jnp.log(1 - pred + 1e-15))

    return loss

In [97]:
@eqx.filter_jit
def make_step(
    model: ChaoGate,
    x: Bool[Array, "dim 2"],
    y: Bool[Array, "dim"],  # noqa: F821
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> (Float[Array, "dim"], ChaoGate, optax.OptState):  # type: ignore  # noqa: F821
    loss, grads = eqx.filter_value_and_grad(compute_loss)(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    # jax.debug.print(f"{grads, updates}")
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [98]:
optim = optax.adabelief(3e-4)
opt_state = optim.init(eqx.filter(chao_gate, eqx.is_inexact_array))

In [99]:
epochs = 10_000

for epoch in trange(epochs):
    loss, chao_gate, opt_state = make_step(chao_gate, X, Y, optim, opt_state)  # type: ignore
    grads = eqx.filter_grad(compute_loss)(chao_gate, X, Y)
    grad_norm_value = grad_norm(grads)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}, Grad Norm: {grad_norm_value}")

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch 0, Loss: 0.20992277562618256, Grad Norm: 0.13402342796325684
Epoch 10, Loss: 0.20924827456474304, Grad Norm: 0.13340365886688232
Epoch 20, Loss: 0.20829875767230988, Grad Norm: 0.13253842294216156
Epoch 30, Loss: 0.20711088180541992, Grad Norm: 0.13145893812179565
Epoch 40, Loss: 0.20572808384895325, Grad Norm: 0.13020040094852448
Epoch 50, Loss: 0.2041856348514557, Grad Norm: 0.12879133224487305
Epoch 60, Loss: 0.20251137018203735, Grad Norm: 0.12725414335727692
Epoch 70, Loss: 0.20072734355926514, Grad Norm: 0.12560704350471497
Epoch 80, Loss: 0.19885212182998657, Grad Norm: 0.12386508285999298
Epoch 90, Loss: 0.1969016194343567, Grad Norm: 0.12204159051179886
Epoch 100, Loss: 0.1948893666267395, Grad Norm: 0.12014837563037872
Epoch 110, Loss: 0.19282719492912292, Grad Norm: 0.11819592863321304
Epoch 120, Loss: 0.19072571396827698, Grad Norm: 0.11619386821985245
Epoch 130, Loss: 0.1885940432548523, Grad Norm: 0.11415120214223862
Epoch 140, Loss: 0.1864403933286667, Grad Norm: 0

In [100]:
print("\nTrained ChaoGate Parameters:")
# print(
#     f"DELTA: {chao_gate.DELTA}, X0: {chao_gate.X0}, X_THRESHOLD: {chao_gate.X_THRESHOLD}"
# )

print(
    f"Constrained: DELTA: {jax.nn.sigmoid(chao_gate.DELTA)}, X0: {jax.nn.sigmoid(chao_gate.X0)}, X_THRESHOLD: {jax.nn.sigmoid(chao_gate.X_THRESHOLD)}"
)


Trained ChaoGate Parameters:
Constrained: DELTA: 0.9994286894798279, X0: 0.6300225853919983, X_THRESHOLD: 0.00010577990906313062


In [101]:
[
    (
        bool(x1.item()),
        bool(x2.item()),
        (
            Map(
                jax.nn.sigmoid(chao_gate.X0)
                + x1 * jax.nn.sigmoid(chao_gate.DELTA)
                + x2 * jax.nn.sigmoid(chao_gate.DELTA)
            )
            > jax.nn.sigmoid(chao_gate.X_THRESHOLD)
        ).item(),
    )
    for x1, x2 in X
]

[(False, False, True),
 (False, True, False),
 (True, False, False),
 (True, True, False)]

In [102]:
pred_ys = jax.vmap(chao_gate)(X)
num_correct = jnp.sum((pred_ys > 0.5) == Y)
final_accuracy = (num_correct / len(X)).item()
print(f"final_accuracy={final_accuracy}")

final_accuracy=1.0
