In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import optax
from tqdm import trange
from jaxtyping import Array, Bool, Float

from chaogatenn.chaogate import ChaoGate
from chaogatenn.maps import DuffingMap
from chaogatenn.utils import grad_norm

In [2]:
# Training data for different logic gates
X = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=bool)  # Input combinations
AND_Y = jnp.array([0, 0, 0, 1], dtype=bool)  # AND gate output
OR_Y = jnp.array([0, 1, 1, 1], dtype=bool)  # OR gate output
XOR_Y = jnp.array([0, 1, 1, 0], dtype=bool)  # XOR gate output
NAND_Y = jnp.array([1, 1, 1, 0], dtype=bool)  # NAND gate output
NOR_Y = jnp.array([1, 0, 0, 0], dtype=bool)  # NOR gate output
XNOR_Y = jnp.array([1, 0, 0, 1], dtype=bool)  # XNOR gate output

In [3]:
# List of logic gates and their corresponding outputs
logic_gates = {
    "AND": AND_Y,
    "OR": OR_Y,
    "XOR": XOR_Y,
    "NAND": NAND_Y,
    "NOR": NOR_Y,
    "XNOR": XNOR_Y,
}

In [4]:
@eqx.filter_value_and_grad()
def compute_loss(
    chao_gate: ChaoGate, x: Bool[Array, "batch 2"], y: Bool[Array, "batch"]
) -> Float[Array, ""]:
    pred = jax.vmap(chao_gate)(x)
    # binary cross entropy
    return -jnp.mean(y * jnp.log(pred + 1e-15) + (1 - y) * jnp.log(1 - pred + 1e-15))


# Function to perform a single optimization step
@eqx.filter_jit
def make_step(
    model: ChaoGate,
    x: Bool[Array, "dim 2"],
    y: Bool[Array, "dim"],
    optim: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> tuple[Float[Array, "dim"], ChaoGate, optax.OptState]:
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [5]:
output_dir = "../output/duffing_sweep/"
metrics_dict = {}
results_dict = {}

In [6]:
for gate_name, Y in logic_gates.items():
    metrics_dict[gate_name] = []
    results_dict[gate_name] = []
    for beta in jnp.geomspace(1e-3, 5, num=40):  # 50 steps from 0 to 4
        Map = DuffingMap(beta=beta)
        chao_gate = ChaoGate(DELTA=1.0, X0=1.0, X_THRESHOLD=1.0, Map=Map)
        optim = optax.adabelief(3e-4)
        opt_state = optim.init(eqx.filter(chao_gate, eqx.is_inexact_array))

        epochs = 1000
        for epoch in trange(
            epochs, desc=f"Training {gate_name} gate with beta={beta:.2f}"
        ):
            loss, chao_gate, opt_state = make_step(chao_gate, X, Y, optim, opt_state)
            _, grads = compute_loss(chao_gate, X, Y)
            grad_norm_value = grad_norm(grads)

        pred_ys = jax.vmap(chao_gate)(X)
        num_correct = jnp.sum((pred_ys > 0.5) == Y)
        final_accuracy = (num_correct / len(X)).item()
        metrics_dict[gate_name].append(
            (beta, loss.item(), final_accuracy, grad_norm_value)
        )
        results_dict[gate_name].append(
            (beta, chao_gate.DELTA, chao_gate.X0, chao_gate.X_THRESHOLD)
        )

Training AND gate with beta=nan:   7%|▋         | 67/1000 [00:14<02:44,  5.66it/s]

In [11]:
# Print results
for gate_name, metrics in metrics_dict.items():
    print(f"\nResults for {gate_name} gate:")
    for a, loss, accuracy, grad_norm_value in metrics:
        print(
            f"beta={beta:.2f}, Loss={loss:.6f}, Accuracy={accuracy:.2f}, Grad Norm={grad_norm_value:.6f}"
        )


Results for AND gate:
a=0.00, Loss=0.029688, Accuracy=1.00, Grad Norm=0.036074
a=0.10, Loss=0.038069, Accuracy=1.00, Grad Norm=0.059445
a=0.21, Loss=0.080870, Accuracy=1.00, Grad Norm=0.102492
a=0.31, Loss=0.133663, Accuracy=1.00, Grad Norm=0.147805
a=0.41, Loss=0.200792, Accuracy=1.00, Grad Norm=0.203656
a=0.51, Loss=0.307240, Accuracy=1.00, Grad Norm=0.318958
a=0.62, Loss=0.479412, Accuracy=0.75, Grad Norm=0.362791
a=0.72, Loss=0.558830, Accuracy=0.75, Grad Norm=0.163402
a=0.82, Loss=0.577126, Accuracy=0.75, Grad Norm=0.126004
a=0.92, Loss=0.581904, Accuracy=0.75, Grad Norm=0.146581
a=1.03, Loss=0.583202, Accuracy=0.75, Grad Norm=0.170779
a=1.13, Loss=0.583410, Accuracy=0.75, Grad Norm=0.192272
a=1.23, Loss=0.583229, Accuracy=0.75, Grad Norm=0.211467
a=1.33, Loss=0.582894, Accuracy=0.75, Grad Norm=0.229102
a=1.44, Loss=0.582493, Accuracy=0.75, Grad Norm=0.245686
a=1.54, Loss=0.582057, Accuracy=0.75, Grad Norm=0.261544
a=1.64, Loss=0.581601, Accuracy=0.75, Grad Norm=0.276882
a=1.74, 

In [12]:
# transform into arrays and save using numpy savetxt
for gate_name, metrics in metrics_dict.items():
    metrics = jnp.array(metrics)
    np.savetxt(f"{output_dir}{gate_name}_metrics.txt", metrics, delimiter=",")

In [13]:
for gate_name, results in results_dict.items():
    results = jnp.array(results)
    np.savetxt(f"{output_dir}{gate_name}_results.txt", results, delimiter=",")