In [None]:

import datetime
import os
import time
from pathlib import Path
from typing import Optional

import numpy as np
import pennylane as qml
import torch
from ax import ChoiceParameterConfig, RangeParameterConfig
from matchcake import NonInteractingFermionicDevice
from matchcake.operations import (
    SingleParticleTransitionMatrixOperation,
)
from torchvision.transforms import Resize

from bolightningpipeline.datasets import *
from bolightningpipeline.modules.classification_model import ClassificationModel
from bolightningpipeline.tr_pipeline.lightning_pipeline import LightningPipeline

In [None]:
class NIFCNN(ClassificationModel):
    MODEL_NAME = "NIFCNN"
    DEFAULT_N_QUBITS = 16
    DEFAULT_LEARNING_RATE = 2e-4
    DEFAULT_ENCODER_OUTPUT_ACTIVATION = "Tanh"
    MIN_INPUT_SIZE = (28, 28)

    def __init__(
            self,
            input_shape: Optional[tuple[int, ...]],
            output_shape: Optional[tuple[int, ...]],
            learning_rate: float = DEFAULT_LEARNING_RATE,
            n_qubits: int = DEFAULT_N_QUBITS,
            encoder_output_activation: str = DEFAULT_ENCODER_OUTPUT_ACTIVATION,
            **kwargs,
    ):
        super().__init__(input_shape=input_shape, output_shape=output_shape, learning_rate=learning_rate, **kwargs)
        self.save_hyperparameters("learning_rate", "n_qubits", "encoder_output_activation")
        self.n_qubits = n_qubits
        self.R_DTYPE = torch.float32
        self.C_DTYPE = torch.cfloat
        self._n_params = np.triu_indices(2 * self.n_qubits, k=1)[0].size
        self.q_device = NonInteractingFermionicDevice(
            wires=self.n_qubits, r_dtype=self.R_DTYPE, c_dtype=self.C_DTYPE, show_progress=False
        )
        self.encoder_output_activation = encoder_output_activation
        self.input_resize = Resize(self.MIN_INPUT_SIZE)
        self.local_fields_encoder = torch.nn.Sequential(
            torch.nn.LazyConv2d(512, kernel_size=7),
            torch.nn.LazyBatchNorm2d(),
            torch.nn.LazyConv2d(128, kernel_size=5),
            torch.nn.LazyBatchNorm2d(),
            torch.nn.LazyConv2d(64, kernel_size=3),
            torch.nn.LazyBatchNorm2d(),
        )

        self.local_fields_head = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.LazyLinear(self.n_qubits),
            getattr(torch.nn, encoder_output_activation)()
        )
        self._hamiltonian_n_params = np.triu_indices(self.n_qubits, k=1)[0].size
        self.zz_body_couplings_head = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.LazyLinear(self._hamiltonian_n_params),
            getattr(torch.nn, encoder_output_activation)()
        )
        self.xx_body_couplings_head = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.LazyLinear(self._hamiltonian_n_params),
            getattr(torch.nn, encoder_output_activation)()
        )

        self.zz_body_coupling_weights = torch.nn.Parameter(torch.randn(self._hamiltonian_n_params), requires_grad=True)
        self.xx_body_coupling_weights = torch.nn.Parameter(torch.randn(self._hamiltonian_n_params), requires_grad=True)

        self.local_fields_op_eigvals = torch.nn.Parameter(torch.from_numpy(np.array([1.0, -1.0])).float(),
                                                          requires_grad=False)  # eigvals(Z)
        self.zz_eigvals = torch.nn.Parameter(torch.from_numpy(np.array([1.0, -1.0, -1.0, 1.0])).float(),
                                             requires_grad=False)  # eigvals(ZZ)
        self.xx_eigvals = torch.nn.Parameter(torch.from_numpy(np.array([1.0, -1.0, 1.0, -1.0])).float(),
                                             requires_grad=False)  # eigvals(XX)

        self.local_fields_wires = [[i] for i in range(self.n_qubits)]
        self.couplings_wires = [[i, j] for i, j in np.vstack(np.triu_indices(self.n_qubits, k=1)).T]

        self.weights = torch.nn.Parameter(torch.rand((int(self.output_size), self._n_params)), requires_grad=True)
        torch.nn.init.xavier_uniform_(self.weights)
        self._build()

    def _build(self):
        dummy_input = torch.randn((3, *self.input_shape)).to(device=self.device)
        with torch.no_grad():
            self(dummy_input)
        return self

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.preprocess_input(x)
        embeddings = self.local_fields_encoder(x)

        local_fields = self.local_fields_head(embeddings)
        zz_couplings = self.zz_body_couplings_head(embeddings)
        xx_couplings = self.xx_body_couplings_head(embeddings)
        self.q_device.execute_generator(self.circuit_gen(), reset=True)

        local_fields_probs = (
            self.q_device.probability(wires=self.local_fields_wires)
            .to(dtype=self.q_device.R_DTYPE, device=self.torch_device)
        )
        couplings_probs = (
            self.q_device.probability(wires=self.couplings_wires)
            .to(dtype=self.q_device.R_DTYPE, device=self.torch_device)
        )
        weighted_local_eigvals = torch.einsum(
            "bi,kij,j->bk", local_fields, local_fields_probs, self.local_fields_op_eigvals
        )
        weighted_zz_eigvals = torch.einsum(
            "bi,kij,j->bk", zz_couplings, couplings_probs, self.zz_eigvals
        )
        weighted_xx_eigvals = torch.einsum(
            "bi,kij,j->bk", xx_couplings, couplings_probs, self.xx_eigvals
        )

        expval = weighted_local_eigvals + weighted_zz_eigvals + weighted_xx_eigvals
        return expval

    def circuit_gen(self):
        yield qml.BasisState(self.initial_basis_state, wires=self.wires)
        yield self.get_sptm_weights()
        return

    def get_sptm_weights(self):
        """
        Compute the single-particle transition matrix (SPTM) weights based on the initialized weight tensor.

        This method constructs a tensor, `h`, which encodes the pairwise weight interactions
        in a prescribed upper triangular form. It then symmetrizes `h` to ensure anti-symmetry
        about the main diagonal. Afterward, the matrix exponential of `h` is computed to generate
        the SPTM. Finally, the SPTM is encapsulated in a `SingleParticleTransitionMatrixOperation`
        object for further usage.

        :return: An instance of `SingleParticleTransitionMatrixOperation` that encapsulates
            the computed single-particle transition matrix based on the initialized weights.
            The shape of the SPTM is (n_classes, 2 * n_qubits, 2 * n_qubits).
        :rtype: SingleParticleTransitionMatrixOperation
        """
        h = torch.zeros(
            (int(self.output_size), 2 * self.n_qubits, 2 * self.n_qubits), dtype=self.R_DTYPE, device=self.torch_device
        )
        triu_indices = np.triu_indices(2 * self.n_qubits, k=1)
        h[:, triu_indices[0], triu_indices[1]] = self.weights
        h = h - h.mT
        sptm = torch.matrix_exp(h)
        return SingleParticleTransitionMatrixOperation(sptm, wires=self.wires)

    def preprocess_input(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        x = x.reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])
        if x.shape[-2] < self.MIN_INPUT_SIZE[-2] or x.shape[-1] < self.MIN_INPUT_SIZE[-1]:
            x = self.input_resize(x)
        return x

    @property
    def torch_device(self):
        return self.device

    @property
    def wires(self):
        return self.q_device.wires

    @property
    def initial_basis_state(self):
        return np.zeros(self.n_qubits, dtype=int)

    @property
    def output_size(self):
        return int(np.prod(self.output_shape))


In [None]:
dataset_name = "Digits2D"
fold_id = 0
batch_size = 32
random_state = 0
num_workers = 0
model_cls = NIFCNN
model_args = dict(
    n_qubits=16,
    learning_rate=2e-4,
    encoder_output_activation="Tanh",
)
job_output_folder_root = Path(os.getcwd()) / "data" / "lightning"
job_output_folder = Path(dataset_name) / model_cls.MODEL_NAME

checkpoint_folder = Path(job_output_folder) / "checkpoints"

In [None]:
datamodule = DataModule.from_dataset_name(
    dataset_name,
    fold_id=fold_id,
    batch_size=batch_size,
    random_state=random_state,
    num_workers=num_workers,
)
lightning_pipeline = LightningPipeline(
    model_cls=model_cls,
    datamodule=datamodule,
    checkpoint_folder=checkpoint_folder,
    max_epochs=10,
    max_time="00:00:03:00",  # DD:HH:MM:SS
    overwrite_fit=True,
    verbose=True,
    **model_args,
)

In [None]:
start_time = time.perf_counter()
metrics = lightning_pipeline.run()
print("⚡" * 20, "\nValidation Metrics:\n", metrics, "\n", "⚡" * 20)
test_metrics = lightning_pipeline.run_test()
print("⚡" * 20, "\nTest Metrics:\n", test_metrics, "\n", "⚡" * 20)
end_time = time.perf_counter()
elapsed_time = datetime.timedelta(seconds=end_time - start_time)
print(f"Time taken: {elapsed_time}")