In [1]:
%env JAX_PLATFORM_NAME=cpu

import jaxtyping  # noqa: F401

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

env: JAX_PLATFORM_NAME=cpu


In [2]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from tqdm.notebook import trange

from jaxtyping import Array, Bool, Float

from chaogatenn.chaogate import ChaoGate
from chaogatenn.maps import LogisticMap
from chaogatenn.utils import grad_norm

In [3]:
Map = LogisticMap(a=4.0)
# Map = LorenzMap(sigma=10.0, rho=28.0, beta=8/3, dt=0.01, steps=1000)
# Map = DuffingMap(
#     alpha=5.0, beta=5.0, delta=0.02, gamma=8.0, omega=0.5, dt=0.01, steps=1000
# )

In [4]:
# Training data for the Half Adder
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
y_sum = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR gate output for Sum
y_carry = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output for Carry

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [5]:
@eqx.filter_value_and_grad()
def compute_sum_loss(
    xor_gate: ChaoGate, x: Bool[Array, "batch 2"], y_sum: Bool[Array, "batch"]
) -> Float[Array, ""]:
    pred_sum = jax.vmap(xor_gate)(x)

    # Binary cross-entropy loss for XOR (Sum)
    loss_sum = -jnp.mean(
        y_sum * jnp.log(pred_sum + 1e-15) + (1 - y_sum) * jnp.log(1 - pred_sum + 1e-15)
    )

    return loss_sum

In [6]:
@eqx.filter_value_and_grad()
def compute_carry_loss(
    and_gate: ChaoGate, x: Bool[Array, "batch 2"], y_carry: Bool[Array, "batch"]
) -> Float[Array, ""]:
    pred_carry = jax.vmap(and_gate)(x)
    # Binary cross-entropy loss for AND (Carry)
    loss_carry = -jnp.mean(
        y_carry * jnp.log(pred_carry + 1e-15)
        + (1 - y_carry) * jnp.log(1 - pred_carry + 1e-15)
    )

    # Total loss is the sum of both losses
    return loss_carry

In [7]:
@eqx.filter_jit
def make_step(
    xor_gate: ChaoGate,
    and_gate: ChaoGate,
    X: Bool[Array, "batch 2"],
    y_sum: Bool[Array, "batch"],
    y_carry: Bool[Array, "batch "],
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> (Float[Array, "dim"], ChaoGate, optax.OptState):  # type: ignore
    loss_sum, grads_sum = compute_sum_loss(xor_gate, X, y_sum)

    loss_carry, grads_carry = compute_carry_loss(and_gate, X, y_carry)

    loss = loss_sum + loss_carry
    updates, opt_state = optim.update([grads_sum, grads_carry], opt_state)

    xor_gate = eqx.apply_updates(xor_gate, updates[0])  # type: ignore
    and_gate = eqx.apply_updates(and_gate, updates[1])  # type: ignore
    return loss, xor_gate, and_gate, opt_state

In [8]:
# Initialize the XOR and AND gates with random values
xor_gate = ChaoGate(DELTA=0.5, X0=1.0, X_THRESHOLD=0.4, Map=Map)
and_gate = ChaoGate(DELTA=0.5, X0=1.0, X_THRESHOLD=0.4, Map=Map)

In [9]:
optim = optax.adabelief(learning_rate=3e-4)
opt_state = optim.init(eqx.filter([xor_gate, and_gate], eqx.is_inexact_array))

In [10]:
epochs = 2500

for epoch in trange(epochs):
    loss, xor_gate, and_gate, opt_state = make_step(
        xor_gate,  # type: ignore
        and_gate,  # type: ignore
        X,
        y_sum,
        y_carry,
        optim,
        opt_state,  # type: ignore
    )
    _, grads = compute_sum_loss(xor_gate, X, y_sum)
    grad_norm_value = grad_norm(grads)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, loss: {loss}, grad norm: {grad_norm_value}")

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 0, loss: 2.115169048309326, grad norm: 6.584014415740967
Epoch 10, loss: 1.3006647825241089, grad norm: 2.2727248668670654
Epoch 20, loss: 1.236580729484558, grad norm: 1.4815900325775146
Epoch 30, loss: 1.2312860488891602, grad norm: 1.6913783550262451
Epoch 40, loss: 1.2257707118988037, grad norm: 1.0811671018600464
Epoch 50, loss: 1.2197433710098267, grad norm: 0.8808463215827942
Epoch 60, loss: 1.2133344411849976, grad norm: 1.1098077297210693
Epoch 70, loss: 1.2062852382659912, grad norm: 0.9026330709457397
Epoch 80, loss: 1.1987131834030151, grad norm: 0.9659523367881775
Epoch 90, loss: 1.1906392574310303, grad norm: 0.9868104457855225
Epoch 100, loss: 1.18207848072052, grad norm: 0.9743662476539612
Epoch 110, loss: 1.1730650663375854, grad norm: 1.0136175155639648
Epoch 120, loss: 1.1636264324188232, grad norm: 1.0104522705078125
Epoch 130, loss: 1.1537952423095703, grad norm: 1.0302644968032837
Epoch 140, loss: 1.143601894378662, grad norm: 1.0360119342803955
Epoch 150, l

In [11]:
# Display trained parameters
print("\nTrained XOR Gate Parameters:")
print(
    f"DELTA: {xor_gate.DELTA}, X0: {xor_gate.X0}, X_THRESHOLD: {xor_gate.X_THRESHOLD}"
)

print("\nTrained AND Gate Parameters:")
print(
    f"DELTA: {and_gate.DELTA}, X0: {and_gate.X0}, X_THRESHOLD: {and_gate.X_THRESHOLD}"
)

# Evaluate the trained Half-Adder
print("\nHalf-Adder Evaluation:")
for i in range(len(X)):
    sum_output = xor_gate(X[i])
    carry_output = and_gate(X[i])
    print(
        f"Input: {X[i]}, Predicted Sum: {sum_output:.0f}, Predicted Carry: {carry_output:.0f}"
    )


Trained XOR Gate Parameters:
DELTA: 0.6329566240310669, X0: 0.7230839729309082, X_THRESHOLD: -0.761564314365387

Trained AND Gate Parameters:
DELTA: 0.5060827136039734, X0: 1.0108097791671753, X_THRESHOLD: 3.158766269683838

Half-Adder Evaluation:
Input: [False False], Predicted Sum: 1, Predicted Carry: 0
Input: [False  True], Predicted Sum: 1, Predicted Carry: 0
Input: [ True False], Predicted Sum: 1, Predicted Carry: 0
Input: [ True  True], Predicted Sum: 0, Predicted Carry: 1
