In [31]:
import jax
import jax.numpy as jnp
import pennylane as qml
from flax import linen as nn
from flax.training import train_state
import optax
from functools import partial
from typing import Callable
import numpy as np

In [None]:
class QuantumAutoEncoder(nn.Module):
    n_qubits: int
    input_size: int
    n_qlayers: int = 2
    backend: str = "default.qubit"
    embedding_dim: int = 32

    def setup(self):
        nq=self.n_qubits
        self.dev = qml.device(self.backend, wires=nq)
        
        @qml.qnode(self.dev, interface='jax')
        def circuit(x, w):
            qml.AmplitudeEmbedding(x, wires=range(nq), normalize=True)
            qml.templates.BasicEntanglerLayers(w, wires=range(nq))
            result=qml.probs(wires=range(nq))#[qml.expval(qml.PauliZ(i)) for i in range(nq)]
            return result
        
        def electron_constriction(x, ne, key):
            """
            x: jnp.ndarray of shape (sequence_length, feature_dim, ?)
            We will modify x[-1, :, -1]
            ne: int, number of electrons
            key: jax.random.PRNGKey
            """
            x2 = jnp.abs(x[-1, :, -1])
            zero_vector = jnp.zeros_like(x2)

            def loop_body(i, carry):
                x2, zero_vector, key = carry
                key1, key2, new_key = jax.random.split(key, 3)

                # ----- Even positions (alpha electrons) -----
                x2_even = x2[::2]
                cumsum_even = jnp.cumsum(x2_even)
                rand_val = jax.random.uniform(key1) * cumsum_even[-1]
                idx_even = jnp.argmax(cumsum_even >= rand_val) * 2

                x2 = x2.at[idx_even].set(0.0)
                zero_vector = zero_vector.at[idx_even].set(1.0)

                # ----- Odd positions (beta electrons) -----
                x2_odd = x2[1::2]
                cumsum_odd = jnp.cumsum(x2_odd)
                rand_val = jax.random.uniform(key2) * cumsum_odd[-1]
                idx_odd = jnp.argmax(cumsum_odd >= rand_val) * 2 + 1

                x2 = x2.at[idx_odd].set(0.0)
                zero_vector = zero_vector.at[idx_odd].set(1.0)

                return (x2, zero_vector, new_key)

            # Loop ne//2 times
            x2, zero_vector, key = jax.lax.fori_loop(0, ne // 2, loop_body, (x2, zero_vector, key))

            # Sanity check (for debugging – remove if using JIT)
            # assert jnp.count_nonzero(zero_vector) == ne, "Electron count mismatch"

            # Replace the last column with the constrained vector
            x = x.at[-1, :, -1].set(zero_vector)

            return x
         

        self.qnode = circuit
        self.batched_qnode = jax.vmap(self.qnode, in_axes=(0, None))

        self.weightsEnc = self.param(
            "weightsEnc", nn.initializers.normal(stddev=0.1),
            (self.n_qlayers, nq)
        )
        self.weightsDec = self.param(
            "weightsDec", nn.initializers.normal(stddev=0.1),
            (self.n_qlayers, nq)
        )

    @nn.compact
    def __call__(self, inputs,train: bool = True):
        # inputs: (batch, input_size) = (16, 8)
        x = nn.Dense(self.embedding_dim)(inputs)         
        x = nn.relu(x)
        x = self.batched_qnode(x, self.weightsEnc)          
        x = nn.Dense(2 * self.embedding_dim)(x)          
        x = nn.Dense(self.embedding_dim)(x)              
        x = self.batched_qnode(x, self.weightsDec)          
        x = nn.Dense(self.input_size)(x)                 # (16, 8)
        x = self.electron_constriction(x)                                # (16, 8)
        return x


In [45]:
def train_model(train_data,train_loader,test_loader,batch,epochs):
    #print(train_data.shape)
    #sample_input=jnp.array(train_data[16,10,8])
    input_shape=[16,8]
    net=QuantumAutoEncoder(n_qubits=5,input_size=8)
    key=jax.random.PRNGKey(0)
    params=net.init(key,jnp.ones(input_shape))
    optimizer=optax.adam(0.01)
    opt_state=optimizer.init(params)
    def binary_accuracy(original, reconstructed, threshold=0.5):
        original = jnp.array(original)
        reconstructed = jnp.array(reconstructed)
        pred = (reconstructed > threshold).astype(jnp.float32)
        correct = (pred == original).astype(jnp.float32)
        return jnp.mean(correct).item()
    @jax.jit
    def train_step(params,opt_state,inputs,targets):
        def loss_fn(params,inputs,targets):
            preds=net.apply(params,inputs)
            loss = -jnp.mean(targets * jnp.log(preds + 1e-7) + (1 - targets) * jnp.log(1 - preds + 1e-7))
            #jax.debug.print(">>> preds mean: {}", jnp.mean(preds))
            return loss
        loss,grad=jax.value_and_grad(loss_fn)(params,inputs,targets)
        updates, opt_state=optimizer.update(grad,opt_state)
        new_params=optax.apply_updates(params,updates)
        return loss, new_params, opt_state
    for epoch in range(epochs):
        epoch_loss = 0.0
        for data in train_loader:
            inputs, targets = data[0], data[1]
            loss, params, opt_state = train_step(params, opt_state, inputs, targets)
            epoch_loss += loss
        epoch_loss /= len(train_loader)

        print(f"Epoch {epoch}, Loss: {epoch_loss}")
        forecasted=[]
        loss_test=[]
        targets_val=[]

    for data in test_loader:

        inputs,targets=data[0],data[1]
        preds=net.apply(params,inputs)
        loss=-jnp.mean(targets * jnp.log(preds + 1e-7) + (1 - targets) * jnp.log(1 - preds + 1e-7))
        
        forecasted.append(preds)
        targets_val.append(targets)
        loss_test.append(loss)
    #error metrics
    loss_test=jnp.array(loss_test).mean()
    print(f"Test Loss: {loss_test}")
    forecasted = jnp.concatenate(forecasted, axis=0)
    targets_val = jnp.concatenate(targets_val, axis=0)
    print("Preds: mean =", jnp.mean(forecasted), ", std =", jnp.std(forecasted))
    print("Targets: mean =", jnp.mean(targets_val), ", std =", jnp.std(targets_val))

    accuracy = binary_accuracy(forecasted, targets_val)
    print(f"RMSE: {accuracy}")
    return net,params
            

In [46]:
def generate_clean_data(n_samples, input_dim):
    return np.random.randint(0, 2, (n_samples, input_dim))
def add_noise(x, noise_level=0.2):
    noise = np.random.randn(*x.shape) < noise_level
    return (x + noise) % 2  # Flip bits
def binary_accuracy(original, reconstructed, threshold=0.5):
    pred = (reconstructed > threshold).float()
    correct = (pred == original).float()
    return correct.mean().item()

In [47]:
import jax_dataloader as jdl
noise_level=0.3
input_size=8
batch_size=16
samples=batch_size*300
clean_data = generate_clean_data(samples, input_size)
noisy_data = add_noise(clean_data, noise_level=noise_level)
train_data=jdl.ArrayDataset(clean_data,noisy_data)
train_loader=jdl.DataLoader(train_data,backend='jax',batch_size=16,shuffle=True,drop_last=True)
test_samples = generate_clean_data(96, input_size)
noisy_test = add_noise(test_samples, noise_level=0.3)
test_data=jdl.ArrayDataset(test_samples,noisy_test)
test_loader=jdl.DataLoader(test_data,backend='jax',batch_size=16,shuffle=False,drop_last=True)
train_model(clean_data,train_loader,test_loader,batch_size,epochs=150)

Epoch 0, Loss: 0.6869564056396484
Epoch 1, Loss: 0.6788758635520935
Epoch 2, Loss: 0.6745942831039429
Epoch 3, Loss: 0.6725912690162659
Epoch 4, Loss: 0.6702226996421814
Epoch 5, Loss: 0.669006884098053
Epoch 6, Loss: 0.6688790917396545
Epoch 7, Loss: 0.667693555355072
Epoch 8, Loss: 0.6671999096870422
Epoch 9, Loss: 0.6669465899467468
Epoch 10, Loss: 0.6665127277374268
Epoch 11, Loss: 0.6655364036560059
Epoch 12, Loss: 0.6649147272109985
Epoch 13, Loss: 0.6649458408355713
Epoch 14, Loss: 0.6636844873428345
Epoch 15, Loss: 0.6640265583992004
Epoch 16, Loss: 0.6637663841247559
Epoch 17, Loss: 0.6635016202926636
Epoch 18, Loss: 0.6629105806350708
Epoch 19, Loss: 0.6631539463996887
Epoch 20, Loss: 0.6626130938529968
Epoch 21, Loss: 0.6627051830291748
Epoch 22, Loss: 0.6626222729682922
Epoch 23, Loss: 0.6622732281684875
Epoch 24, Loss: 0.6621303558349609
Epoch 25, Loss: 0.6616784930229187
Epoch 26, Loss: 0.6617505550384521
Epoch 27, Loss: 0.6612953543663025
Epoch 28, Loss: 0.66112327575683

(QuantumAutoEncoder(
     # attributes
     n_qubits = 5
     input_size = 8
     n_qlayers = 2
     backend = 'default.qubit'
     embedding_dim = 32
 ),
 {'params': {'Dense_0': {'bias': Array([-0.9148683 ,  2.2726054 ,  2.1679316 , -1.5716857 , -0.66860855,
            1.397078  ,  0.15520507,  1.2170259 , -0.2107089 , -1.9452105 ,
            1.4254565 , -2.117494  ,  2.8563974 , -1.8062284 , -1.4366468 ,
            2.577079  ,  0.7471307 , -1.7303091 ,  0.02512974,  0.8536879 ,
           -3.5370615 , -1.2477065 ,  1.2060643 , -2.3516333 ,  1.9831634 ,
            0.14551745, -0.36467347, -0.7905444 ,  1.1378888 , -0.18673113,
            1.0532075 , -1.349688  ], dtype=float32),
    'kernel': Array([[ 1.16617098e-01,  3.77709717e-01, -9.35216546e-01,
             3.38628078e+00, -1.60734951e+00, -1.19939864e+00,
            -3.95994592e+00, -6.06687427e-01, -6.54238462e-02,
             5.52972615e-01, -4.50445986e+00, -4.09472942e+00,
             2.12708235e+00, -1.80171561e+00