In [1]:
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
from pennylane import AngleEmbedding, StronglyEntanglingLayers, RandomLayers
import pennylane as qml
import os
import jax
from jax import numpy as jnp
import equinox as eqx
import time
from tqdm import tqdm
import optax
import tensorflow as tf
from jaxtyping import Array, Float, PyTree
from collections import deque

In [2]:
print(jax.devices())

[cuda(id=0)]


In [3]:
key = jax.random.PRNGKey(0)

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

In [5]:
x_train = x_train[:50_000]
y_train = y_train[:50_000]

x_test = x_test[:10_000]
y_test = y_test[:10_000]

In [6]:
y_test = y_test.astype(np.float32)
y_train = y_train.astype(np.float32)

x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)

In [7]:
x_train = jnp.expand_dims(x_train, axis=-1)
x_test = jnp.expand_dims(x_test, axis=-1)

In [8]:
x_train = np.transpose(x_train, (0, 3, 1, 2))
x_test = np.transpose(x_test, (0, 3, 1, 2))

In [9]:
y_train = jax.nn.one_hot(y_train, 10)
y_test = jax.nn.one_hot(y_test, 10)

In [10]:
class QuantumConv2dLayer(eqx.Module):
    filter_size: int
    circuit_length: int
    padding_mode: str
    weight: jax.Array
    num_qubits: int
    quantum_conv_circuit: list
    kernel_size: tuple[int, int]
    stride: tuple[int, int]
    
    def __init__(self, filter_size, circuit_length, kernel_size, stride, padding_mode, in_channel):
        super().__init__()
        self.filter_size = filter_size
        self.circuit_length = circuit_length
        self.padding_mode = padding_mode
        
        self.num_qubits = in_channel * kernel_size[0] * kernel_size[1]
        self.weight = jax.random.normal(jax.random.PRNGKey(0), (filter_size, circuit_length, self.num_qubits, 3))
        self.kernel_size = kernel_size
        self.stride = stride

        device = qml.device("default.qubit.jax", wires=self.num_qubits)
        @jax.jit
        @qml.qnode(device, interface="jax")
        def quantum_conv_circuit(inputs, weights):
            AngleEmbedding(inputs, wires=range(self.num_qubits))
            StronglyEntanglingLayers(weights, wires=range(self.num_qubits))
            
            return qml.expval(qml.PauliZ(wires=0))

        self.quantum_conv_circuit = [quantum_conv_circuit]
        
    @eqx.filter_jit
    def apply_on_batch(self, inputs):
        # Apply the quantum circuit to each patch
        inputs = jnp.transpose(inputs)
        outputs = []
                
        for i in range(self.filter_size):
            
            outputs.append(
                jnp.expand_dims(
                    jax.vmap(
                        self.quantum_conv_circuit[0], in_axes=(0, None)
                    )(inputs, self.weight[i]),
                    axis=0
                )
            )
        return jnp.concatenate(outputs, axis=0)
    
    @eqx.filter_jit
    def __call__(self, inputs):
        # Extract patches
        inputs = jnp.expand_dims(inputs, axis=0)
        patches = jax.lax.conv_general_dilated_patches(inputs, self.kernel_size, self.stride, self.padding_mode)
        patches_shape = patches.shape
        patches_flat = patches.reshape(patches_shape[0], patches_shape[1], patches_shape[2] * patches_shape[3])

        # Apply the quantum circuit on each patch
        output = jax.vmap(self.apply_on_batch)(patches_flat)
        output_shape = patches_shape
        
        output_shape = (self.filter_size, output_shape[2], output_shape[3])        
        # Reshape the output
        return output.reshape(output_shape)


#quantum_layer = QuantumConv2dLayer(filter_size=2, circuit_length=3, kernel_size=(2, 2), stride=(2, 2), padding_mode='SAME', in_channel=1)

In [11]:
#quantum_layer((np.random.random((1, 28, 28))))

In [12]:
class HybrideModel(eqx.Module):
    layers: list
    
    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.layers = [
            eqx.nn.Conv2d(1, 1, kernel_size=4, key=key1, stride=3),
            jax.nn.sigmoid,
            QuantumConv2dLayer(filter_size=8, circuit_length=1, kernel_size=(3, 3), stride=(2, 2), padding_mode='VALID', in_channel=1),
            eqx.nn.Conv2d(8, 12, kernel_size=2, key=key3, stride=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(48, 10, key=key3),
            jax.nn.softmax
        ]
        
    @eqx.filter_jit
    def __call__(self, x):
        
        for layer in self.layers:
            x = layer(x)
        return x
    
key, subkey = jax.random.split(key, 2)
model = HybrideModel(subkey)

In [13]:
@eqx.filter_jit
def categorical_crossentropy(model, inputs, targets):
    
    y_pred = jax.vmap(model)(inputs)
    
    y_pred = jnp.clip(y_pred, 1e-7, 1 - 1e-7)
    
    loss = -jnp.sum(targets * jnp.log(y_pred), axis=-1)
    
    return jnp.mean(loss), y_pred

@eqx.filter_jit
def categorical_accuracy(y_true, y_pred):
    true_labels = jnp.argmax(y_true, axis=-1)
    predicted_labels = jnp.argmax(y_pred, axis=-1)

    # Comparer les indices pour calculer la précision
    return jnp.mean(true_labels == predicted_labels)
    

@eqx.filter_jit
def train_step(model, opt_state, inputs, target):
    
    (loss_value, y_pred), grads = eqx.filter_value_and_grad(categorical_crossentropy, has_aux=True)(model, inputs, target)
    
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    
    accuracy = categorical_accuracy(target, y_pred)
    
    
    return model, opt_state, loss_value, accuracy

@eqx.filter_jit
def test_step(model, inputs, target):
    loss_value, y_pred = categorical_crossentropy(model, inputs, target)
    
    accuracy = categorical_accuracy(target, y_pred)
    
    return loss_value, accuracy

In [14]:
def batch(array, batch_size):
    return np.array_split(array, np.ceil(len(array) / batch_size))

In [15]:
optim = optax.adamax(0.002)
opt_state = optim.init(eqx.filter(model, eqx.is_array))
epochs = 20
batch_size = 32

In [16]:
x_train_batch = batch(x_train, batch_size=batch_size)
y_train_batch = batch(y_train, batch_size=batch_size)

x_test_batch = batch(x_test, batch_size=batch_size)
y_test_batch = batch(y_test, batch_size=batch_size)

In [17]:
for step in range(epochs):
    accuracys = deque()
    losss = deque()
    print(f"Starting epoch: {step + 1}")
    
    for (x_batch, y_batch) in tqdm(zip(x_train_batch, y_train_batch), total=len(x_train_batch)):
        model, opt_state, loss_val, accuracy = train_step(model, opt_state, x_batch, y_batch)
        accuracys.append(accuracy)
        losss.append(loss_val)
    print(f"Train Loss: {np.mean(losss)}")
    print(f"Train Accuracy: {np.mean(accuracys)}")
    
    test_accuracys = deque()
    test_losss = deque()
    for (x_batch, y_batch) in tqdm(zip(x_test_batch, y_test_batch), total=len(x_test_batch)):
        loss_val, accuracy = test_step(model, x_batch, y_batch)
        
        test_accuracys.append(accuracy)
        test_losss.append(loss_val)
    print(f"Test Loss: {np.mean(test_losss)}")
    print(f"Test Accuracy: {np.mean(test_accuracys)}")
        

Starting epoch: 1


100%|██████████| 1563/1563 [01:19<00:00, 19.60it/s]


Train Loss: 1.5509504079818726
Train Accuracy: 0.4727655053138733


100%|██████████| 313/313 [00:32<00:00,  9.72it/s] 


Test Loss: 1.0191282033920288
Test Accuracy: 0.6668491363525391
Starting epoch: 2


100%|██████████| 1563/1563 [00:26<00:00, 58.03it/s]


Train Loss: 0.8971230983734131
Train Accuracy: 0.7117419242858887


100%|██████████| 313/313 [00:01<00:00, 222.16it/s]


Test Loss: 0.765321671962738
Test Accuracy: 0.7597682476043701
Starting epoch: 3


100%|██████████| 1563/1563 [00:26<00:00, 57.96it/s]


Train Loss: 0.7299824953079224
Train Accuracy: 0.767480194568634


100%|██████████| 313/313 [00:01<00:00, 221.16it/s]


Test Loss: 0.6462812423706055
Test Accuracy: 0.797723650932312
Starting epoch: 4


100%|██████████| 1563/1563 [00:27<00:00, 57.83it/s]


Train Loss: 0.6368333697319031
Train Accuracy: 0.7987392544746399


100%|██████████| 313/313 [00:01<00:00, 221.67it/s]


Test Loss: 0.5741549730300903
Test Accuracy: 0.8203132748603821
Starting epoch: 5


100%|██████████| 1563/1563 [00:27<00:00, 57.79it/s]


Train Loss: 0.5766146779060364
Train Accuracy: 0.8180736899375916


100%|██████████| 313/313 [00:01<00:00, 221.69it/s]


Test Loss: 0.5240447521209717
Test Accuracy: 0.8378174901008606
Starting epoch: 6


100%|██████████| 1563/1563 [00:27<00:00, 57.89it/s]


Train Loss: 0.5367886424064636
Train Accuracy: 0.8307941555976868


100%|██████████| 313/313 [00:01<00:00, 222.06it/s]


Test Loss: 0.498129665851593
Test Accuracy: 0.8455567955970764
Starting epoch: 7


100%|██████████| 1563/1563 [00:27<00:00, 57.83it/s]


Train Loss: 0.508546769618988
Train Accuracy: 0.8412120342254639


100%|██████████| 313/313 [00:01<00:00, 220.24it/s]


Test Loss: 0.47465822100639343
Test Accuracy: 0.8529417514801025
Starting epoch: 8


100%|██████████| 1563/1563 [00:27<00:00, 57.60it/s]


Train Loss: 0.4848669469356537
Train Accuracy: 0.8501137495040894


100%|██████████| 313/313 [00:01<00:00, 220.16it/s]


Test Loss: 0.4518565237522125
Test Accuracy: 0.8623299598693848
Starting epoch: 9


100%|██████████| 1563/1563 [00:27<00:00, 57.57it/s]


Train Loss: 0.4636886417865753
Train Accuracy: 0.8568955063819885


100%|██████████| 313/313 [00:01<00:00, 220.25it/s]


Test Loss: 0.4357653856277466
Test Accuracy: 0.8667325377464294
Starting epoch: 10


100%|██████████| 1563/1563 [00:27<00:00, 57.57it/s]


Train Loss: 0.44566941261291504
Train Accuracy: 0.8640937805175781


100%|██████████| 313/313 [00:01<00:00, 215.08it/s]


Test Loss: 0.41635072231292725
Test Accuracy: 0.8717471957206726
Starting epoch: 11


100%|██████████| 1563/1563 [00:27<00:00, 57.45it/s]


Train Loss: 0.42392560839653015
Train Accuracy: 0.8708587288856506


100%|██████████| 313/313 [00:01<00:00, 220.36it/s]


Test Loss: 0.4053879380226135
Test Accuracy: 0.8774154782295227
Starting epoch: 12


100%|██████████| 1563/1563 [00:27<00:00, 57.54it/s]


Train Loss: 0.4014393091201782
Train Accuracy: 0.8798971772193909


100%|██████████| 313/313 [00:01<00:00, 220.56it/s]


Test Loss: 0.40277573466300964
Test Accuracy: 0.8764299154281616
Starting epoch: 13


100%|██████████| 1563/1563 [00:27<00:00, 57.53it/s]


Train Loss: 0.385634183883667
Train Accuracy: 0.884599506855011


100%|██████████| 313/313 [00:01<00:00, 219.82it/s]


Test Loss: 0.39774754643440247
Test Accuracy: 0.8774283528327942
Starting epoch: 14


100%|██████████| 1563/1563 [00:27<00:00, 57.51it/s]


Train Loss: 0.3734549582004547
Train Accuracy: 0.8884589076042175


100%|██████████| 313/313 [00:01<00:00, 219.53it/s]


Test Loss: 0.39158493280410767
Test Accuracy: 0.8802206516265869
Starting epoch: 15


100%|██████████| 1563/1563 [00:27<00:00, 57.54it/s]


Train Loss: 0.36302825808525085
Train Accuracy: 0.8918585181236267


100%|██████████| 313/313 [00:01<00:00, 219.29it/s]


Test Loss: 0.3858073055744171
Test Accuracy: 0.8827230930328369
Starting epoch: 16


100%|██████████| 1563/1563 [00:27<00:00, 57.54it/s]


Train Loss: 0.35342854261398315
Train Accuracy: 0.8945388793945312


100%|██████████| 313/313 [00:01<00:00, 220.84it/s]


Test Loss: 0.38029932975769043
Test Accuracy: 0.8828293681144714
Starting epoch: 17


100%|██████████| 1563/1563 [00:27<00:00, 57.52it/s]


Train Loss: 0.344897985458374
Train Accuracy: 0.8968981504440308


100%|██████████| 313/313 [00:01<00:00, 220.29it/s]


Test Loss: 0.37547919154167175
Test Accuracy: 0.8846200704574585
Starting epoch: 18


100%|██████████| 1563/1563 [00:27<00:00, 57.49it/s]


Train Loss: 0.33763134479522705
Train Accuracy: 0.8986575603485107


100%|██████████| 313/313 [00:01<00:00, 220.33it/s]


Test Loss: 0.37032097578048706
Test Accuracy: 0.8860242366790771
Starting epoch: 19


100%|██████████| 1563/1563 [00:27<00:00, 57.52it/s]


Train Loss: 0.33085858821868896
Train Accuracy: 0.9003982543945312


100%|██████████| 313/313 [00:01<00:00, 219.94it/s]


Test Loss: 0.36589038372039795
Test Accuracy: 0.8861272931098938
Starting epoch: 20


100%|██████████| 1563/1563 [00:27<00:00, 57.77it/s]


Train Loss: 0.32440993189811707
Train Accuracy: 0.9022377133369446


100%|██████████| 313/313 [00:01<00:00, 223.54it/s]


Test Loss: 0.36045876145362854
Test Accuracy: 0.8882303833961487
