In [1]:
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
from pennylane import AngleEmbedding, StronglyEntanglingLayers, RandomLayers
import pennylane as qml
import os
import jax
from jax import numpy as jnp
import equinox as eqx
import time
from tqdm import tqdm
import optax
import tensorflow as tf
from jaxtyping import Array, Float, PyTree
from collections import deque

In [2]:
print(jax.devices())

[cuda(id=0)]


In [3]:
key = jax.random.PRNGKey(0)

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

In [5]:
x_train = x_train[:50_000]
y_train = y_train[:50_000]

x_test = x_test[:10_000]
y_test = y_test[:10_000]

In [6]:
y_test = y_test.astype(np.float32)
y_train = y_train.astype(np.float32)

x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)

In [7]:
x_train = jnp.expand_dims(x_train, axis=-1)
x_test = jnp.expand_dims(x_test, axis=-1)

In [8]:
x_train = np.transpose(x_train, (0, 3, 1, 2))
x_test = np.transpose(x_test, (0, 3, 1, 2))

In [9]:
y_train = jax.nn.one_hot(y_train, 10)
y_test = jax.nn.one_hot(y_test, 10)

In [10]:
class QuantumConv2dLayer(eqx.Module):
    filter_size: int
    circuit_length: int
    padding_mode: str
    weight: jax.Array
    num_qubits: int
    quantum_conv_circuit: list
    kernel_size: tuple[int, int]
    stride: tuple[int, int]
    
    def __init__(self, filter_size, circuit_length, kernel_size, stride, padding_mode, in_channel):
        super().__init__()
        self.filter_size = filter_size
        self.circuit_length = circuit_length
        self.padding_mode = padding_mode
        
        self.num_qubits = in_channel * kernel_size[0] * kernel_size[1]
        self.weight = jax.random.normal(jax.random.PRNGKey(0), (filter_size, circuit_length, self.num_qubits))
        self.kernel_size = kernel_size
        self.stride = stride

        device = qml.device("default.qubit.jax", wires=self.num_qubits)
        @jax.jit
        @qml.qnode(device, interface="jax")
        def quantum_conv_circuit(inputs, weights):
            AngleEmbedding(inputs, wires=range(self.num_qubits))
            RandomLayers(weights, wires=range(self.num_qubits))
            
            return qml.expval(qml.PauliZ(wires=0))

        self.quantum_conv_circuit = [quantum_conv_circuit]
        
    @eqx.filter_jit
    def apply_on_batch(self, inputs):
        # Apply the quantum circuit to each patch
        inputs = jnp.transpose(inputs)
        outputs = []
                
        for i in range(self.filter_size):
            
            outputs.append(
                jnp.expand_dims(
                    jax.vmap(
                        self.quantum_conv_circuit[0], in_axes=(0, None)
                    )(inputs, self.weight[i]),
                    axis=0
                )
            )
        return jnp.concatenate(outputs, axis=0)
    
    @eqx.filter_jit
    def __call__(self, inputs):
        # Extract patches
        inputs = jnp.expand_dims(inputs, axis=0)
        patches = jax.lax.conv_general_dilated_patches(inputs, self.kernel_size, self.stride, self.padding_mode)
        patches_shape = patches.shape
        patches_flat = patches.reshape(patches_shape[0], patches_shape[1], patches_shape[2] * patches_shape[3])

        # Apply the quantum circuit on each patch
        output = jax.vmap(self.apply_on_batch)(patches_flat)
        output_shape = patches_shape
        
        output_shape = (self.filter_size, output_shape[2], output_shape[3])        
        # Reshape the output
        return output.reshape(output_shape)


#quantum_layer = QuantumConv2dLayer(filter_size=2, circuit_length=3, kernel_size=(2, 2), stride=(2, 2), padding_mode='SAME', in_channel=1)

In [11]:
#quantum_layer((np.random.random((1, 28, 28))))

In [12]:
class HybrideModel(eqx.Module):
    layers: list
    
    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.layers = [
            eqx.nn.Conv2d(1, 1, kernel_size=4, key=key1, stride=3),
            jax.nn.sigmoid,
            QuantumConv2dLayer(filter_size=8, circuit_length=1, kernel_size=(3, 3), stride=(2, 2), padding_mode='VALID', in_channel=1),
            eqx.nn.Conv2d(8, 12, kernel_size=2, key=key3, stride=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(48, 10, key=key3),
            jax.nn.softmax
        ]
        
    @eqx.filter_jit
    def __call__(self, x):
        
        for layer in self.layers:
            x = layer(x)
        return x
    
key, subkey = jax.random.split(key, 2)
model = HybrideModel(subkey)

In [13]:
@eqx.filter_jit
def categorical_crossentropy(model, inputs, targets):
    
    y_pred = jax.vmap(model)(inputs)
    
    y_pred = jnp.clip(y_pred, 1e-7, 1 - 1e-7)
    
    loss = -jnp.sum(targets * jnp.log(y_pred), axis=-1)
    
    return jnp.mean(loss), y_pred

@eqx.filter_jit
def categorical_accuracy(y_true, y_pred):
    true_labels = jnp.argmax(y_true, axis=-1)
    predicted_labels = jnp.argmax(y_pred, axis=-1)

    # Comparer les indices pour calculer la précision
    return jnp.mean(true_labels == predicted_labels)
    

@eqx.filter_jit
def train_step(model, opt_state, inputs, target):
    
    (loss_value, y_pred), grads = eqx.filter_value_and_grad(categorical_crossentropy, has_aux=True)(model, inputs, target)
    
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    
    accuracy = categorical_accuracy(target, y_pred)
    
    
    return model, opt_state, loss_value, accuracy

@eqx.filter_jit
def test_step(model, inputs, target):
    loss_value, y_pred = categorical_crossentropy(model, inputs, target)
    
    accuracy = categorical_accuracy(target, y_pred)
    
    return loss_value, accuracy

In [14]:
def batch(array, batch_size):
    return np.array_split(array, np.ceil(len(array) / batch_size))

In [15]:
optim = optax.adamax(0.002)
opt_state = optim.init(eqx.filter(model, eqx.is_array))
epochs = 20
batch_size = 32

In [16]:
x_train_batch = batch(x_train, batch_size=batch_size)
y_train_batch = batch(y_train, batch_size=batch_size)

x_test_batch = batch(x_test, batch_size=batch_size)
y_test_batch = batch(y_test, batch_size=batch_size)

In [17]:
for step in range(epochs):
    accuracys = deque()
    losss = deque()
    print(f"Starting epoch: {step + 1}")
    
    for (x_batch, y_batch) in tqdm(zip(x_train_batch, y_train_batch), total=len(x_train_batch)):
        model, opt_state, loss_val, accuracy = train_step(model, opt_state, x_batch, y_batch)
        accuracys.append(accuracy)
        losss.append(loss_val)
    print(f"Train Loss: {np.mean(losss)}")
    print(f"Train Accuracy: {np.mean(accuracys)}")
    
    test_accuracys = deque()
    test_losss = deque()
    for (x_batch, y_batch) in tqdm(zip(x_test_batch, y_test_batch), total=len(x_test_batch)):
        loss_val, accuracy = test_step(model, x_batch, y_batch)
        
        test_accuracys.append(accuracy)
        test_losss.append(loss_val)
    print(f"Test Loss: {np.mean(test_losss)}")
    print(f"Test Accuracy: {np.mean(test_accuracys)}")
        

Starting epoch: 1


100%|██████████| 1563/1563 [00:36<00:00, 43.29it/s]


Train Loss: 1.4758970737457275
Train Accuracy: 0.5172615647315979


100%|██████████| 313/313 [00:05<00:00, 57.58it/s] 


Test Loss: 1.1107656955718994
Test Accuracy: 0.6330065131187439
Starting epoch: 2


100%|██████████| 1563/1563 [00:21<00:00, 72.36it/s]


Train Loss: 1.092655897140503
Train Accuracy: 0.6315327286720276


100%|██████████| 313/313 [00:01<00:00, 300.59it/s]


Test Loss: 1.0371507406234741
Test Accuracy: 0.6539826393127441
Starting epoch: 3


100%|██████████| 1563/1563 [00:21<00:00, 72.42it/s]


Train Loss: 1.041046380996704
Train Accuracy: 0.6477462649345398


100%|██████████| 313/313 [00:01<00:00, 301.00it/s]


Test Loss: 0.9943450093269348
Test Accuracy: 0.6703950762748718
Starting epoch: 4


100%|██████████| 1563/1563 [00:21<00:00, 72.40it/s]


Train Loss: 1.000580906867981
Train Accuracy: 0.6636670231819153


100%|██████████| 313/313 [00:01<00:00, 300.39it/s]


Test Loss: 0.9574816823005676
Test Accuracy: 0.6839669346809387
Starting epoch: 5


100%|██████████| 1563/1563 [00:21<00:00, 71.90it/s]


Train Loss: 0.9626046419143677
Train Accuracy: 0.6762100458145142


100%|██████████| 313/313 [00:01<00:00, 297.00it/s]


Test Loss: 0.9245643615722656
Test Accuracy: 0.6930556297302246
Starting epoch: 6


100%|██████████| 1563/1563 [00:21<00:00, 71.91it/s]


Train Loss: 0.9228246212005615
Train Accuracy: 0.6924512982368469


100%|██████████| 313/313 [00:01<00:00, 298.70it/s]


Test Loss: 0.8798932433128357
Test Accuracy: 0.7123408913612366
Starting epoch: 7


100%|██████████| 1563/1563 [00:21<00:00, 72.08it/s]


Train Loss: 0.8872103691101074
Train Accuracy: 0.7077909111976624


100%|██████████| 313/313 [00:01<00:00, 293.96it/s]


Test Loss: 0.8317278027534485
Test Accuracy: 0.7314683198928833
Starting epoch: 8


100%|██████████| 1563/1563 [00:21<00:00, 71.05it/s]


Train Loss: 0.8447204828262329
Train Accuracy: 0.7222456932067871


100%|██████████| 313/313 [00:01<00:00, 289.30it/s]


Test Loss: 0.7901082038879395
Test Accuracy: 0.7469725608825684
Starting epoch: 9


100%|██████████| 1563/1563 [00:22<00:00, 69.86it/s]


Train Loss: 0.8224063515663147
Train Accuracy: 0.7312434315681458


100%|██████████| 313/313 [00:01<00:00, 290.13it/s]


Test Loss: 0.7689432501792908
Test Accuracy: 0.753678023815155
Starting epoch: 10


100%|██████████| 1563/1563 [00:22<00:00, 69.90it/s]


Train Loss: 0.8035842776298523
Train Accuracy: 0.7373834252357483


100%|██████████| 313/313 [00:01<00:00, 289.58it/s]


Test Loss: 0.7547426223754883
Test Accuracy: 0.7575717568397522
Starting epoch: 11


100%|██████████| 1563/1563 [00:22<00:00, 69.89it/s]


Train Loss: 0.7880181670188904
Train Accuracy: 0.7418838739395142


100%|██████████| 313/313 [00:01<00:00, 289.19it/s]


Test Loss: 0.7462584972381592
Test Accuracy: 0.7592658400535583
Starting epoch: 12


100%|██████████| 1563/1563 [00:22<00:00, 69.95it/s]


Train Loss: 0.7770360112190247
Train Accuracy: 0.7462831139564514


100%|██████████| 313/313 [00:01<00:00, 289.74it/s]


Test Loss: 0.7386224865913391
Test Accuracy: 0.7616523504257202
Starting epoch: 13


100%|██████████| 1563/1563 [00:22<00:00, 69.94it/s]


Train Loss: 0.7676261067390442
Train Accuracy: 0.7487037181854248


100%|██████████| 313/313 [00:01<00:00, 289.50it/s]


Test Loss: 0.7289990186691284
Test Accuracy: 0.7638456225395203
Starting epoch: 14


100%|██████████| 1563/1563 [00:22<00:00, 69.93it/s]


Train Loss: 0.7605056166648865
Train Accuracy: 0.7511241436004639


100%|██████████| 313/313 [00:01<00:00, 289.31it/s]


Test Loss: 0.7212004661560059
Test Accuracy: 0.7659390568733215
Starting epoch: 15


100%|██████████| 1563/1563 [00:22<00:00, 69.94it/s]


Train Loss: 0.7543898224830627
Train Accuracy: 0.7530048489570618


100%|██████████| 313/313 [00:01<00:00, 289.71it/s]


Test Loss: 0.7142605781555176
Test Accuracy: 0.7672305107116699
Starting epoch: 16


100%|██████████| 1563/1563 [00:22<00:00, 69.83it/s]


Train Loss: 0.7480621337890625
Train Accuracy: 0.7545230984687805


100%|██████████| 313/313 [00:01<00:00, 287.97it/s]


Test Loss: 0.7074794173240662
Test Accuracy: 0.7705317139625549
Starting epoch: 17


100%|██████████| 1563/1563 [00:22<00:00, 69.77it/s]


Train Loss: 0.7426234483718872
Train Accuracy: 0.7563456892967224


100%|██████████| 313/313 [00:01<00:00, 249.80it/s]


Test Loss: 0.703238308429718
Test Accuracy: 0.7727314233779907
Starting epoch: 18


100%|██████████| 1563/1563 [00:21<00:00, 71.71it/s]


Train Loss: 0.7376511096954346
Train Accuracy: 0.7574840188026428


100%|██████████| 313/313 [00:01<00:00, 295.94it/s]


Test Loss: 0.6981175541877747
Test Accuracy: 0.7742483019828796
Starting epoch: 19


100%|██████████| 1563/1563 [00:21<00:00, 72.65it/s]


Train Loss: 0.7335332632064819
Train Accuracy: 0.7587036490440369


100%|██████████| 313/313 [00:01<00:00, 303.13it/s]


Test Loss: 0.6942533254623413
Test Accuracy: 0.7762483358383179
Starting epoch: 20


100%|██████████| 1563/1563 [00:21<00:00, 73.21it/s]


Train Loss: 0.7295634746551514
Train Accuracy: 0.759443461894989


100%|██████████| 313/313 [00:01<00:00, 298.47it/s]


Test Loss: 0.6877160668373108
Test Accuracy: 0.7765575051307678
