In [None]:
#import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
import numpy as np
import matplotlib.pyplot as plt
from pennylane import AngleEmbedding, StronglyEntanglingLayers, RandomLayers
import pennylane as qml
import os
import jax
from jax import numpy as jnp
import equinox as eqx
import time
from tqdm import tqdm
import optax
import tensorflow as tf
from jaxtyping import Array, Float, PyTree
from collections import deque

In [None]:
print(jax.devices())

In [None]:
key = jax.random.PRNGKey(0)

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

In [None]:
x_train = x_train[:100_000]
y_train = y_train[:100_000]

x_test = x_test[:20_000]
y_test = y_test[:20_000]

In [None]:
x_train = x_train.astype(np.float32)
y_train = y_train.astype(np.float32)

x_test = x_test.astype(np.float32)
y_test = y_test.astype(np.float32)

In [None]:
x_train.shape

In [None]:
x_train = jnp.expand_dims(x_train, axis=1)
x_test = jnp.expand_dims(x_test, axis=1)

In [None]:
x_train.shape

In [None]:
y_test.shape, y_test[1:10]

In [None]:
y_test = jax.nn.one_hot(y_test, 10)
y_train = jax.nn.one_hot(y_train, 10)

In [None]:
y_train.shape, y_train[10:20]

In [None]:
in_channel = 2
kernel_size = [2, 2]
num_qubits = 2 * 2 * 2

In [None]:
device = qml.device("default.qubit.jax", wires=num_qubits)

@jax.jit
@qml.qnode(device, interface="jax")
def quantum_conv_circuit(inputs, weights):
    AngleEmbedding(inputs, wires=range(num_qubits))
    
    for weight_set, cell_qubit in enumerate(range(0, num_qubits, in_channel)):
        
        for channel_qubit in range(0, in_channel):
            if cell_qubit + channel_qubit != cell_qubit:
                qml.CNOT(wires=[cell_qubit + channel_qubit, cell_qubit])
        qml.Rot(weights[weight_set, 0], weights[weight_set, 1], weights[weight_set, 2], wires=cell_qubit)
        
        if cell_qubit > 0:
            qml.CNOT(wires=[cell_qubit, 0])
    
    qml.Rot(weights[-1, 0], weights[-1, 1], weights[-1, 2], wires=0)
    
    return qml.expval(qml.PauliZ(wires=0))

In [None]:
class QuantumConv2d(eqx.Module):
    filter_size: int
    padding_mode: str
    weight: jax.Array
    num_qubits: int
    quantum_conv_circuit: list
    kernel_size: tuple[int, int]
    stride: tuple[int, int]
    
    def __init__(self, filter_size, kernel_size, stride, padding_mode, in_channel, key):
        super().__init__()
        self.filter_size = filter_size
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding_mode = padding_mode
        
        self.num_qubits = in_channel * kernel_size[0] * kernel_size[1]
        self.weight = jax.random.normal(key, (filter_size, kernel_size[0] + kernel_size[1] + 1, 3))
        
        device = qml.device("default.qubit.jax", wires=self.num_qubits)
        @jax.jit
        @qml.qnode(device, interface="jax")
        def quantum_conv_circuit(inputs, weights):
            AngleEmbedding(inputs, wires=range(self.num_qubits))
            for weight_set, cell_qubit in enumerate(range(0, self.num_qubits, in_channel)):

                for channel_qubit in range(0, in_channel):
                    if cell_qubit + channel_qubit != cell_qubit:
                        qml.CNOT(wires=[cell_qubit + channel_qubit, cell_qubit])
                qml.Rot(weights[weight_set, 0], weights[weight_set, 1], weights[weight_set, 2], wires=cell_qubit)

                if cell_qubit > 0:
                    qml.CNOT(wires=[cell_qubit, 0])

                qml.Rot(weights[-1, 0], weights[-1, 1], weights[-1, 2], wires=0)

            return qml.expval(qml.PauliZ(wires=0))
        
        self.quantum_conv_circuit = [quantum_conv_circuit]
    
    @eqx.filter_jit
    def apply_on_patches(self, inputs):
        inputs = jnp.transpose(inputs)
        output = []
        
        for i in range(self.filter_size):
            
            output.append(
                jnp.expand_dims(
                    jax.vmap(
                        self.quantum_conv_circuit[0], in_axes=(0, None)
                    )(inputs, self.weight[i]),
                    axis=0
                )
            )
        
        return  jnp.concatenate(output, axis=0)
    
    @eqx.filter_jit
    def __call__(self, inputs):
        
        inputs = jnp.expand_dims(inputs, axis=0)
        patches = jax.lax.conv_general_dilated_patches(inputs, self.kernel_size, self.stride, self.padding_mode)
        patches_shape = patches.shape
        
        patches_flat = patches.reshape((patches_shape[0], patches_shape[1], patches_shape[2] * patches_shape[3]))
        
        
        output = jax.vmap(self.apply_on_patches)(patches_flat)
        
        output_shape = patches_shape
        output_shape = (self.filter_size, output_shape[2], output_shape[3])
        
        return output.reshape(output_shape)
        

In [None]:
class HybrideModel(eqx.Module):
    layers: list
    
    def __init__(self, key):
        
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.layers = [
            eqx.nn.Conv2d(1, 1, kernel_size=4, key=key1, stride=3),
            jax.nn.sigmoid,
            QuantumConv2d(filter_size=8, kernel_size=[3, 3], stride=[2, 2], padding_mode="VALID", in_channel=1, key=key2),
            eqx.nn.Conv2d(8, 12, kernel_size=2, key=key3, stride=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(48, 10, key=key4),
            jax.nn.softmax
        ]
        
    @eqx.filter_jit
    def __call__(self, x):
        
        for layer in self.layers:
            x = layer(x)
        return x
    
key, subkey = jax.random.split(key, 2)
model = HybrideModel(subkey)

In [None]:
def confusion_matrix(y_true, y_pred, num_classes):
    
    minlength = num_classes**2
    return jnp.bincount(num_classes * y_true.astype(jnp.int32) + y_pred, minlength=minlength).reshape((num_classes, num_classes))


def precision_recall_f1(y_true, y_pred, num_classes):
    
    cm = confusion_matrix(y_true, y_pred, num_classes)
    true_positives = jnp.diag(cm)
    pred_positives = jnp.sum(cm, axis=0)
    real_positives = jnp.sum(cm, axis=1)

    precision = true_positives / jnp.maximum(1.0, pred_positives)
    recall = true_positives / jnp.maximum(1.0, real_positives)

    f1_score = 2 * precision * recall / (precision + recall + 1e-7)
    return precision, recall, f1_score

def classification_report(model, x_test, y_target):
    all_y_pred = []
    all_y_target = []
    for (x_batch, y_batch) in tqdm(zip(x_test, y_target), total=len(x_test_batch)):
        y_pred = jax.vmap(model)(x_batch)
        y_pred = jnp.argmax(y_pred, axis=-1)
        y_batch = jnp.argmax(y_batch, axis=-1)
        
        all_y_pred.append(y_pred)
        all_y_target.append(y_batch)
    
    all_y_pred = jnp.concatenate(all_y_pred, axis=0)
    all_y_target = jnp.concatenate(all_y_target, axis=0)
    
    _, _, f1_score = precision_recall_f1(all_y_target, all_y_pred, 10)
    return f1_score

In [None]:
@eqx.filter_jit
def categorical_crossentropy(model, inputs, targets):
    
    y_pred = jax.vmap(model)(inputs)
    
    y_pred = jnp.clip(y_pred, 1e-7, 1 - 1e-7)
    
    loss = -jnp.sum(targets * jnp.log(y_pred), axis=-1)
    
    return jnp.mean(loss), y_pred

@eqx.filter_jit
def categorical_accuracy(y_true, y_pred):
    true_labels = jnp.argmax(y_true, axis=-1)
    predicted_labels = jnp.argmax(y_pred, axis=-1)

    # Comparer les indices pour calculer la précision
    return jnp.mean(true_labels == predicted_labels)
    

@eqx.filter_jit
def train_step(model, opt_state, inputs, target):

    (loss_value, y_pred), grads = eqx.filter_value_and_grad(categorical_crossentropy, has_aux=True)(model, inputs, target)
    
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    
    accuracy = categorical_accuracy(target, y_pred)
    
    return model, opt_state, loss_value, accuracy
    

@eqx.filter_jit
def test_step(model, inputs, target):
    loss_value, y_pred = categorical_crossentropy(model, inputs, target)
    
    accuracy = categorical_accuracy(target, y_pred)
    
    return loss_value, accuracy
    
    

In [None]:
optim = optax.adamax(0.002)
opt_state = optim.init(eqx.filter(model, eqx.is_array))
epochs = 20
batch_size = 8

In [None]:
def batch(array, batch_size):
    return np.array_split(array, np.ceil(len(array) / batch_size))

In [None]:
x_train_batch = batch(x_train, batch_size=batch_size)
y_train_batch = batch(y_train, batch_size=batch_size)

x_test_batch = batch(x_test, batch_size=batch_size)
y_test_batch = batch(y_test, batch_size=batch_size)

In [None]:
for step in range(epochs):
    accuracys = deque()
    losss = deque()
    print(f"Starting epoch: {step + 1}")
    
    for (x_batch, y_batch) in tqdm(zip(x_train_batch, y_train_batch), total=len(x_train_batch)):
        model, opt_state, loss_value, accuracy = train_step(model, opt_state, x_batch, y_batch)
        
        accuracys.append(accuracy)
        losss.append(loss_value)
    
    print(f"Train Loss: {np.mean(losss)} Train Accuracy: {np.mean(accuracys)}")
    
    test_accuracys = deque()
    test_losss = deque()
    for (x_batch, y_batch) in tqdm(zip(x_test_batch, y_test_batch), total=len(x_test_batch)):
        loss_value, accuracy = test_step(model, x_batch, y_batch)
        
        test_accuracys.append(accuracy)
        test_losss.append(loss_value)
    
    print(f"Test Loss: {np.mean(test_losss)} Test Accuracy {np.mean(test_accuracys)}")

In [None]:
list(classification_report(model, x_test_batch, y_test_batch))