# 0. Prerequisites

## Hardware Information

In [None]:
!lscpu

In [None]:
!nvidia-smi

In [None]:
!nvcc -V

## Environment Variables

In [None]:
# TODO: Check out again while using kokkos
!export OMP_NUM_THREADS=4
!export OMP_PROC_BIND=spread
!export OMP_PLACES=threads

## Install Dependencies

In [None]:
USE_GPU = False

In [None]:
# Basic Libraries
%pip install tqdm
%pip install torch torchvision torchaudio tensorboard
%pip install numpy scipy scikit-learn matplotlib pandas
%pip install pennylane --upgrade
%pip install pennylane-lightning
%pip install pennylane-lightning-kokkos
%pip install pennylane-qulacs["cpu"]
%pip install pyquafu

In [None]:
if USE_GPU:
    %pip install -v --no-cache-dir cuquantum-python-cu12
    %pip install nvidia-cuda-cupti-cu12 == 12.1.105
    %pip install nvidia-cuda-nvrtc-cu12 == 12.1.105
    %pip install nvidia-cudnn-cu12 == 8.9.2.26
    %pip install nvidia-cufft-cu12 == 11.0.2.54
    %pip install nvidia-curand-cu12 == 10.3.2.106
    %pip install nvidia-cusolver-cu12 == 11.4.5.107
    %pip install nvidia-nccl-cu12 == 2.19.3
    %pip install nvidia-nvtx-cu12 == 12.1.105
    %pip install nvidia-cusparse-cu12 == 12.1.0.106
    %pip install nvidia-cublas-cu12 == 12.1.3.1
    %pip install nvidia-cuda-runtime-cu12 == 12.1.105
    # %pip install nvidia-cusparse-cu12
    # %pip install nvidia-cublas-cu12
    # %pip install nvidia-cuda-runtime-cu12
    %pip install custatevec_cu12
    %pip install pennylane-lightning[gpu]
    # %pip install pennylane-qulacs["gpu"]

# 1. Import Libraries

In [None]:
import json
import logging
import math
import os
import random
import time
from collections.abc import Iterable
from functools import wraps
from typing import Any, Callable, Dict, Optional, Type, Union, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pennylane as qml
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image
from quafu import QuantumCircuit, User, Task, simulate
from sklearn.metrics import recall_score, f1_score, confusion_matrix
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

# 2. Quantum Convolutional Components

## 2.1 Quantum Convolutional Kernel

In [None]:
ArrayLike = Union[list, np.ndarray, torch.Tensor]


class QKernel:

    def __init__(
            self,
            quantum_channels: int,
            kernel_size: Union[int, Tuple[int, int]] = 2,
            num_param_blocks: int = 2,
            kernel_circuit: Callable[[ArrayLike], None] = None,
            weight_shapes: Dict[str, Tuple[int, ...]] = None,
    ):
        """Quantum Kernel"""
        self.validate_params(quantum_channels, kernel_size, num_param_blocks, kernel_circuit, weight_shapes)

        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.num_qubits = quantum_channels * self.kernel_size[0] * self.kernel_size[1]
        self.num_param_blocks = num_param_blocks

        self.circuit = kernel_circuit if kernel_circuit else self._default_circuit
        self.weight_shapes = weight_shapes if weight_shapes else {"weights": (num_param_blocks, 2 * self.num_qubits)}

    @staticmethod
    def validate_params(
            quantum_channels: int,
            kernel_size: Union[int, Tuple[int, int]],
            num_param_blocks: int,
            kernel_circuit: Callable[[ArrayLike], None],
            weight_shapes: dict[str, tuple]
    ):
        if not isinstance(quantum_channels, int) or quantum_channels <= 0:
            raise ValueError("quantum_channels must be a positive integer")
        if (not isinstance(kernel_size, (int, tuple)) or
                (isinstance(kernel_size, int) and kernel_size <= 0) or
                (isinstance(kernel_size, tuple) and any(size <= 0 for size in kernel_size))):
            raise ValueError("kernel_size must be a positive integer or a tuple of positive integers")
        if not isinstance(num_param_blocks, int) or num_param_blocks <= 0:
            raise ValueError("num_param_blocks must be a positive integer")
        if kernel_circuit and not weight_shapes:
            raise ValueError("Must provide weight_shapes for custom kernel circuit")

    def _default_circuit(self, inputs: ArrayLike, weights: ArrayLike):
        # Encoding Layer
        for qubit in range(self.num_qubits):
            qml.Hadamard(wires=qubit)
            qml.RY(inputs[qubit], wires=qubit)

        # Parametric Layer
        for layer in range(self.num_param_blocks):
            # Entanglement
            for qubit in range(self.num_qubits):
                qml.CRZ(weights[layer, qubit], wires=[qubit, (qubit + 1) % self.num_qubits])
            # Rotation
            for qubit in range(self.num_qubits):
                qml.RY(weights[layer, self.num_qubits + qubit], wires=qubit)

        # Observation Layer
        _expectations = [qml.expval(qml.PauliZ(wires=qubit)) for qubit in range(self.num_qubits)]
        return _expectations

    def __repr__(self):
        return f"QKernel(num_qubits={self.num_qubits}, kernel_size={self.kernel_size}, num_param_blocks={self.num_param_blocks})"


## 2.2 Quantum Convolutional Layer

In [None]:
class _QuanvNd(nn.Module):
    __constants__ = ["in_channels", "out_channels", "kernel_size", "stride", "padding", "num_qlayers", "qdevice",
                     "diff_method"]

    def __init__(self):
        super().__init__()
        pass


In [None]:
class Quanv2d(_QuanvNd):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int = 2,
                 stride: int = 1,
                 padding: int = 0,
                 qkernel: QKernel = None,
                 num_qlayers: int = 2,
                 qdevice: str = "default.qubit",
                 qdevice_kwargs: dict = None,
                 diff_method: str = "best"):
        super(Quanv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.qkernel = qkernel or QKernel(quantum_channels=in_channels, kernel_size=kernel_size,
                                          num_param_blocks=num_qlayers)
        self.qdevice_kwargs = qdevice_kwargs or {}
        self.qdevice = qml.device(qdevice, wires=self.qkernel.num_qubits, **self.qdevice_kwargs)
        self.qnode = qml.QNode(self.qkernel.circuit, device=self.qdevice, interface="torch", diff_method=diff_method)
        self.qlayer = qml.qnn.TorchLayer(self.qnode, self.qkernel.weight_shapes)

        # Use 1x1 classical convolution to match the desired output channels
        self.classical_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # Verified params
        assert len(x.shape) == 4, "Input tensor must have 4 dimensions: (batch_size, channels, height, width)"
        assert x.dtype == torch.float32, "Input tensor must have dtype torch.float32"
        assert x.shape[1] == self.in_channels, f"Input tensor must have {self.in_channels} input channels"

        # Apply quantum convolution
        x = self.quantum_conv(x)

        # Apply 1x1 classical convolution to match the desired output channels
        x = self.classical_conv(x)

        return x

    def quantum_conv(self, x):
        bs, _, h, w = x.shape

        # Apply padding to the input tensor
        if self.padding != 0:
            x = F.pad(x, (self.padding,) * 4, mode="constant", value=0)
            h += 2 * self.padding
            w += 2 * self.padding

        # Unfold the input tensor to extract overlapping patches
        patches = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
        patches = patches.permute(0, 2, 1)  # Reshape to (bs, num_patches, in_channels * kernel_size^2)

        # Apply quantum kernel to x
        out = []
        # Apply quantum layer to each batch
        for i in range(bs):
            batch_out = [self.qlayer(patch) for patch in patches[i]]
            out.append(torch.stack(batch_out))
        out = torch.stack(out)

        # Fold the output tensor
        out = out.permute(0, 2, 1)
        out = F.fold(out, output_size=(h, w), kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)

        # Normalization
        ones = torch.ones_like(x)
        unfolded_ones = F.unfold(ones, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
        folded_ones = F.fold(unfolded_ones, output_size=(h, w), kernel_size=self.kernel_size, stride=self.stride,
                             padding=self.padding)
        out = out / folded_ones

        return out


## 2.3 Quafu Extensions

### 2.3.1 Quafu Convolutional Kernel

In [None]:
class QuafuQKernel:

    def __init__(self,
                 in_channels: int,
                 kernel_size: int = 2,
                 num_qlayers: int = 2,
                 kernel_circuit: Callable[[ArrayLike], None] = None,
                 weight_shapes: dict[str, tuple] = None,
                 qdevice="simulator",
                 api_token: str = ""):
        """Quafu Quantum Kernel"""
        self.validate_params(in_channels, kernel_size, num_qlayers, kernel_circuit, weight_shapes)

        self.num_qubits = in_channels * kernel_size ** 2
        self.num_qlayers = num_qlayers

        self.circuit = kernel_circuit if kernel_circuit else self._default_circuit
        self.weight_shapes = weight_shapes if weight_shapes else {"weights": (num_qlayers, 2 * self.num_qubits)}

        self.qdevice = qdevice
        self.api_token = api_token

        self.qcircuit = None
        if qdevice != "simulator":
            self.task = Task()
            user = User(self.api_token)
            user.save_apitoken()
            self.task.config(backend=self.qdevice, shots=2000, compile=True)

    @staticmethod
    def validate_params(in_channels, kernel_size, num_qlayers, kernel_circuit, weight_shapes):
        if not isinstance(in_channels, int) or in_channels <= 0:
            raise ValueError("in_channels must be a positive integer")
        if not isinstance(kernel_size, int) or kernel_size <= 0:
            raise ValueError("kernel_size must be a positive integer")
        if not isinstance(num_qlayers, int) or num_qlayers <= 0:
            raise ValueError("num_qlayers must be a positive integer")
        if kernel_circuit and not weight_shapes:
            raise ValueError("Must provide weight_shapes for custom parametric_layer")

    @staticmethod
    def expval_z_expectations(probabilities, num_qubits):
        """
        Calculate the expectation values of the Z operator for each qubit.
        """
        z_expectations = [0] * num_qubits
        probs = probabilities

        if not isinstance(probabilities, dict):
            bases = [format(i, "0" + str(num_qubits) + "b") for i in range(2 ** num_qubits)]
            probs = dict(zip(bases, probabilities))

        for base, prob in probs.items():
            for i in range(num_qubits):
                # For Z operator: |0> contributes +1, |1> contributes -1
                z_expectations[i] += prob * (1 if base[i] == "0" else -1)

        return z_expectations

    def _default_circuit(self, inputs: ArrayLike, weights: ArrayLike):
        # Quantum Kernel Circuit
        self.qcircuit = QuantumCircuit(self.num_qubits)

        # TODO: Check Data Type(It seems that np.ndarray is supported with warning and torch.Tensor is not supported.)
        if isinstance(inputs, (np.ndarray, torch.Tensor)):
            inputs = inputs.tolist()
        if isinstance(weights, (np.ndarray, torch.Tensor)):
            weights = weights.tolist()

        # Encoding Layer
        for qubit in range(self.num_qubits):
            self.qcircuit.h(qubit)
            self.qcircuit.ry(qubit, inputs[qubit])

        # Parametric Layer
        for layer in range(self.num_qlayers):
            # Entanglement
            for i in range(self.num_qubits):
                # 分解 CRZ = RZ + CNOT + RZ + CNOT
                # q_0: ─────────────■────────────────■──
                #      ┌─────────┐┌─┴─┐┌──────────┐┌─┴─┐
                # q_1: ┤ Rz(λ/2) ├┤ X ├┤ Rz(-λ/2) ├┤ X ├
                #      └─────────┘└───┘└──────────┘└───┘
                self.qcircuit.rz((i + 1) % self.num_qubits, weights[layer][i] / 2)
                self.qcircuit.cx(i, (i + 1) % self.num_qubits)
                self.qcircuit.rz((i + 1) % self.num_qubits, -weights[layer][i] / 2)
                self.qcircuit.cx(i, (i + 1) % self.num_qubits)
            # Rotation
            for qubit in range(self.num_qubits):
                self.qcircuit.ry(qubit, weights[layer][self.num_qubits + qubit])

        # Observation Layer
        self.qcircuit.measure()

        if self.qdevice == "simulator":
            results = simulate(self.qcircuit, output="probabilities").probabilities
        else:
            results = self.task.send(self.qcircuit, wait=True).probabilities

        _expectations = self.expval_z_expectations(results, self.num_qubits)

        return _expectations


### 2.3.2 QuafuTorchLayer

In [None]:
class QuafuTorchLayer(nn.Module):

    def __init__(self, qkernel: QuafuQKernel, weight_shapes: dict):
        super().__init__()
        weight_shapes = {
            weight: (
                tuple(size)
                if isinstance(size, Iterable)
                else () if size == 1 else (size,)
            )
            for weight, size in weight_shapes.items()
        }
        # validate the QNode signature, and convert to a Torch QNode.
        self.qkernel = qkernel
        self.circuit = qkernel.circuit
        self.qnode_weights: Dict[str, torch.nn.Parameter] = {}
        self._init_weights(weight_shapes=weight_shapes)

    def forward(self, inputs):
        """Evaluates a forward pass through the QNode based upon input data and the initialized
        weights.

        Args:
            inputs (tensor): data to be processed

        Returns:
            tensor: output data
        """
        if len(inputs.shape) > 1:
            # If the input size is not 1-dimensional, unstack the input along its first dimension,
            # recursively call the forward pass on each of the yielded tensors, and then stack the
            # outputs back into the correct shape
            reconstructor = [self.forward(x) for x in torch.unbind(inputs)]
            return torch.stack(reconstructor)

        # If the input is 1-dimensional, calculate the forward pass as usual
        return self._evaluate_qnode(inputs)

    def _evaluate_qnode(self, x):
        """Evaluates the QNode for a single input datapoint.

        Args:
            x (tensor): the datapoint

        Returns:
            tensor: output datapoint
        """
        kwargs = {"inputs": x}
        kwargs.update({arg: weight.to(x) for arg, weight in self.qnode_weights.items()})

        result = torch.tensor(self.circuit(**kwargs))

        if isinstance(result, torch.Tensor):
            return result.type(x.dtype)

        return torch.hstack(result).type(x.dtype)

    def _init_weights(self, weight_shapes: Dict[str, tuple]):
        """Initialize and register the weights, and weights are randomly initialized from the uniform distribution
        on the interval [0, 2π].
        """
        for name, size in weight_shapes.items():
            weight = torch.Tensor(*size)
            torch.nn.init.uniform_(weight, b=2 * math.pi)
            self.qnode_weights[name] = torch.nn.Parameter(weight)
            self.register_parameter(name, self.qnode_weights[name])


### 2.3.3 Quafu Convolutional Layer

In [None]:
class QuafuQuanv2d(_QuanvNd):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int = 2,
                 stride: int = 1,
                 padding: int = 0,
                 qkernel: QuafuQKernel = None,
                 num_qlayers: int = 2,
                 qdevice: str = "simulator",
                 api_token: str = ""):
        super(QuafuQuanv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.qkernel = qkernel or QuafuQKernel(in_channels=in_channels, kernel_size=kernel_size,
                                               num_qlayers=num_qlayers, qdevice=qdevice, api_token=api_token)
        self.qlayer = QuafuTorchLayer(self.qkernel, self.qkernel.weight_shapes)

        # Use 1x1 classical convolution to match the desired output channels
        self.classical_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # Verified params
        assert len(x.shape) == 4, "Input tensor must have 4 dimensions: (batch_size, channels, height, width)"
        assert x.dtype == torch.float32, "Input tensor must have dtype torch.float32"
        assert x.shape[1] == self.in_channels, f"Input tensor must have {self.in_channels} input channels"

        # Apply quantum convolution
        x = self.quantum_conv(x)

        # Apply 1x1 classical convolution to match the desired output channels
        x = self.classical_conv(x)

        return x

    def quantum_conv(self, x):
        bs, _, h, w = x.shape

        # Apply padding to the input tensor
        if self.padding != 0:
            x = F.pad(x, (self.padding,) * 4, mode="constant", value=0)
            h += 2 * self.padding
            w += 2 * self.padding

        # Unfold the input tensor to extract overlapping patches
        patches = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
        patches = patches.permute(0, 2, 1)  # Reshape to (bs, num_patches, in_channels * kernel_size^2)

        # Apply quantum kernel to x
        out = []
        # Apply quantum layer to each batch
        for i in range(bs):
            batch_out = [self.qlayer(patch) for patch in patches[i]]
            out.append(torch.stack(batch_out))
        out = torch.stack(out)

        # Fold the output tensor
        out = out.permute(0, 2, 1)
        out = F.fold(out, output_size=(h, w), kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)

        # Normalization
        ones = torch.ones_like(x)
        unfolded_ones = F.unfold(ones, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
        folded_ones = F.fold(unfolded_ones, output_size=(h, w), kernel_size=self.kernel_size, stride=self.stride,
                             padding=self.padding)
        out = out / folded_ones

        return out


# 3. Models

## 3.1 Benchmark Models

In [None]:
class ClassicNet(nn.Module):

    def __init__(self, num_classes=10):
        super(ClassicNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(8 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        bs = x.shape[0]
        # Convolution
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        # Linear
        x = x.view(bs, -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x


In [None]:
class HybridNet(nn.Module):

    def __init__(self, num_classes=10, stride=2, **kwargs):
        super(HybridNet, self).__init__()
        self.quanv = Quanv2d(in_channels=1, out_channels=4, kernel_size=2, stride=stride, padding=0, **kwargs)
        self.conv = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(8 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        bs = x.shape[0]
        # Quantum Convolution
        x = self.quanv(x)
        x = self.conv(x)
        x = F.relu(x)
        # Linear
        x = x.view(bs, -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)

        return x


In [None]:
class QuafuHybridNet(nn.Module):

    def __init__(self, num_classes=10, stride=2, **kwargs):
        super(QuafuHybridNet, self).__init__()
        self.quanv = QuafuQuanv2d(in_channels=1, out_channels=4, kernel_size=2, stride=stride, padding=0, **kwargs)
        self.conv = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(8 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        bs = x.shape[0]
        # Quantum Convolution
        x = self.quanv(x)
        x = self.conv(x)
        x = F.relu(x)
        # Linear
        x = x.view(bs, -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)

        return x


## 3.2 Hybrid Models

### 3.2.1 VGG

In [None]:
class SimpleVGG(nn.Module):

    def __init__(self, num_classes=10):
        super(SimpleVGG, self).__init__()
        self.features = nn.Sequential(
            # Layer 1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Layer 2
            nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Layer 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(128 * 7 * 7, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


In [None]:
class HybridVGG(nn.Module):

    def __init__(self, num_classes=10, **kwargs):
        super(HybridVGG, self).__init__()
        self.features = nn.Sequential(
            # Layer 1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Layer 2
            # 先降维后升维
            nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
            Quanv2d(3, 64, kernel_size=2, stride=2, padding=0, **kwargs),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Layer 3
            # N x 64 x 16 x 16
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(128 * 7 * 7, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


In [None]:
# TODO: Hybrid Quafu

### 3.2.2 GoogLeNet

In [None]:
class SimpleGoogLeNet(nn.Module):

    def __init__(
            self,
            num_classes: int = 10,
            aux_logits: bool = True,
            dropout: float = 0.4,
            dropout_aux: float = 0.5,
    ):
        super(SimpleGoogLeNet, self).__init__()

        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(16, 16, kernel_size=1)
        self.conv3 = BasicConv2d(16, 32, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

        self.inception3a = SimpleInception(32, 24, 12, 24, 4, 8, 12)
        self.inception3b = SimpleInception(68, 48, 24, 32, 8, 12, 24)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

        self.inception4a = SimpleInception(116, 64, 12, 36, 4, 12, 36)
        self.inception4b = SimpleInception(148, 48, 24, 48, 12, 24, 36)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)

        self.inception5a = SimpleInception(156, 72, 36, 64, 4, 12, 48)
        self.inception5b = SimpleInception(196, 72, 36, 64, 24, 32, 48)

        if self.aux_logits:
            self.aux = SimpleInceptionAux(156, num_classes, dropout=dropout_aux)
        else:
            self.aux = None

        self.avgpool = nn.AvgPool2d((1, 1))
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(864, num_classes)

    def forward(self, x):
        # N x 3 x 64 x 64
        x = self.conv1(x)
        # N x 16 x 32 x 32
        x = self.maxpool1(x)
        # N x 16 x 16 x 16
        x = self.conv2(x)
        x = self.conv3(x)
        # N x 32 x 16 x 16
        x = self.maxpool2(x)

        # N x 32 x 8 x 8
        x = self.inception3a(x)
        # N x 68 x 8 x 8
        x = self.inception3b(x)
        # N x 116 x 8 x 8
        x = self.maxpool3(x)

        # N x 116 x 4 x 4
        x = self.inception4a(x)
        # N x 148 x 4 x 4
        x = self.inception4b(x)
        # N x 156 x 2 x 2
        aux: Optional[Tensor] = None
        if self.training and self.aux:
            aux = self.aux(x)
        x = self.maxpool4(x)

        # N x 156 x 2 x 2
        x = self.inception5a(x)
        # N x 196 x 2 x 2
        x = self.inception5b(x)

        x = self.avgpool(x)
        # N x 216 x 2 x 2
        x = torch.flatten(x, 1)
        # N x 864
        x = self.dropout(x)
        x = self.fc(x)

        if self.training and self.aux_logits:
            return x, aux

        return x


class SimpleInception(nn.Module):

    def __init__(
            self,
            in_channels: int,
            ch1x1: int,
            ch3x3red: int,
            ch3x3: int,
            ch5x5red: int,
            ch5x5: int,
            pool_proj: int):
        super(SimpleInception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class SimpleInceptionAux(nn.Module):

    def __init__(
            self,
            in_channels: int,
            num_classes: int,
            dropout: float = 0.5,
    ):
        super(SimpleInceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = dropout

    def forward(self, x):
        # aux1: N x 156 x 4 x 4
        x = self.averagePool(x)
        # aux1: N x 156 x 2 x 2
        x = self.conv(x)
        # N x 128 x 2 x 2
        x = torch.flatten(x, start_dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        # N x 512
        x = self.fc1(x)
        x = F.relu(x, inplace=True)
        x = F.dropout(x, self.dropout, training=self.training)
        # N x 128
        x = self.fc2(x)

        return x


class BasicConv2d(nn.Module):

    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


In [None]:
class HybridGoogLeNet(nn.Module):

    def __init__(
            self,
            num_classes: int = 10,
            aux_logits: bool = True,
            dropout: float = 0.4,
            dropout_aux: float = 0.5,
            **kwargs: Any,
    ):
        super(HybridGoogLeNet, self).__init__()

        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(16, 16, kernel_size=1)
        self.conv3 = BasicConv2d(16, 32, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

        self.inception3a = HybridInception(32, 24, 12, 24, 4, 8, 12, **kwargs)
        self.inception3b = SimpleInception(68, 48, 24, 32, 8, 12, 24)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

        self.inception4a = HybridInception(116, 64, 12, 36, 4, 12, 36, **kwargs)
        self.inception4b = SimpleInception(148, 48, 24, 48, 12, 24, 36)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)

        self.inception5a = HybridInception(156, 72, 36, 64, 4, 12, 48, **kwargs)
        self.inception5b = SimpleInception(196, 72, 36, 64, 24, 32, 48)

        if self.aux_logits:
            self.aux = SimpleInceptionAux(156, num_classes, dropout=dropout_aux)
        else:
            self.aux = None

        self.avgpool = nn.AvgPool2d((1, 1))
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(864, num_classes)

    def forward(self, x):
        # N x 3 x 64 x 64
        x = self.conv1(x)
        # N x 16 x 32 x 32
        x = self.maxpool1(x)
        # N x 16 x 16 x 16
        x = self.conv2(x)
        x = self.conv3(x)
        # N x 32 x 16 x 16
        x = self.maxpool2(x)

        # N x 32 x 8 x 8
        x = self.inception3a(x)
        # N x 68 x 8 x 8
        x = self.inception3b(x)
        # N x 116 x 8 x 8
        x = self.maxpool3(x)

        # N x 116 x 4 x 4
        x = self.inception4a(x)
        # N x 148 x 4 x 4
        x = self.inception4b(x)
        # N x 156 x 2 x 2
        aux: Optional[Tensor] = None
        if self.training and self.aux:
            aux = self.aux(x)
        x = self.maxpool4(x)

        # N x 156 x 2 x 2
        x = self.inception5a(x)
        # N x 196 x 2 x 2
        x = self.inception5b(x)

        x = self.avgpool(x)
        # N x 216 x 2 x 2
        x = torch.flatten(x, 1)
        # N x 864
        x = self.dropout(x)
        x = self.fc(x)

        if self.training and self.aux_logits:
            return x, aux

        return x


class HybridInception(nn.Module):

    def __init__(
            self,
            in_channels: int,
            ch1x1: int,
            ch3x3red: int,
            ch3x3: int,
            ch5x5red: int,
            ch5x5: int,
            pool_proj: int,
            **kwargs: Any,
    ):
        super(HybridInception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            HybridConv2d(ch5x5red, ch5x5, kernel_size=2, stride=2, padding=0, **kwargs)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class HybridConv2d(nn.Module):

    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super(HybridConv2d, self).__init__()
        self.quanv = Quanv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.quanv(x)
        x = self.bn(x)
        return x


In [None]:
# TODO: Hybrid Quafu

### 3.2.3 ResNet

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            stride: int = 1,
            downsample: Optional[nn.Module] = None,
    ):
        super(BasicBlock, self).__init__()
        # Layer 1
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        # Layer 2
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=4,
                               kernel_size=1, stride=1, padding=0, bias=False)
        self.conv3 = nn.Conv2d(in_channels=4, out_channels=out_channels,
                               kernel_size=3, stride=1, padding=1, bias=False)
        # Layer 3
        self.conv4 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.conv3(out)
        out = self.relu(out)

        out = self.conv4(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class SimpleResNet(nn.Module):
    def __init__(
            self,
            block: Type[Union[BasicBlock]],
            num_blocks: List[int],
            num_classes: int = 10,
    ):
        super(SimpleResNet, self).__init__()

        self.in_channels = 16

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 128, num_blocks[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def _make_layer(
            self,
            block: Type[Union[BasicBlock]],
            channels: int,
            block_num: int,
            stride: int = 1,
    ) -> nn.Sequential:
        downsample = None
        if stride != 1 or self.in_channels != channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channels * block.expansion),
            )

        layers = [
            block(self.in_channels, channels, stride=stride, downsample=downsample),
        ]
        self.in_channels = channels * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channels, channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def simple_resnet18(num_classes=10):
    return SimpleResNet(BasicBlock, [2, 2, 2, 2], num_classes)


def simple_resnet34(num_classes=10):
    return SimpleResNet(BasicBlock, [3, 4, 6, 3], num_classes)


In [None]:
class HybridBlock(nn.Module):
    expansion = 1

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            stride: int = 1,
            downsample: Optional[nn.Module] = None,
            **kwargs: Any,
    ):
        super(HybridBlock, self).__init__()
        # Layer 1
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        # Layer 2
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=4,
                               kernel_size=1, stride=1, padding=0, bias=False)
        self.quanv = Quanv2d(in_channels=4, out_channels=out_channels, kernel_size=2, stride=2, padding=0, **kwargs)
        # Layer 3
        self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.quanv(out)

        out = self.conv3(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class HybridResNet(nn.Module):
    def __init__(
            self,
            block: Type[Union[BasicBlock]],
            num_blocks: List[int],
            num_classes: int = 10,
            **kwargs: Any,
    ):
        super(HybridResNet, self).__init__()

        self.in_channels = 16

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, use_quanv2d=False, **kwargs)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, use_quanv2d=False, **kwargs)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, use_quanv2d=True, **kwargs)
        self.layer4 = self._make_layer(block, 128, num_blocks[3], stride=2, use_quanv2d=True, **kwargs)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def _make_layer(
            self,
            block: Type[Union[BasicBlock]],
            channels: int,
            block_num: int,
            stride: int = 1,
            use_quanv2d: bool = False,
            **kwargs: Any,
    ) -> nn.Sequential:
        downsample = None
        if stride != 1 or self.in_channels != channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channels * block.expansion),
            )

        layers = []
        if use_quanv2d:
            layers.append(HybridBlock(self.in_channels, channels, stride=stride, downsample=downsample, **kwargs))
        else:
            layers.append(block(self.in_channels, channels, stride=stride, downsample=downsample))
        self.in_channels = channels * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channels, channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def hybrid_resnet18(num_classes=10, **kwargs):
    return HybridResNet(BasicBlock, [2, 2, 2, 2], num_classes, **kwargs)


def hybrid_resnet34(num_classes=10, **kwargs):
    return HybridResNet(BasicBlock, [3, 4, 6, 3], num_classes, **kwargs)


In [None]:
# TODO: Hybrid Quafu

In [None]:
ALL_MODELS = {
    'ClassicNet': ClassicNet,
    'HybridNet': HybridNet,
    'SimpleVGG': SimpleVGG,
    'HybridVGG': HybridVGG,
    'SimpleGoogLeNet': SimpleGoogLeNet,
    'HybridGoogLeNet': HybridGoogLeNet,
    'SimpleResNet': SimpleResNet,
    'HybridResNet': HybridResNet,
}

# 4. Dataset

In [None]:
class GarbageDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.classes = [
            'battery', 'biological', 'brown-glass', 'cardboard', 'clothes',
            'metal', 'paper', 'plastic', 'shoes', 'trash',
        ]

        self.images = []
        self.labels = []
        for index, class_name in enumerate(self.classes):
            class_dir = os.path.join(self.root_dir, class_name)
            class_dir = str(class_dir)  # This line is not required and is only used to eliminate PyCharm's warnings.
            for image in os.listdir(class_dir):
                image_path = os.path.join(class_dir, image)
                self.images.append(image_path)
                self.labels.append(index)

    def set_transform(self, transform):
        self.transform = transform

    def __getitem__(self, index):
        image_path = self.images[index]
        label = self.labels[index]

        image = Image.open(image_path).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def __len__(self):
        return len(self.images)


# 5. Tools

## 5.1 Dataset Tools

In [None]:
def extract_images_from_datasets(dataset, num_pics, out_dir='./pics/ext_pics/'):
    """
    Save a specified number of randomly selected images from a dataset to an output directory.
    """
    labels = dataset.classes
    random_indexes = random.sample(range(len(dataset)), num_pics)
    os.makedirs(out_dir, exist_ok=True)

    # Save data
    for idx in random_indexes:
        image, label = dataset[idx]
        class_name = labels[label]
        image_path = os.path.join(out_dir, class_name, f'{idx}.jpg')
        os.makedirs(os.path.dirname(image_path), exist_ok=True)

        try:
            image.save(image_path)
        except Exception as e:
            print(f"Failed to save image {idx}: {str(e)}")


def save_class_indices(dataset, json_file_path):
    """
    Save the class information and their corresponding indices from the dataset to a JSON file.
    """
    labels = dataset.classes
    label_dict = {i: label for i, label in enumerate(labels)}

    with open(json_file_path, "w") as json_file:
        json.dump(label_dict, json_file, indent=4)

    print(f"Class labels and their indices successfully saved to {json_file_path}.")


## 5.2 Quantum Transforms

In [None]:
class ToTensor4Quantum:
    """
    Transform class to convert a PIL Image or ndarray to tensor and scale the values suitable for quantum computing.

    This class normalizes the input image and converts it to a tensor representation.
    The normalization process includes scaling pixel values in the image to the range (0, π).

    Usage:
        transform = ToTensor4Quantum()
        tensor_image = transform(image)
    """

    def __init__(self) -> None:
        pass

    @wraps(torch.tensor)
    def __call__(self, pic):
        return np.pi * TF.to_tensor(pic)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"


## 5.3 Visualization Tools

In [None]:
# Constants
TITLE_FONTSIZE = 16
LABEL_FONTSIZE = 14
TICK_FONTSIZE = 12
LEGEND_FONTSIZE = 12
FIGSIZE_WIDTH = 12
FIGSIZE_HEIGHT_PER_ROW = 3
PROBABILITY_FIGSIZE_WIDTH = 18
PROBABILITY_FIGSIZE_HEIGHT_PER_ROW = 6


def save_confusion_matrix(confusion_matrix_data, output_path):
    plt.figure(figsize=(8, 6))
    sns.heatmap(confusion_matrix_data, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()


def set_plot_style(ax, title, xlabel, ylabel):
    """
    Set the style of a subplot.
    """
    ax.set_title(title, fontsize=TITLE_FONTSIZE)
    ax.set_xlabel(xlabel, fontsize=LABEL_FONTSIZE)
    ax.set_ylabel(ylabel, fontsize=LABEL_FONTSIZE)
    ax.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)


def plot_evaluation_metrics(models_linestyles, data_types, data_dir='./output/', show=True, save=False,
                            save_path='./output/evaluation_metrics_chart.png', names=None):
    """
    Plot models' evaluation metrics as line charts.
    """
    # Set subplot layout
    num_data_types = len(data_types)
    num_columns = 2
    num_rows = (num_data_types + num_columns - 1) // num_columns
    fig, axes = plt.subplots(num_rows, num_columns, figsize=(FIGSIZE_WIDTH, FIGSIZE_HEIGHT_PER_ROW * num_rows))
    axes = axes.flatten()

    # Plot subplots
    for idx, data_type in enumerate(data_types):
        ax = axes[idx]
        title = data_type.replace('_', ' ').title()
        ylabel = data_type.split('_')[1].title()
        set_plot_style(ax, title, 'Epoch', ylabel)

        for model_name, linestyle in models_linestyles.items():
            data = load_evaluation_metrics(model_name, data_type, data_dir)
            # ax.plot(data, label=model_name, linestyle=linestyle)
            label = names.get(model_name, model_name) if names else model_name
            ax.plot(data, label=label, linestyle=linestyle)
            ax.legend(fontsize=LEGEND_FONTSIZE)

    for i in range(num_data_types, num_columns * num_rows):
        fig.delaxes(axes[i])

    plt.tight_layout()

    if save:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)

    if show:
        plt.show()


def plot_probabilities(results, labels, show=True, save=False, save_path='./probabilities.png', colors=None):
    """
    Plot probability distribution charts for multiple model classification results.
    """
    if not results:
        print("No results to plot.")
        return

    # Default color list
    if colors is None:
        colors = ['blue', 'green', 'red', 'purple', 'orange', 'cyan']

    # Set subplot layout
    num_models = len(results)
    num_columns = 3
    num_rows = (num_models + num_columns - 1) // num_columns
    fig, axes = plt.subplots(num_rows, num_columns,
                             figsize=(PROBABILITY_FIGSIZE_WIDTH, PROBABILITY_FIGSIZE_HEIGHT_PER_ROW * num_rows))
    axes = axes.flatten()

    # Plot probability distribution charts
    for idx, (model_name, result) in enumerate(results.items()):
        ax = axes[idx]  # 获取当前子图

        probabilities = result['probabilities'][0]
        ax.bar(labels, probabilities, color=colors[idx % len(colors)])
        set_plot_style(ax, f'{model_name} Prediction', '', '')
        ax.set_xticks(labels)
        ax.set_xticklabels(labels, rotation=45, ha='right')
        ax.set_ylim(0, 0.3)

    # Clear unused subplots
    for i in range(num_models, num_columns * num_rows):
        fig.delaxes(axes[i])

    # Adjust spacing between subplots
    plt.tight_layout()

    # Save probability distribution charts
    if save:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)

    # Show probability distribution charts
    if show:
        plt.show()


## 5.4 Model Tools

In [None]:
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_models(dataset_name, qdevice, qdevice_kwargs, diff_method):
    if dataset_name == 'FashionMNIST':
        return {
            'ClassicNet': ClassicNet(num_classes=10),
            'HybridNet': HybridNet(num_classes=10, qkernel=None, num_qlayers=1,
                                   qdevice=qdevice, qdevice_kwargs=qdevice_kwargs, diff_method=diff_method),
            'HybridNetDeeper': HybridNet(num_classes=10, qkernel=None, num_qlayers=2,
                                         qdevice=qdevice, qdevice_kwargs=qdevice_kwargs, diff_method=diff_method),
            'HybridNetStrideOne': HybridNet(num_classes=10, qkernel=None, num_qlayers=2, stride=1,
                                            qdevice=qdevice, qdevice_kwargs=qdevice_kwargs, diff_method=diff_method),
            # TODO: About Barren Plateau
            'HybridNetDeeper2': HybridNet(num_classes=10, qkernel=None, num_qlayers=3,
                                          qdevice=qdevice, qdevice_kwargs=qdevice_kwargs, diff_method=diff_method),
        }
    elif dataset_name == 'GarbageDataset':
        return {
            'SimpleVGG': SimpleVGG(num_classes=10),
            'HybridVGG': HybridVGG(num_classes=10, qkernel=None, num_qlayers=2,
                                   qdevice=qdevice, qdevice_kwargs=qdevice_kwargs, diff_method=diff_method),
            'SimpleGoogLeNet': SimpleGoogLeNet(num_classes=10),
            'HybridGoogLeNet': HybridGoogLeNet(num_classes=10, qkernel=None, num_qlayers=2,
                                               qdevice=qdevice, qdevice_kwargs=qdevice_kwargs, diff_method=diff_method),
            'SimpleResNet': simple_resnet18(num_classes=10),
            'HybridResNet': hybrid_resnet18(num_classes=10, qkernel=None, num_qlayers=2,
                                            qdevice=qdevice, qdevice_kwargs=qdevice_kwargs, diff_method=diff_method),
        }
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")


def get_transforms(model_name, dataset_name):
    model_type = 'Hybrid' if 'Hybrid' in model_name else 'Classic'
    if dataset_name == 'FashionMNIST' and model_type == 'Classic':
        train_transform = test_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((14, 14)),
            torchvision.transforms.ToTensor(),
        ])
    elif dataset_name == 'FashionMNIST' and model_type == 'Hybrid':
        train_transform = test_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((14, 14)),
            ToTensor4Quantum(),
        ])
    elif dataset_name == 'GarbageDataset' and model_type == 'Classic':
        train_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(72),
            torchvision.transforms.RandomCrop(64),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomRotation(degrees=15),
            torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            torchvision.transforms.ToTensor(),
        ])
        test_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(72),
            torchvision.transforms.CenterCrop(64),
            torchvision.transforms.ToTensor(),
        ])
    elif dataset_name == 'GarbageDataset' and model_type == 'Hybrid':
        train_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(72),
            torchvision.transforms.RandomCrop(64),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomRotation(degrees=15),
            torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            ToTensor4Quantum(),
        ])
        test_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(72),
            torchvision.transforms.CenterCrop(64),
            ToTensor4Quantum(),
        ])
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    return train_transform, test_transform


def get_dataset(dataset_name, data_dir, train_transform, test_transform, batch_size):
    if dataset_name == 'FashionMNIST':
        train_set = torchvision.datasets.FashionMNIST(root=data_dir, train=True, transform=train_transform,
                                                      download=True)
        test_set = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=test_transform,
                                                     download=True)
    elif dataset_name == 'GarbageDataset':
        train_set = GarbageDataset(root_dir=os.path.join(data_dir, 'GarbageDataset', 'train'),
                                   transform=train_transform)
        test_set = GarbageDataset(root_dir=os.path.join(data_dir, 'GarbageDataset', 'test'),
                                  transform=test_transform)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


def construct_file_path(model_name, data_type, data_dir='./output/'):
    """
    Construct the file path based on model name and data type.
    """
    return os.path.join(data_dir, model_name, f'{model_name}_{data_type}.npy')


def load_model_with_weights(model_name, model_weights_path, num_classes, device, **kwargs):
    """
    Load weights data into the specified model and return the model.
    """
    try:
        if model_name not in ALL_MODELS:
            raise ValueError("无效的模型名称")

        model = ALL_MODELS[model_name](num_classes=num_classes, **kwargs)

        try:
            model.load_state_dict(torch.load(model_weights_path, map_location=device))
        except FileNotFoundError:
            logging.error(f"Model weights file not found: {model_weights_path}")
            return None

        model.to(device)
        model.eval()

        return model

    except Exception as e:
        logging.error(f"Error loading the model: {str(e)}")
        return None


def save_evaluation_metrics(model_name, data, data_type, data_dir='./output/'):
    """
    Save evaluation metrics to the specified path.
    """
    data_file_path = construct_file_path(model_name, data_type, data_dir)
    os.makedirs(os.path.dirname(data_file_path), exist_ok=True)
    np.save(data_file_path, data)


def load_evaluation_metrics(model_name, data_type, data_dir='./output/'):
    """
    Load evaluation metrics data from a file.
    """
    data_file_path = construct_file_path(model_name, data_type, data_dir)
    if os.path.exists(data_file_path):
        return np.load(data_file_path)
    else:
        logging.warning(f"Evaluation metrics file not found: {data_file_path}")
        return None


# 6. Train

In [None]:
def train(model_name, model, train_loader, test_loader, optimizer, scheduler, criterion, device, num_epochs, output_dir,
          tensorboard_dir, aux_weight=0.4):
    logger = logging.getLogger(__name__)
    logger.info(f"Start training for {num_epochs} epochs.")

    model_dir = os.path.join(output_dir, model_name)
    os.makedirs(model_dir, exist_ok=True)
    model_weights_path = os.path.join(str(model_dir), f"{model_name}_model.pth")

    writer = SummaryWriter(log_dir=tensorboard_dir)

    # Move the model to the specified device
    model = model.to(device)

    # Initialize lists to store statistics
    best_test_acc = 0.0
    train_loss_history = []
    test_loss_history = []
    train_acc_history = []
    test_acc_history = []

    for epoch in range(1, num_epochs + 1):
        start_time = time.time()

        # Training
        model.train()
        train_loss = 0.0
        train_acc = 0.0
        train_samples = 0
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch} - Training")
        for img, label in train_bar:
            img, label = img.to(device), label.to(device)
            optimizer.zero_grad()

            if 'GoogLeNet' in model_name:
                output, aux_output = model(img)
                loss1 = criterion(output, label)
                loss2 = criterion(aux_output, label)
                loss = loss1 + aux_weight * loss2
            else:
                output = model(img)
                loss = criterion(output, label)

            loss.backward()
            optimizer.step()

            train_loss += abs(loss.item()) * img.size(0)
            accuracy = torch.sum(torch.argmax(output, dim=1) == label).item()
            train_acc += accuracy
            train_samples += img.size(0)

            train_bar.set_postfix(loss=loss.item(), accuracy=accuracy / img.size(0))

        train_loss /= train_samples
        train_acc /= train_samples
        train_loss_history.append(train_loss)
        train_acc_history.append(train_acc)

        # Testing
        model.eval()
        test_loss = 0.0
        test_acc = 0.0
        test_samples = 0
        with torch.no_grad():
            test_bar = tqdm(test_loader, desc=f"Epoch {epoch} - Testing")
            for img, label in test_bar:
                img, label = img.to(device), label.to(device)
                output = model(img)
                loss = criterion(output, label)

                test_loss += abs(loss.item()) * img.size(0)
                accuracy = torch.sum(torch.argmax(output, dim=1) == label).item()
                test_acc += accuracy
                test_samples += img.size(0)

                test_bar.set_postfix(loss=loss.item(), accuracy=accuracy / img.size(0))

        test_loss /= test_samples
        test_acc /= test_samples
        test_loss_history.append(test_loss)
        test_acc_history.append(test_acc)

        scheduler.step()

        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Accuracy/train", train_acc, epoch)
        writer.add_scalar("Loss/test", test_loss, epoch)
        writer.add_scalar("Accuracy/test", test_acc, epoch)

        if test_acc > best_test_acc:
            best_test_acc = test_acc
            torch.save(model.state_dict(), model_weights_path)

        elapsed_time = time.time() - start_time
        logger.info(f"Epoch [{epoch}/{num_epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                    f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Elapsed Time: {elapsed_time:.2f}s")

    writer.close()

    # Save evaluation metrics
    save_evaluation_metrics(model_name, train_loss_history, "train_loss", output_dir)
    save_evaluation_metrics(model_name, train_acc_history, "train_accuracy", output_dir)
    save_evaluation_metrics(model_name, test_loss_history, "test_loss", output_dir)
    save_evaluation_metrics(model_name, test_acc_history, "test_accuracy", output_dir)
    logger.info(f"Training completed. Best test accuracy: {best_test_acc:.4f}")


In [None]:
def main4train(args):
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
    logger = logging.getLogger(__name__)

    # Set Random Seed
    set_random_seed(args['seed'])

    # Set Device
    if args['device'] is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(args['device'])
    logger.info(f"Using device: {device}")

    # Set Model
    models = get_models(args['dataset'], args['qdevice'], args['qdevice_kwargs'], args['diff_method'])
    model = models[args['model']]
    model.to(device)

    # Set Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)

    # Set Datasets
    train_transform, test_transform = get_transforms(args['model'], args['dataset'])
    train_loader, test_loader = get_dataset(args['dataset'], args['data_dir'], train_transform, test_transform,
                                            args['batch_size'])

    # Set Output Dir
    os.makedirs(args['output_dir'], exist_ok=True)
    os.makedirs(args['tensorboard_dir'], exist_ok=True)

    # Train
    train(args['model'], model, train_loader, test_loader, optimizer, scheduler, criterion, device, args['epochs'],
          args['output_dir'], args['tensorboard_dir'])


In [None]:
TRAIN = True
train_args = {
    "model": "ClassicNet",
    "dataset": "FashionMNIST",
    "data_dir": "../datasets",
    "epochs": 10,
    "batch_size": 64,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "seed": 42,
    "log_interval": 10,
    "output_dir": "../output/Demo",
    "tensorboard_dir": "../tensorboard",
    "device": 'cpu',
    "qdevice": "lightning.qubit",
    "qdevice_kwargs": None,
    "diff_method": "adjoint"
}
if TRAIN:
    main4train(train_args)

# 7. Test

In [None]:
def test(model_name, model, test_loader, criterion, device, output_dir):
    logger = logging.getLogger(__name__)
    logger.info(f"Start testing.")

    # Move the model to the specified device
    model = model.to(device)

    # Evaluate the model
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    test_samples = 0
    test_labels = []
    test_predictions = []
    with torch.no_grad():
        test_bar = tqdm(test_loader, desc=f"Testing")
        for img, label in test_bar:
            img, label = img.to(device), label.to(device)
            output = model(img)
            loss = criterion(output, label)

            # test_loss += abs(loss.item()) * img.size(0)
            test_loss += loss.item() * img.size(0)
            accuracy = torch.sum(torch.argmax(output, dim=1) == label).item()
            test_acc += accuracy
            test_samples += img.size(0)

            test_labels.extend(label.cpu().numpy())
            test_predictions.extend(torch.argmax(output, dim=1).cpu().numpy())

            test_bar.set_postfix(loss=loss.item(), accuracy=accuracy / img.size(0))

    test_loss /= test_samples
    test_acc /= test_samples
    test_recall = recall_score(test_labels, test_predictions, average='macro')
    test_f1 = f1_score(test_labels, test_predictions, average='macro')
    confusion_matrix_data = confusion_matrix(test_labels, test_predictions)

    # Save evaluation metrics
    output_path = str(os.path.join(output_dir, model_name))
    metrics = {
        "test_loss": test_loss,
        "test_accuracy": test_acc,
        "test_recall": test_recall,
        "test_f1": test_f1,
    }
    with open(os.path.join(output_path, f"{model_name}_test_metrics.json"), "w") as f:
        json.dump(metrics, f, indent=4)

    # Save Confusion Matrix
    save_confusion_matrix(confusion_matrix_data, os.path.join(output_path, f"{model_name}_confusion_matrix.png"))

    logger.info(f"Testing completed. Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, "
                f"Test Recall: {test_recall:.4f}, Test F1: {test_f1:.4f}")


In [None]:
def main4test(args):
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
    logger = logging.getLogger(__name__)

    # Set Random Seed
    set_random_seed(args['seed'])

    # Set Device
    if args['device'] is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(args['device'])
    logger.info(f"Using device: {device}")

    # Set Model
    models = get_models(args['dataset'], args['qdevice'], args['qdevice_kwargs'], args['diff_method'])
    model = models[args['model']]
    model_weights_path = str(os.path.join(args['output_dir'], args['model'], f"{args['model']}_model.pth"))
    model.load_state_dict(torch.load(model_weights_path, map_location=device))

    # Set Criterion
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)

    # Set Datasets
    train_transform, test_transform = get_transforms(args['model'], args['dataset'])
    _, test_loader = get_dataset(args['dataset'], args['data_dir'], train_transform, test_transform, args['batch_size'])

    # Set Output Dir
    os.makedirs(args['output_dir'], exist_ok=True)

    # Test
    test(args['model'], model, test_loader, criterion, device, args['output_dir'])


In [None]:
TEST = True
test_args = {
    "model": "ClassicNet",
    "dataset": "FashionMNIST",
    "data_dir": "../datasets",
    "batch_size": 64,
    "seed": 42,
    "output_dir": "../output",
    "device": None,
    "qdevice": "default.qubit",
    "qdevice_kwargs": None,
    "diff_method": "best"
}
if TEST:
    main4test(test_args)

# 8. Predict