In [20]:
!pip install pennylane



In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as transforms
import pennylane as qml
import numpy as np
from typing import Tuple, List
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

    Quantum filter pour extraire des features d'un patch 2x2

    Paramètres optimisables:
    - params : Tensor de forme (n_layers, 2, n_qubits)
      * params[layer, 0, qubit] : angle θ pour RY(θ)
      * params[layer, 1, qubit] : angle φ pour RZ(φ)

    Pour n_layers=2, n_qubits=4 : 2 × 2 × 4 = 16 paramètres par filtre

        ⚠️ PARAMÈTRES QUANTIQUES OPTIMISABLES ⚠️
        Ces paramètres seront mis à jour par backpropagation !
        Forme: (n_layers, 2, n_qubits)
        - Dimension 0: layer du circuit (0 à n_layers-1)
        - Dimension 1: type de rotation (0=RY, 1=RZ)
        - Dimension 2: qubit index (0 à n_qubits-1)

        Circuit quantique variationnel

        Structure du circuit:
        1. ENCODING: RX(θᵢ) pour encoder les pixels
        2. LAYER 1: CNOT → RY(φ₁) → RZ(ψ₁)  [paramètres optimisables]
        3. LAYER 2: CNOT → RY(φ₂) → RZ(ψ₂)  [paramètres optimisables]
        4. MEASUREMENT: ⟨Zᵢ⟩ pour chaque qubit

        Paramètres:
        - inputs: angles d'encoding [θ₀, θ₁, θ₂, θ₃] (NON optimisables, viennent des pixels)
        - params: angles variationnels (OPTIMISABLES via gradient descent)

        Traite un patch 2x2 et retourne une feature scalaire

        Args:
            patch: Tensor de taille (4,) avec valeurs [0, 255]

        Returns:
            Feature scalaire agrégée (Tensor avec gradient)

In [None]:
class QuantumFilter(nn.Module):
    def __init__(self, n_qubits: int = 4, n_layers: int = 2):
        super(QuantumFilter, self).__init__()

        self.n_qubits = n_qubits
        self.n_layers = n_layers

        # Création du device quantique (simulateur)
        self.dev = qml.device('default.qubit', wires=n_qubits)
        self.params = nn.Parameter(
            torch.tensor(
                np.random.uniform(0, 2*np.pi, (n_layers, 2, n_qubits)),
                dtype=torch.float32,
                requires_grad=True  # ← Gradient activé !
            )
        )

        print(f"    QuantumFilter initialized: {self.params.numel()} trainable quantum parameters")

        self.qnode = qml.QNode(self._circuit, self.dev, interface='torch', diff_method='backprop')

    def _circuit(self, inputs, params):
        for i in range(self.n_qubits):
            qml.RX(inputs[i], wires=i)

        for layer in range(self.n_layers):
            for i in range(self.n_qubits - 1):
                qml.CNOT(wires=[i, i+1])

            for i in range(self.n_qubits):
                qml.RY(params[layer, 0, i], wires=i)

            for i in range(self.n_qubits):
                qml.RZ(params[layer, 1, i], wires=i)

        return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]

    def forward(self, patch: torch.Tensor) -> torch.Tensor:

        theta = (patch / 255.0) * 2 * np.pi

        measurements = self.qnode(theta, self.params)

        feature = sum(measurements) / len(measurements)

        return feature  

    def get_params(self) -> torch.Tensor:
        """Retourne les paramètres quantiques"""
        return self.params.data

    def set_params(self, params: torch.Tensor):
        """Met à jour les paramètres quantiques"""
        self.params.data = params

    Couche quanvolutionnelle complète avec N filtres quantiques

    Chaque filtre a ses propres paramètres VQC optimisables
    Pour n_filters=8, n_layers=2, n_qubits=4:
    Total paramètres quantiques = 8 × (2 × 2 × 4) = 128 paramètres

        Applique tous les filtres quantiques à l'image

        Args:
            image: Tensor de taille (32, 32, 1) - grayscale

        Returns:
            Feature map de taille (16, 16, n_filters)

In [None]:
class QuanvolutionalLayer(nn.Module):

    def __init__(self, n_filters: int = 8, n_qubits: int = 4, n_layers: int = 2):
        super(QuanvolutionalLayer, self).__init__()

        self.n_filters = n_filters

        self.filters = nn.ModuleList([
            QuantumFilter(n_qubits, n_layers) for _ in range(n_filters)
        ])

        total_quantum_params = sum(f.params.numel() for f in self.filters)
        print(f"  QuanvolutionalLayer: {n_filters} filters × {n_qubits} qubits × {n_layers} layers")
        print(f"  Total quantum parameters: {total_quantum_params}")

    def forward(self, image: torch.Tensor) -> torch.Tensor:

        H, W = image.shape[:2]
        patch_size = 2
        stride = 2

        out_h = (H - patch_size) // stride + 1  
        out_w = (W - patch_size) // stride + 1  

        features_list = []

        for filter_idx in range(self.n_filters):
            print(f"  Processing filter {filter_idx + 1}/{self.n_filters}...", end='\r')

            filter_output = torch.zeros(out_h, out_w)

            for i in range(out_h):
                for j in range(out_w):
                    # Extraction du patch 2x2
                    h_start = i * stride
                    w_start = j * stride
                    patch = image[h_start:h_start+patch_size, w_start:w_start+patch_size, 0]

                    patch_flat = patch.flatten()

                    feature = self.filters[filter_idx].forward(patch_flat)
                    filter_output[i, j] = feature

            features_list.append(filter_output)

        print()  

        # Stack pour obtenir (n_filters, 16, 16) puis permute vers (16, 16, n_filters)
        feature_map = torch.stack(features_list, dim=0)  # Shape: (n_filters, 16, 16)
        feature_map = feature_map.permute(1, 2, 0)  # Shape: (16, 16, n_filters)

        return feature_map.float()

    def get_all_params(self) -> List[torch.Tensor]:
        return [f.get_params() for f in self.filters]

    def set_all_params(self, params_list: List[torch.Tensor]):
        for i, params in enumerate(params_list):
            self.filters[i].set_params(params)

    Quanvolutional Neural Network complète

    PARAMÈTRES OPTIMISABLES:
    1. Paramètres quantiques (VQC):
       - n_filters × n_layers × 2 × n_qubits paramètres
       - Exemple: 8 × 2 × 2 × 4 = 128 paramètres quantiques

    2. Paramètres classiques (FC layers):
       - FC1: (256×n_filters) × 128 + 128 bias
       - FC2: 128 × 64 + 64 bias
       - FC3: 64 × n_classes + n_classes bias
       - Exemple pour n_filters=8, n_classes=10:
         (2048×128 + 128) + (128×64 + 64) + (64×10 + 10) = 270,666 paramètres

    TOTAL: ~128 quantiques + ~270k classiques = ~270k paramètres

        Convertit RGB en grayscale

        Args:
            image: Tensor de taille (32, 32, 3)

        Returns:
            Tensor de taille (32, 32, 1)

        Forward pass complet

        Args:
            x: Batch d'images de taille (batch_size, 32, 32, 3)

        Returns:
            Probabilités de classe (batch_size, n_classes)

In [None]:
class QuanNN(nn.Module):

    def __init__(self, n_filters: int = 8, n_classes: int = 10,
                 n_qubits: int = 4, n_layers: int = 2):
        super(QuanNN, self).__init__()

        self.n_filters = n_filters
        self.n_classes = n_classes

        print(f"\n{'='*60}")
        print(f"Initializing QuanNN Model")
        print(f"{'='*60}")

        # Couche quanvolutionnelle
        self.quanv_layer = QuanvolutionalLayer(n_filters, n_qubits, n_layers)

        flatten_dim = 256 * n_filters

        # Couches fully connected classiques
        self.fc1 = nn.Linear(flatten_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, n_classes)

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

        # Calcul du nombre de paramètres
        quantum_params = sum(p.numel() for p in self.quanv_layer.parameters())
        classical_params = sum(p.numel() for p in [*self.fc1.parameters(),
                                                     *self.fc2.parameters(),
                                                     *self.fc3.parameters()])
        total_params = quantum_params + classical_params

        print(f"\n  Parameter Breakdown:")
        print(f"  {'─'*58}")
        print(f"  Quantum (VQC) parameters:    {quantum_params:>10,}")
        print(f"  Classical (FC) parameters:   {classical_params:>10,}")
        print(f"  {'─'*58}")
        print(f"  TOTAL trainable parameters:  {total_params:>10,}")
        print(f"  {'='*60}\n")

    def rgb_to_grayscale(self, image: torch.Tensor) -> torch.Tensor:

        if image.shape[2] == 3:
            # Formule de conversion: 0.299*R + 0.587*G + 0.114*B
            weights = torch.tensor([0.299, 0.587, 0.114], dtype=image.dtype)
            gray = torch.sum(image * weights, dim=2, keepdim=True)
            return gray
        return image

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        outputs = []

        for i in range(batch_size):
            print(f"Processing image {i+1}/{batch_size}")

            # Step 1: Conversion en grayscale
            img = x[i]
            gray_img = self.rgb_to_grayscale(img)

            # Step 2-5: Application de la couche quanvolutionnelle
            feature_map = self.quanv_layer.forward(gray_img)
            print(f"  Shape after quanv_layer: {feature_map.shape}")

            # Step 6: Flatten - Reshape en vecteur 1D
            # feature_map shape: (16, 16, n_filters)
            # After flatten: (16 * 16 * n_filters) = (256 * n_filters)
            features = feature_map.reshape(-1)  # Flatten all dimensions
            print(f"  Shape after flatten: {features.shape}")

            outputs.append(features)

        # Stack all outputs
        batch_features = torch.stack(outputs).float()
        print(f"Shape after stacking batch features: {batch_features.shape}")

        # Fully connected layers
        h1 = self.relu(self.fc1(batch_features))
        h2 = self.relu(self.fc2(h1))
        logits = self.fc3(h2)

        # Softmax for probabilities
        probs = self.softmax(logits)

        return probs

In [None]:
def test_quannn(model: QuanNN, test_loader: DataLoader, device='cpu'):
    model.eval()
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    test_loss = 0
    criterion = nn.CrossEntropyLoss()

    print("\n=== Testing QuanNN ===")

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)

            print(f"Testing batch {batch_idx + 1}/{len(test_loader)}...", end='\r')

            # Forward pass
            outputs = model(images)

            # Calcul de la loss
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            # Prédictions
            _, predicted = torch.max(outputs.data, 1)

            # Statistiques
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Sauvegarde pour matrice de confusion
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calcul des métriques
    accuracy = 100 * correct / total
    avg_loss = test_loss / len(test_loader)

    print(f"\n\n{'='*50}")
    print(f"Test Results:")
    print(f"{'='*50}")
    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.2f}% ({correct}/{total})")
    print(f"{'='*50}\n")

    return accuracy, all_predictions, all_labels

In [None]:
def plot_confusion_matrix(y_true, y_pred, class_names=None):
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    print("Confusion matrix saved as 'confusion_matrix.png'")
    plt.close()

In [None]:
def plot_training_history(train_losses, train_accs, val_losses=None, val_accs=None):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Loss
    ax1.plot(train_losses, label='Train Loss', marker='o')
    if val_losses:
        ax1.plot(val_losses, label='Val Loss', marker='s')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)

    # Accuracy
    ax2.plot(train_accs, label='Train Accuracy', marker='o')
    if val_accs:
        ax2.plot(val_accs, label='Val Accuracy', marker='s')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.savefig('training_history.png')
    print("Training history saved as 'training_history.png'")
    plt.close()

In [None]:
def visualize_predictions(model: QuanNN, test_loader: DataLoader,
                         class_names=None, n_images=10):
    model.eval()
    images, labels = next(iter(test_loader))

    with torch.no_grad():
        outputs = model(images[:n_images])
        _, predictions = torch.max(outputs, 1)

    # Affichage
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.ravel()

    for i in range(min(n_images, len(images))):
        img = images[i].cpu().numpy()

        # Conversion pour affichage
        if img.shape[2] == 3:
            img = img.astype(np.uint8)
        else:
            img = img[:, :, 0].astype(np.uint8)

        axes[i].imshow(img, cmap='gray' if len(img.shape) == 2 else None)

        true_label = class_names[labels[i]] if class_names else labels[i].item()
        pred_label = class_names[predictions[i]] if class_names else predictions[i].item()

        color = 'green' if predictions[i] == labels[i] else 'red'
        axes[i].set_title(f'True: {true_label}\nPred: {pred_label}', color=color)
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig('predictions_visualization.png')
    print("Predictions visualization saved as 'predictions_visualization.png'")
    plt.close()

In [None]:
def load_cifar10_subset(n_classes=3, n_train=100, n_test=30):
    print(f"\n=== Loading CIFAR-10 Subset ===")
    print(f"Classes: {n_classes}, Train samples: {n_train}, Test samples: {n_test}")

    # Transformation: redimensionne à 32x32 (déjà la bonne taille pour CIFAR-10)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 255)  # Pixel values [0, 255]
    ])

    # Chargement de CIFAR-10
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                           download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                          download=True, transform=transform)

    # Sélection des n_classes premières classes
    class_indices_train = [i for i, label in enumerate(trainset.targets) if label < n_classes]
    class_indices_test = [i for i, label in enumerate(testset.targets) if label < n_classes]

    # Sous-échantillonnage
    selected_train = class_indices_train[:n_train]
    selected_test = class_indices_test[:n_test]

    # Création des sous-ensembles
    train_subset = torch.utils.data.Subset(trainset, selected_train)
    test_subset = torch.utils.data.Subset(testset, selected_test)

    # Noms des classes
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck'][:n_classes]

    print(f"Loaded: {len(train_subset)} train, {len(test_subset)} test samples")
    print(f"Classes: {class_names}\n")

    return train_subset, test_subset, class_names

In [None]:
def visualize_predictions(model: QuanNN, test_loader: DataLoader,
                         class_names=None, n_images=10):
    model.eval()
    images, labels = next(iter(test_loader))

    with torch.no_grad():
        outputs = model(images[:n_images])
        _, predictions = torch.max(outputs, 1)

    # Affichage
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.ravel()

    for i in range(min(n_images, len(images))):
        img = images[i].cpu().numpy()

        # Conversion pour affichage
        if img.shape[2] == 3:
            img = img.astype(np.uint8)
        else:
            img = img[:, :, 0].astype(np.uint8)

        axes[i].imshow(img, cmap='gray' if len(img.shape) == 2 else None)

        true_label = class_names[labels[i]] if class_names else labels[i].item()
        pred_label = class_names[predictions[i]] if class_names else predictions[i].item()

        color = 'green' if predictions[i] == labels[i] else 'red'
        axes[i].set_title(f'True: {true_label}\nPred: {pred_label}', color=color)
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig('predictions_visualization.png')
    print("Predictions visualization saved as 'predictions_visualization.png'")
    plt.close()

In [None]:
def load_mnist_subset(n_classes=3, n_train=100, n_test=30):
    print(f"\n=== Loading MNIST Subset ===")
    print(f"Classes: {n_classes}, Train samples: {n_train}, Test samples: {n_test}")

    # Transformation: redimensionne à 32x32 et convertit en RGB avec le bon format (H, W, C)
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),  # Donne (1, 32, 32)
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),  # (3, 32, 32)
        transforms.Lambda(lambda x: x.permute(1, 2, 0)),  # (32, 32, 3) - AJOUT CRUCIAL
        transforms.Lambda(lambda x: x * 255)  # Pixel values [0, 255]
    ])

    # Chargement de MNIST
    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                         download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False,
                                        download=True, transform=transform)

    # Sélection des n_classes premières classes
    class_indices_train = [i for i, label in enumerate(trainset.targets) if label < n_classes]
    class_indices_test = [i for i, label in enumerate(testset.targets) if label < n_classes]

    # Sous-échantillonnage
    selected_train = class_indices_train[:n_train]
    selected_test = class_indices_test[:n_test]

    # Création des sous-ensembles
    train_subset = torch.utils.data.Subset(trainset, selected_train)
    test_subset = torch.utils.data.Subset(testset, selected_test)

    class_names = [str(i) for i in range(n_classes)]

    print(f"Loaded: {len(train_subset)} train, {len(test_subset)} test samples")
    print(f"Classes: {class_names}\n")

    return train_subset, test_subset, class_names

In [None]:
def save_model(model: QuanNN, filepath='quannn_model.pth'):
    torch.save({
        'model_state_dict': model.state_dict(),  # Contient TOUS les paramètres
        'quantum_params': model.quanv_layer.get_all_params(),
        'n_filters': model.n_filters,
        'n_classes': model.n_classes
    }, filepath)
    print(f"✓ Model saved to {filepath}")

In [None]:
def load_model(filepath='quannn_model.pth', n_qubits=4, n_layers=2):
    checkpoint = torch.load(filepath)

    model = QuanNN(
        n_filters=checkpoint['n_filters'],
        n_classes=checkpoint['n_classes'],
        n_qubits=n_qubits,
        n_layers=n_layers
    )

    model.load_state_dict(checkpoint['model_state_dict'])
    model.quanv_layer.set_all_params(checkpoint['quantum_params'])

    print(f"✓ Model loaded from {filepath}")
    return model

In [None]:
def print_parameter_details(model: QuanNN):
    print("\n" + "="*70)
    print("DETAILED PARAMETER ANALYSIS")
    print("="*70)

    print("\n📊 QUANTUM PARAMETERS (VQC):")
    print("─"*70)
    for idx, qfilter in enumerate(model.quanv_layer.filters):
        params = qfilter.params
        print(f"  Filter {idx+1}:")
        print(f"    Shape: {params.shape} = (layers={params.shape[0]}, "
              f"rotations={params.shape[1]}, qubits={params.shape[2]})")
        print(f"    Total: {params.numel()} parameters")
        print(f"    Requires grad: {params.requires_grad}")
        print(f"    Example values (first layer, RY rotations):")
        print(f"      {params[0, 0, :].detach().numpy()}")
        if idx == 2:  # Limite l'affichage
            print(f"    ... ({model.n_filters - 3} more filters)")
            break

    quantum_total = sum(f.params.numel() for f in model.quanv_layer.filters)
    print(f"\n  ✓ Total Quantum Params: {quantum_total}")

    print("\n📊 CLASSICAL PARAMETERS (Fully Connected):")
    print("─"*70)

    for name, param in model.named_parameters():
        if 'fc' in name:
            print(f"  {name}:")
            print(f"    Shape: {param.shape}")
            print(f"    Total: {param.numel()} parameters")
            print(f"    Requires grad: {param.requires_grad}")

    classical_total = sum(p.numel() for name, p in model.named_parameters() if 'fc' in name)
    print(f"\n  ✓ Total Classical Params: {classical_total}")

    print("\n" + "="*70)
    print(f"GRAND TOTAL: {quantum_total + classical_total:,} trainable parameters")
    print(f"  - Quantum: {quantum_total:,} ({100*quantum_total/(quantum_total+classical_total):.2f}%)")
    print(f"  - Classical: {classical_total:,} ({100*classical_total/(quantum_total+classical_total):.2f}%)")
    print("="*70 + "\n")

In [None]:
def run_complete_experiment(dataset='cifar10', n_filters=4, n_classes=3,
                           n_train=60, n_test=20, n_epochs=3, batch_size=2):
    print("\n" + "="*60)
    print("QuanNN - Complete Experiment")
    print("="*60)

    # 1. Chargement des données
    if dataset == 'cifar10':
        train_data, test_data, class_names = load_cifar10_subset(
            n_classes=n_classes, n_train=n_train, n_test=n_test
        )
    elif dataset == 'mnist':
        train_data, test_data, class_names = load_mnist_subset(
            n_classes=n_classes, n_train=n_train, n_test=n_test
        )
    else:
        raise ValueError("Dataset must be 'cifar10' or 'mnist'")

    # Création des dataloaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # 2. Création du modèle
    print(f"\nCreating QuanNN model...")
    print(f"  - Quantum filters: {n_filters}")
    print(f"  - Classes: {n_classes}")
    print(f"  - VQC layers: 2")

    model = QuanNN(n_filters=n_filters, n_classes=n_classes,
                   n_qubits=4, n_layers=2)

    # Affichage de l'architecture
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  - Total trainable parameters: {total_params:,}")

    # 3. Entraînement
    train_losses, train_accs, val_losses, val_accs = train_quannn(
        model, train_loader, val_loader=test_loader,
        n_epochs=n_epochs, lr=0.01
    )

    # 4. Test final
    accuracy, predictions, labels = test_quannn(model, test_loader)

    # 5. Visualisations
    print("\n=== Generating Visualizations ===")
    plot_confusion_matrix(labels, predictions, class_names)
    visualize_predictions(model, test_loader, class_names, n_images=10)

    # 6. Rapport de classification
    print("\n=== Classification Report ===")
    print(classification_report(labels, predictions,
                                target_names=class_names,
                                zero_division=0))

    # 7. Sauvegarde du modèle
    save_model(model, f'quannn_{dataset}_model.pth')

    print("\n" + "="*60)
    print("Experiment Complete!")
    print("="*60)
    print(f"\nFiles generated:")
    print(f"  - confusion_matrix.png")
    print(f"  - training_history.png")
    print(f"  - predictions_visualization.png")
    print(f"  - quannn_{dataset}_model.pth")
    print("\n")

    return model, accuracy

In [36]:
def train_quannn(model: QuanNN, train_loader: DataLoader,
                 val_loader: DataLoader = None, n_epochs: int = 5,
                 lr: float = 0.01, device='cuda'):
    """
    Entraîne le modèle QuanNN avec suivi complet
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []

    print("\n=== Training QuanNN ===\n")

    for epoch in range(n_epochs):
        # === TRAINING ===
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        print(f"Epoch [{epoch+1}/{n_epochs}] - Training...")

        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)

            # Calcul de la loss
            loss = criterion(outputs, labels)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Statistiques
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            print(f"  Batch {batch_idx+1}/{len(train_loader)} - Loss: {loss.item():.4f}", end='\r')

        train_accuracy = 100 * correct / total
        train_loss = total_loss / len(train_loader)
        train_losses.append(train_loss)
        train_accs.append(train_accuracy)

        print(f"\n  Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")

        # === VALIDATION ===
        if val_loader is not None:
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            print(f"  Validating...")

            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

            val_accuracy = 100 * val_correct / val_total
            val_loss_avg = val_loss / len(val_loader)
            val_losses.append(val_loss_avg)
            val_accs.append(val_accuracy)

            print(f"  Val Loss: {val_loss_avg:.4f}, Val Accuracy: {val_accuracy:.2f}%")

        print(f"{'-'*60}\n")

    print("=== Training Complete ===\n")

    # Affichage de l'historique
    if val_loader:
        plot_training_history(train_losses, train_accs, val_losses, val_accs)
    else:
        plot_training_history(train_losses, train_accs)

    return train_losses, train_accs, val_losses, val_accs

In [None]:


    print("\n" + "="*60)
    print("QuanNN - Testing Options")
    print("="*60)
    print("\nChoose an option:")
    print("  1. Quick test (CIFAR-10, 3 classes, 60 train/20 test)")
    print("  2. MNIST test (3 classes, 60 train/20 test)")
    print("  3. Extended CIFAR-10 (5 classes, 200 train/50 test)")
    print("  4. Load and test existing model")
    print("  5. Simple forward pass demo")

    choice = input("\nEnter choice (1-5): ").strip()

    if choice == '1':
        # Test rapide avec CIFAR-10
        print("\n[Option 1] Quick CIFAR-10 Test")
        model, accuracy = run_complete_experiment(
            dataset='cifar10',
            n_filters=4,
            n_classes=3,
            n_train=60,
            n_test=20,
            n_epochs=3,
            batch_size=2
        )

    elif choice == '2':
        # Test avec MNIST
        print("\n[Option 2] MNIST Test")
        model, accuracy = run_complete_experiment(
            dataset='mnist',
            n_filters=2,
            n_classes=3,
            n_train=20,
            n_test=10,
            n_epochs=2,
            batch_size=2
        )

    elif choice == '3':
        # Test étendu
        print("\n[Option 3] Extended CIFAR-10 Test")
        print("WARNING: This will take significantly longer!")
        confirm = input("Continue? (y/n): ").strip().lower()
        if confirm == 'y':
            model, accuracy = run_complete_experiment(
                dataset='cifar10',
                n_filters=8,
                n_classes=5,
                n_train=200,
                n_test=50,
                n_epochs=5,
                batch_size=2
            )

    elif choice == '4':
        # Chargement et test d'un modèle existant
        print("\n[Option 4] Load Existing Model")
        filepath = input("Enter model path (default: quannn_cifar10_model.pth): ").strip()
        if not filepath:
            filepath = 'quannn_cifar10_model.pth'

        try:
            model = load_model(filepath)

            # Chargement des données de test
            _, test_data, class_names = load_cifar10_subset(n_classes=3, n_test=20)
            test_loader = DataLoader(test_data, batch_size=2, shuffle=False)

            # Test
            accuracy, predictions, labels = test_quannn(model, test_loader)
            plot_confusion_matrix(labels, predictions, class_names)
            visualize_predictions(model, test_loader, class_names)

        except FileNotFoundError:
            print(f"Error: Model file '{filepath}' not found!")

    elif choice == '5':
        # Démo simple du forward pass
        print("\n[Option 5] Simple Forward Pass Demo")

        # Création d'une image factice
        print("\nCreating random test image (32x32x3)...")
        test_img = torch.rand(1, 32, 32, 3) * 255

        # Création du modèle
        print("Creating QuanNN model (4 filters, 3 classes)...")
        model = QuanNN(n_filters=4, n_classes=3, n_qubits=4, n_layers=2)
        # Affichage détaillé des paramètres
        print_parameter_details(model)

        # Forward pass
        with torch.no_grad():
            output = model(test_img)

        print("\n" + "="*60)
        print("Results:")
        print("="*60)
        print(f"Output shape: {output.shape}")
        print(f"Class probabilities: {output[0].numpy()}")
        print(f"Predicted class: {torch.argmax(output[0]).item()}")
        print(f"Confidence: {torch.max(output[0]).item():.2%}")
        print("="*60)

        # Forward pass avec gradient
        output = model(test_img)
        loss = -torch.log(output[0, 0])  # Simple loss

        # Backward pass
        loss.backward()

        print("\n✓ Gradient check:")
        has_quantum_grad = any(f.params.grad is not None for f in model.quanv_layer.filters)
        has_classical_grad = model.fc1.weight.grad is not None

        print(f"  Quantum params have gradients: {has_quantum_grad}")
        print(f"  Classical params have gradients: {has_classical_grad}")

        if has_quantum_grad:
            print(f"\n  Example quantum gradient (Filter 1, Layer 1, RY):")
            print(f"    {model.quanv_layer.filters[0].params.grad[0, 0, :].numpy()}")


    else:
        print("\nInvalid choice! Running default demo...")
        print("\n[Default] Simple Forward Pass Demo")
        test_img = torch.rand(1, 32, 32, 3) * 255
        model = QuanNN(n_filters=4, n_classes=3, n_qubits=4, n_layers=2)

        with torch.no_grad():
            output = model(test_img)

        print(f"\nOutput probabilities: {output[0].numpy()}")
        print(f"Predicted class: {torch.argmax(output[0]).item()}")



QuanNN - Testing Options

Choose an option:
  1. Quick test (CIFAR-10, 3 classes, 60 train/20 test)
  2. MNIST test (3 classes, 60 train/20 test)
  3. Extended CIFAR-10 (5 classes, 200 train/50 test)
  4. Load and test existing model
  5. Simple forward pass demo

Enter choice (1-5): 2

[Option 2] MNIST Test

QuanNN - Complete Experiment

=== Loading MNIST Subset ===
Classes: 3, Train samples: 20, Test samples: 10
Loaded: 20 train, 10 test samples
Classes: ['0', '1', '2']


Creating QuanNN model...
  - Quantum filters: 2
  - Classes: 3
  - VQC layers: 2

Initializing QuanNN Model
    QuantumFilter initialized: 16 trainable quantum parameters
    QuantumFilter initialized: 16 trainable quantum parameters
  QuanvolutionalLayer: 2 filters × 4 qubits × 2 layers
  Total quantum parameters: 32

  Parameter Breakdown:
  ──────────────────────────────────────────────────────────
  Quantum (VQC) parameters:            32
  Classical (FC) parameters:       74,115
  ─────────────────────────────