In [11]:
%env JAX_PLATFORM_NAME=cpu

import jaxtyping  # noqa: F401

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

env: JAX_PLATFORM_NAME=cpu
The jaxtyping extension is already loaded. To reload it, use:
  %reload_ext jaxtyping


In [12]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from tqdm.notebook import trange

from jaxtyping import Array, Bool, Float

from chaogatenn.chaogate import ChaoGate
from chaogatenn.maps import DuffingMap
from chaogatenn.utils import grad_norm

In [13]:
# Map = LogisticMap(a=4.0)
# Map = LorenzMap(sigma=10.0, rho=28.0, beta=8/3, dt=0.01, steps=1000)
Map = DuffingMap(
    alpha=1.0, beta=5.0, delta=0.02, gamma=8.0, omega=0.5, dt=0.01, steps=1000
)

In [14]:
# Training data for the Half Adder
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
y_sum = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR gate output for Sum
y_carry = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output for Carry

In [15]:
@eqx.filter_value_and_grad()
def compute_sum_loss(
    xor_gate: ChaoGate, x: Bool[Array, "batch 2"], y_sum: Bool[Array, "batch"]
) -> Float[Array, ""]:
    pred_sum = jax.vmap(xor_gate)(x)

    # Binary cross-entropy loss for XOR (Sum)
    loss_sum = -jnp.mean(
        y_sum * jnp.log(pred_sum + 1e-15) + (1 - y_sum) * jnp.log(1 - pred_sum + 1e-15)
    )

    return loss_sum

In [16]:
@eqx.filter_value_and_grad()
def compute_carry_loss(
    and_gate: ChaoGate, x: Bool[Array, "batch 2"], y_carry: Bool[Array, "batch"]
) -> Float[Array, ""]:
    pred_carry = jax.vmap(and_gate)(x)
    # Binary cross-entropy loss for AND (Carry)
    loss_carry = -jnp.mean(
        y_carry * jnp.log(pred_carry + 1e-15)
        + (1 - y_carry) * jnp.log(1 - pred_carry + 1e-15)
    )

    # Total loss is the sum of both losses
    return loss_carry

In [17]:
@eqx.filter_jit
def make_step(
    xor_gate: ChaoGate,
    and_gate: ChaoGate,
    X: Bool[Array, "batch 2"],
    y_sum: Bool[Array, "batch"],
    y_carry: Bool[Array, "batch "],
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> (Float[Array, "dim"], ChaoGate, optax.OptState):  # type: ignore
    loss_sum, grads_sum = compute_sum_loss(xor_gate, X, y_sum)

    loss_carry, grads_carry = compute_carry_loss(and_gate, X, y_carry)

    loss = loss_sum + loss_carry
    updates, opt_state = optim.update([grads_sum, grads_carry], opt_state)

    xor_gate = eqx.apply_updates(xor_gate, updates[0])  # type: ignore
    and_gate = eqx.apply_updates(and_gate, updates[1])  # type: ignore
    return loss, xor_gate, and_gate, opt_state

In [18]:
# Initialize the XOR and AND gates with random values
xor_gate = ChaoGate(DELTA=0.5, X0=1.0, X_THRESHOLD=0.4, Map=Map)
and_gate = ChaoGate(DELTA=0.5, X0=1.0, X_THRESHOLD=0.4, Map=Map)

In [19]:
optim = optax.adabelief(learning_rate=3e-4)
opt_state = optim.init(eqx.filter([xor_gate, and_gate], eqx.is_inexact_array))

In [20]:
epochs = 2500

for epoch in trange(epochs):
    loss, xor_gate, and_gate, opt_state = make_step(
        xor_gate,
        and_gate,
        X,
        y_sum,
        y_carry,
        optim,
        opt_state,  # type: ignore
    )
    _, grads = compute_sum_loss(xor_gate, X, y_sum)
    grad_norm_value = grad_norm(grads)

    if epoch % 1 == 0:
        print(f"Epoch {epoch}, loss: {loss}, grad norm: {grad_norm_value}")

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 0, loss: 1.3239755630493164, grad norm: 34.50337600708008
Epoch 1, loss: 1.300273060798645, grad norm: 33.1033821105957
Epoch 2, loss: 1.276551604270935, grad norm: 31.577356338500977
Epoch 3, loss: 1.2530434131622314, grad norm: 29.936603546142578
Epoch 4, loss: 1.229955792427063, grad norm: 28.199138641357422
Epoch 5, loss: 1.2075345516204834, grad norm: 26.383403778076172
Epoch 6, loss: 1.1859779357910156, grad norm: 24.514856338500977
Epoch 7, loss: 1.1655094623565674, grad norm: 22.62105369567871
Epoch 8, loss: 1.1463145017623901, grad norm: 20.727556228637695
Epoch 9, loss: 1.1285209655761719, grad norm: 18.86127471923828
Epoch 10, loss: 1.1122305393218994, grad norm: 17.050796508789062
Epoch 11, loss: 1.0975120067596436, grad norm: 15.314709663391113
Epoch 12, loss: 1.0843632221221924, grad norm: 13.672774314880371
Epoch 13, loss: 1.072744607925415, grad norm: 12.135438919067383
Epoch 14, loss: 1.0625773668289185, grad norm: 10.713972091674805
Epoch 15, loss: 1.05377399921

In [21]:
# Display trained parameters
print("\nTrained XOR Gate Parameters:")
print(
    f"DELTA: {xor_gate.DELTA}, X0: {xor_gate.X0}, X_THRESHOLD: {xor_gate.X_THRESHOLD}"
)

print("\nTrained AND Gate Parameters:")
print(
    f"DELTA: {and_gate.DELTA}, X0: {and_gate.X0}, X_THRESHOLD: {and_gate.X_THRESHOLD}"
)

# Evaluate the trained Half-Adder
print("\nHalf-Adder Evaluation:")
for i in range(len(X)):
    sum_output = xor_gate(X[i])
    carry_output = and_gate(X[i])
    print(
        f"Input: {X[i]}, Predicted Sum: {sum_output:.0f}, Predicted Carry: {carry_output:.0f}"
    )


Trained XOR Gate Parameters:
DELTA: 0.4204155206680298, X0: 1.1994107961654663, X_THRESHOLD: -0.2726741135120392

Trained AND Gate Parameters:
DELTA: 0.42408156394958496, X0: 1.126099944114685, X_THRESHOLD: 1.864433765411377

Half-Adder Evaluation:
Input: [False False], Predicted Sum: 1, Predicted Carry: 0
Input: [False  True], Predicted Sum: 1, Predicted Carry: 0
Input: [ True False], Predicted Sum: 1, Predicted Carry: 0
Input: [ True  True], Predicted Sum: 0, Predicted Carry: 1
