In [4]:
import jax
import jax.numpy as jnp
import optax
import equinox as eqx


# ChaoGate class (same as before)
class ChaoGate(eqx.Module):
    DELTA: jnp.ndarray
    X0: jnp.ndarray
    X_THRESHOLD: jnp.ndarray

    def logistic_map(self, x: float, a: float = 4):
        return a * x * (1 - x)

    def __call__(self, x: jnp.ndarray):
        x1, x2, x3 = x  # Now includes Cin as the third input
        return jax.nn.sigmoid(
            self.logistic_map(
                self.X0 + x1 * self.DELTA + x2 * self.DELTA + x3 * self.DELTA
            )
            - self.X_THRESHOLD
        )


# A network of ChaoGates with fixed weights of 1
class ChaoGateNetwork(eqx.Module):
    gates: list  # No more weights; all gates contribute equally

    def __call__(self, inputs: jnp.ndarray):
        # Pass inputs to each gate in the array
        gate_outputs = jnp.array([gate(inputs) for gate in self.gates])

        # Fixed weights of 1: simply sum the gate outputs
        sum_of_gates = jnp.sum(gate_outputs)

        # Final output is sigmoid to map it to [0, 1]
        final_output = jax.nn.sigmoid(sum_of_gates)
        return final_output


# Full-adder with a flexible array of ChaoGates
class FlexibleFullAdder(eqx.Module):
    sum_network: ChaoGateNetwork  # Network that will learn to compute the sum
    carry_network: ChaoGateNetwork  # Network that will learn to compute the carry

    def __call__(self, inputs: jnp.ndarray):
        sum_output = self.sum_network(inputs)
        carry_output = self.carry_network(inputs)
        return sum_output, carry_output


# Input-output pairs for a full-adder
def full_adder_truth_table():
    inputs = jnp.array(
        [
            [0, 0, 0],
            [0, 0, 1],
            [0, 1, 0],
            [0, 1, 1],
            [1, 0, 0],
            [1, 0, 1],
            [1, 1, 0],
            [1, 1, 1],
        ]
    )
    sum_output = jnp.array([0, 1, 1, 0, 1, 0, 0, 1])  # Full adder sum (A XOR B XOR Cin)
    carry_output = jnp.array([0, 0, 0, 1, 0, 1, 1, 1])  # Full adder carry
    return inputs, sum_output, carry_output


# Loss function to train the network
def loss_fn(model, inputs, sum_target, carry_target):
    sum_pred, carry_pred = jax.vmap(model)(inputs)
    sum_loss = -jnp.mean(
        sum_target * jnp.log(sum_pred + 1e-15)
        + (1 - sum_target) * jnp.log(1 - sum_pred + 1e-15)
    )
    carry_loss = -jnp.mean(
        carry_target * jnp.log(carry_pred + 1e-15)
        + (1 - carry_target) * jnp.log(1 - carry_pred + 1e-15)
    )
    return sum_loss + carry_loss


# Training function
def train_full_adder(model, optimizer, inputs, sum_target, carry_target, steps=10000):
    opt_state = optimizer.init(model)

    @jax.jit
    def step(model, opt_state, inputs, sum_target, carry_target):
        loss, grads = jax.value_and_grad(loss_fn)(
            model, inputs, sum_target, carry_target
        )
        updates, opt_state = optimizer.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    for step_idx in range(steps):
        model, opt_state, loss = step(
            model, opt_state, inputs, sum_target, carry_target
        )
        if step_idx % 100 == 0:
            print(f"Step {step_idx}, Loss: {loss}")

    return model


# Initialize a network of ChaoGates
def initialize_chao_network(num_gates, key):
    gates = [
        ChaoGate(
            DELTA=jax.random.uniform(key, ()),
            X0=jax.random.uniform(key, ()),
            X_THRESHOLD=jax.random.uniform(key, ()),
        )
        for _ in range(num_gates)
    ]

    return ChaoGateNetwork(gates=gates)


# Main: Create and train a flexible full-adder model
key = jax.random.PRNGKey(0)

# Initialize the sum and carry networks with 4 gates each (can adjust the number of gates)
sum_network = initialize_chao_network(4, key)
carry_network = initialize_chao_network(4, key)

# Create the flexible full-adder
full_adder_model = FlexibleFullAdder(
    sum_network=sum_network, carry_network=carry_network
)

# Define optimizer (using Optax)
optimizer = optax.adam(learning_rate=0.01)

# Get the truth table for the full-adder
inputs, sum_target, carry_target = full_adder_truth_table()

# Train the full-adder model
trained_full_adder = train_full_adder(
    full_adder_model, optimizer, inputs, sum_target, carry_target
)

Step 0, Loss: 2.261190176010132
Step 100, Loss: 1.2721980810165405
Step 200, Loss: 1.1749132871627808
Step 300, Loss: 1.1548521518707275
Step 400, Loss: 1.1468359231948853
Step 500, Loss: 1.142775297164917
Step 600, Loss: 1.1403920650482178
Step 700, Loss: 1.1388503313064575
Step 800, Loss: 1.1377828121185303
Step 900, Loss: 1.1370059251785278
Step 1000, Loss: 1.1364188194274902
Step 1100, Loss: 1.135961651802063
Step 1200, Loss: 1.1355972290039062
Step 1300, Loss: 1.1353009939193726
Step 1400, Loss: 1.1350562572479248
Step 1500, Loss: 1.1348512172698975
Step 1600, Loss: 1.1346774101257324
Step 1700, Loss: 1.134528636932373
Step 1800, Loss: 1.1344000101089478
Step 1900, Loss: 1.1342881917953491
Step 2000, Loss: 1.1341900825500488
Step 2100, Loss: 1.134103536605835
Step 2200, Loss: 1.1340267658233643
Step 2300, Loss: 1.1339584589004517
Step 2400, Loss: 1.1338971853256226
Step 2500, Loss: 1.13384211063385
Step 2600, Loss: 1.1337924003601074
Step 2700, Loss: 1.1337474584579468
Step 2800, 