In [None]:
import quimb.tensor as qtn

In [None]:
from __future__ import annotations

__all__ = ["MPS"]

import copy
import numpy as np
from numpy.typing import NDArray
import quimb.tensor as qtn # type: ignore
from scipy import linalg # type: ignore
from typing import Type, Union

# Import `qickit.data.Data`
from qickit.data import Data # type: ignore

# Import `qickit.circuit.Circuit`
from qickit.circuit import Circuit # type: ignore

# Import `qickit.types.Collection` and `qickit.types.NestedCollection`
from qickit.types import Collection, NestedCollection # type: ignore

# Define `NumberType` as an alias for `int | float | complex`
NumberType = int | float | complex


class MPS:
    """ `qmprs.mps.MPS` is the class for creating and manipulating matrix product states (MPS).

    Parameters
    ----------
    `statevector` : qickit.data.Data | NestedCollection[NumberType], optional
        The statevector of the quantum system.
    `mps` : qtn.MatrixProductState, optional
        The matrix product state (MPS) of the quantum system.
    `bond_dimension` : int
        The maximum bond dimension of the MPS.

    Attributes
    ----------
    `statevector` : qickit.data.Data
        The statevector of the quantum system.
    `mps` : qtn.MatrixProductState
        The matrix product state (MPS) of the quantum system.
    `bond_dimension` : int
        The maximum bond dimension of the MPS.
    `num_sites` : int
        The number of sites for the MPS.
    `physical_dimension` : int
        The physical dimension of the MPS.
    `canonical_form` : str
        The canonical form of the MPS.
    `normalized` : bool
        Whether the MPS is normalized.

    Raises
    ------
    TypeError
        If `bond_dimension` is not an integer.
    ValueError
        If `bond_dimension` is less than 1.
        `bond_dimension` must be an integer greater than 0.
        Cannot initialize with both `statevector` and `mps`.
        Must provide either `statevector` or `mps`.
        Only supports MPS with physical dimension=2.

    Usage
    -----
    >>> statevector = [1, 2, 3, 4]
    >>> bond_dimension = 2
    >>> mps = MPS(statevector, bond_dimension)
    """
    def __init__(self,
                 statevector: Union[Data, NestedCollection[NumberType]] | None = None,
                 mps: qtn.MatrixProductState | None = None,
                 bond_dimension: int = 64) -> None:
        """ Initialize a `qmprs.mps.MPS` instance. Pass only `statevector` to define
        the MPS from the statevector. Pass only `mps` to define the MPS from the MPS.
        """
        # Ensure `bond_dimension` is an integer greater than 0
        if not isinstance(bond_dimension, int) or bond_dimension < 1:
            raise ValueError("`bond_dimension` must be an integer greater than 0.")

        # Determine initialization path
        if statevector is not None and mps is not None:
            raise ValueError("Cannot initialize with both `statevector` and `mps`.")

        elif statevector is not None:
            # Initialize from statevector
            if not isinstance(statevector, Data):
                statevector = Data(statevector)
            statevector.to_quantumstate()
            self.statevector = statevector
            self.mps = self.from_statevector(statevector, bond_dimension)

        elif mps is not None:
            # Initialize from MPS
            if not isinstance(mps, qtn.MatrixProductState):
                raise TypeError("`mps` must be a `qtn.MatrixProductState` instance.")
            self.mps = mps
            self.statevector = self.to_statevector(mps)

        else:
            raise ValueError("Must provide either `statevector` or `mps`.")

        # Check if the MPS is normalized
        self.normalized = self.mps.norm() == 1

        # Define the maximum bond dimension
        self.bond_dimension = bond_dimension

        # Define the number of sites for the MPS
        self.num_sites = self.statevector.num_qubits

        # Define the physical dimension of the MPS (for qubits this must be 2)
        if self.mps.phys_dim() != 2:
            raise ValueError("Only supports MPS with physical dimension=2.")
        else:
            self.physical_dimension = 2

    @staticmethod
    def from_statevector(statevector: Data,
                         max_bond_dimension: int) -> qtn.MatrixProductState:
        """ Define the MPS from the statevector.

        Parameters
        ----------
        `statevector` : qickit.data.Data
            The statevector of the quantum system.

        Returns
        -------
        `mps` : qtn.MatrixProductState
            The MPS of the quantum system.

        Usage
        -----
        >>> statevector = Data([1, 2, 3, 4])
        >>> max_bond_dimension = 2
        >>> mps = MPS.from_statevector(statevector, max_bond_dimension)
        """
        # Define the number of sites
        num_sites = statevector.num_qubits

        # Reshape the vector to N sites where N is the number of qubits
        # needed to represent the state vector
        site_dimensions = [2] * num_sites

        # Generate MPS from the tensor arrays
        mps = qtn.MatrixProductState.from_dense(statevector.data, site_dimensions)

        # Compress the bond dimension of the MPS to the maximum bond dimension specified
        for i in range(num_sites-1):
            qtn.tensor_core.tensor_compress_bond(mps[i], mps[i+1], max_bond = max_bond_dimension)

        return mps

    @staticmethod
    def to_statevector(mps: qtn.MatrixProductState) -> Data:
        """ Convert the MPS to a statevector.

        Parameters
        ----------
        `mps` : qtn.MatrixProductState
            The matrix product state (MPS) of the quantum system.

        Returns
        -------
        `statevector` : qickit.data.Data
            The statevector of the quantum system.

        Usage
        -----
        >>> mps = qtn.MatrixProductState.random([2, 2, 2, 2], 2)
        >>> statevector = MPS.to_statevector(mps)
        """
        # Define the statevector from the MPS using `.to_dense()` method
        statevector = Data(mps.to_dense())

        # Convert the statevector to a quantum state
        statevector.to_quantumstate()

        return statevector

    def normalize(self) -> None:
        """ Normalize the MPS.

        Usage
        -----
        >>> mps.normalize()
        """
        self.mps.normalize()

    def canonicalize(self,
                     mode: str,
                     normalize=False) -> None:
        """ Convert the MPS to the canonical form. This states how the singular values from
        the SVD are absorbed into the left or right tensors.
        - If `mode` is "left", the singular values are absorbed into the tensors to their
        right. (all tensors contract to unit matrix from left)

                          i              i
            >->->->->->->-o-o-         +-o-o-
            | | | | | | | | | ...  =>  | | | ...
            >->->->->->->-o-o-         +-o-o-

        - If `mode` is "right", the singular values are absorbed into the tensors to their
        left. (all tensors contract to unit matrix from right)

                   i                           i
                -o-o-<-<-<-<-<-<-<          -o-o-+
             ... | | | | | | | | |   ->  ... | | |
                -o-o-<-<-<-<-<-<-<          -o-o-+

        Parameters
        ----------
        `mode` : str
            The mode of canonicalization, either "left" or "right".
        `normalize` : bool, optional
            Whether to normalize the state. This is different from the `.normalize()` method.

        Raises
        ------
        ValueError
            If `mode` is not "left" or "right".

        Usage
        -----
        >>> mps.canonicalize("left")
        >>> mps.canonicalize("right")
        """
        if mode == "left":
            self.mps.left_canonize(normalize=normalize)
            self.mps.canonical_form = "left"
        elif mode == "right":
            self.mps.right_canonize(normalize=normalize)
            self.mps.canonical_form = "right"
        else:
            raise ValueError("`mode` must be either 'left' or 'right'.")

    def compress(self,
                 max_bond_dimension: int | None = None,
                 mode: str | None = None) -> None:
        """ SVD Compress the bond dimension of the MPS.

         a)│   │        b)│        │        c)│       │
         ━━●━━━●━━  ->  ━━>━━○━━○━━<━━  ->  ━━>━━━M━━━<━━
           │   │          │  ....  │          │       │
          <*> <*>          contract              <*>
          QR   LQ            -><-                SVD

            d)│            │        e)│   │
        ->  ━━>━━━ML──MR━━━<━━  ->  ━━●───●━━
              │....    ....│          │   │
            contract  contract          ^compressed bond
               -><-      -><-

        Parameters
        ----------
        `max_bond_dimension` : int
            The maximum bond dimension of the MPS.
        `mode` : str, optional
            The mode of compression, either "left", "right", or "flat".

        Raises
        ------
        ValueError
            If `mode` is not "left", "right", or "flat".

        Usage
        -----
        >>> mps.compress(max_bond_dimension=16)
        """
        # If no parameters are passed, perform a left canonicalization followed by right compression
        if not (max_bond_dimension or mode):
            self.mps.compress()

        # If only `max_bond_dimension` is specified, compress the MPS with the specified bond dimension
        elif not mode:
            for i in range(self.num_sites-1):
                qtn.tensor_core.tensor_compress_bond(self.mps[i], self.mps[i+1], max_bond = max_bond_dimension)

        else:
            if mode in ["left", "right", "flat"]:
                # If only `mode` is specified, compress the MPS with the specified mode
                if not max_bond_dimension:
                    self.mps.compress(form=mode)

                # If both `mode` and `max_bond_dimension` are specified, compress the MPS with the specified mode
                else:
                    self.mps.compress(form=mode, max_bond=max_bond_dimension)
            else:
                raise ValueError(f"`mode` must be either 'left', 'right', or 'flat'. Received {mode}.")

    def contract_site(self,
                      sites: Collection[int]) -> None:
        """ Contract/block tensors sites together.

        # TODO: Add Figure

        Parameters
        ----------
        `sites` : Collection[int]
            The sites to contract.

        Raises
        ------
        TypeError
            If `sites` is not a collection of integers.
            If any element of `sites` is not an integer.

        Usage
        -----
        >>> mps.contract([0, 1])
        """
        if not isinstance(sites, Collection):
            raise TypeError("`sites` must be a collection of integers.")
        elif not all(isinstance(index, int) for index in sites):
            raise TypeError("All elements of `sites` must be integers.")

        self.mps ^= (self.mps.site_tag(i) for i in sites)

    def contract_index(self,
                       index: str) -> None:
        """ Contract tensors connected by the given index.

        # TODO: Add Figure

        Parameters
        ----------
        `index` : str
            The index to contract.

        Usage
        -----
        >>> mps.contract_ind("k0")
        """
        self.mps.contract_ind(index)

    # TODO: Should this return sth, or should we access the isometries and positive semidefinite matrix from the MPS?
    def polar_decompose(self,
                        indices: Collection[int]) -> None:
        """ Perform a polar decomposition on the MPS to retrieve the
        isometries V and positive semidefinite matrix P.

        # TODO: Add diagram (9) from Wei et al.

        Parameters
        ----------
        `indices` : Collection[int]
            The indices of each tensor to polar decompose.

        Returns
        -------
        `isometries` : qtn.Tensor
            The isometries.
        `positive_semidefinite_matrix` : qtn.Tensor
            The positive semidefinite matrix.

        Usage
        -----
        >>> mps.polar_decompose()
        """
        self.mps.split_tensor(
            tags=self.mps.site_tag(indices[0]),
            left_inds=[self.mps.site_ind(i) for i in indices],
            method="polar_right",
            ltags="V",
            rtags="P",
        )

    def permute(self,
                shape: str) -> None:
        """ Permute the indices of each tensor in the MPS to match `shape`.

        Parameters
        ----------
        `shape` : str
            The shape to permute, being "lrp" or "lpr".

        Raises
        ------
        ValueError
            If `shape` is not "lrp" or "lpr".

        Usage
        -----
        >>> mps.permute("lrp")
        >>> mps.permute("lpr")
        """
        if shape in ["lrp", "lpr"]:
            self.mps.permute_arrays(shape)
        else:
            raise ValueError(f"`shape` must be either 'lrp' or 'lpr'. Received {shape}.")

    # TODO: Redo the comments for better clarity.
    def get_submps_indices(self) -> list[tuple[int, int]]:
        """ Get the indices of the tensors at each site of the MPS.

        Returns
        -------
        `submps_indices` : list[tuple[int, int]]
            The indices of the MPS tensors.

        Usage
        -----
        >>> mps.get_submps_indices()
        """
        # Initialize the indices of the tensors at each site of the MPS
        submps_indices = []

        # If the MPS only has one site, only add the (0, 0) coordinate to the sub MPS indices
        if self.num_sites == 1:
            return [(0, 0)]

        # Otherwise, iterate over the sites of the MPS to define the left (first) and right (last)
        # sites' dimensions
        for site in range(self.num_sites):
            # Initialize the dimension for the first and last sites
            dim_left, dim_right = 1, 1

            # If this is the first site, then only define the right dimension
            if site == 0:
                _, dim_right = self.mps[site].shape

            # If this is the last site, then only define the left dimension
            elif site == (self.num_sites - 1):
                dim_left, _ = self.mps[site].shape

            # Otherwise, define both the left and right dimensions for the intermediate sites
            else:
                dim_left, _, dim_right = self.mps[site].shape

            # If the left and right dimensions are both less than 2,
            # then add the (site, site) coordinate to the sub MPS indices
            if dim_left < 2 and dim_right < 2:
                submps_indices.append((site, site))

            # If the left dimension is less than 2 and the right dimension
            # is greater than or equal to 2, then set the temp variable to the site
            elif dim_left < 2 and dim_right >= 2:
                temp = site

            # If the left dimension is greater than or equal to 2 and the right dimension
            # is less than 2, then add the (temp, site) coordinate to the sub MPS indices
            elif dim_left >= 2 and dim_right < 2:
                submps_indices.append((temp, site))

        return submps_indices

    # TODO: Redo the comments for better clarity.
    def _generate_two_site_unitary(self,
                                   mps_data: NDArray[np.complex128],
                                   generated_unitaries: list[qtn.Tensor]) -> list[qtn.Tensor]:
        """ Generate a two site unitary for a given tensor in the MPS.

        Parameters
        ----------
        mps_data : NDArray[np.complex128]
            The data of the MPS tensor at the specified index.
        generated_unitaries : list[qtn.Tensor]
            The list of generated unitaries.

        Returns
        -------
        `generated_unitaries` : list[qtn.Tensor]
            The list of generated unitaries.
        """
        # Define the physical dimension of the MPS
        phy_dim = self.physical_dimension

        # Initialize the unitary with 0s
        unitary = np.zeros((phy_dim, phy_dim, phy_dim, phy_dim), dtype=np.complex128)

        # Set the first row of the unitary to the MPS tensor at the specified site
        unitary[0, :, :, :] = mps_data

        # Set the second row of the unitary to the null space of the MPS tensor at the specified site
        kernel = linalg.null_space(mps_data.reshape((phy_dim, -1)).conj())

        # Multiply the kernel by 1/exp(1j * angle of the first row of the kernel)
        kernel = kernel * (1 / np.exp(1j * np.angle(kernel[0, :])))
        unitary[1:phy_dim, :, :, :] = kernel.reshape((phy_dim, phy_dim, phy_dim, phy_dim - 1)).transpose((3, 2, 0, 1))

        # Transpose the unitary, such that the indices of the unitary are ordered as unitary(L,B,T,R)
        unitary = unitary.transpose((0, 1, 3, 2))

        # Transpose the unitary, such that the indices of the unitary are ordered as unitary(B,L,R,T)
        unitary = unitary.transpose((1, 0, 3, 2))

        # Convert the unitary to a qtn.Tensor
        # .T at the end is useful for the application of unitaries as quantum circuit
        unitary = qtn.Tensor(unitary.reshape((phy_dim**2, phy_dim**2)).T, inds=["L", "R"], tags={"G"})

        # Append the unitary to the list of generated unitaries
        generated_unitaries.append(unitary)

        return generated_unitaries

    # TODO: Redo the comments for better clarity.
    def _generate_first_site_unitary(self,
                                     mps_data: np.ndarray,
                                     generated_unitaries: list[qtn.Tensor]) -> list[qtn.Tensor]:
        """ Generate the first site unitary for a given tensor in the MPS.

        Parameters
        ----------
        mps_data : NDArray[np.number]
            The data of the MPS tensor at the specified index.
        generated_unitaries : List[qtn.Tensor]
            The list of generated unitaries.

        Returns
        -------
        `generated_unitaries` : list[qtn.Tensor]
            The list of generated unitaries.
        """
        # Define the physical dimension of the MPS
        phy_dim = self.physical_dimension

        # Initialize the unitary with 0s
        unitary = np.zeros((phy_dim, phy_dim, phy_dim, phy_dim), dtype=np.complex128)

        # Set the first row of the unitary to the data of the MPS at the specified index
        unitary[0, 0, :, :] = mps_data.reshape((phy_dim, -1))

        # Get the kernel from the data of the MPS at the specified index
        kernel = linalg.null_space(mps_data.reshape((1, -1)).conj())

        # Iterate over the physical dimension
        for i in range(phy_dim):
            # Iterate over the physical dimension
            for j in range(phy_dim):
                # If the indices are both 0, continue
                if i == 0 and j == 0:
                    continue

                # Define the index
                index = i * phy_dim + j

                # Set the unitary at the specified index to the kernel at the specified index
                unitary[i, j, :, :] = kernel[:, index - 1].reshape((phy_dim, phy_dim))

        # Transpose the unitary, such that the indices of the unitary are ordered as unitary(L,B,T,R)
        unitary = unitary.transpose((0, 1, 3, 2))

        # Transpose the unitary, such that the indices of the unitary are ordered as unitary(B,L,R,T)
        unitary = unitary.transpose((1, 0, 3, 2))

        # Convert the unitary to a qtn.Tensor
        # .T at the end is useful for the application of unitaries as quantum circuit
        unitary = qtn.Tensor(unitary.reshape((phy_dim**2, phy_dim**2)).T, inds=["L", "R"], tags={"G"})

        # Append the unitary to the list of generated unitaries
        generated_unitaries.append(unitary)

        return generated_unitaries

    # TODO: Redo the comments for better clarity.
    def _generate_single_site_unitary(self,
                                      mps_data: NDArray[np.complex128],
                                      start_index: int,
                                      end_index: int,
                                      generated_unitaries: list[qtn.Tensor]) -> list[qtn.Tensor]:
        """ Generate a single site unitary for a given tensor in the MPS.

        Parameters
        ----------
        mps_data : NDArray[np.complex128]
            The data of the MPS tensor at the specified index.
        start_index : int
            The starting index of the sub-MPS.
        end_index : int
            The ending index of the sub-MPS.
        generated_unitaries : List[qtn.Tensor]
            The list of generated unitaries.

        Returns
        -------
        `generated_unitaries` : list[qtn.Tensor]
            The list of generated unitaries.
        """
        # Define the physical dimension of the MPS
        phy_dim = self.physical_dimension

        # Check if the sub-MPS has only one site
        if end_index == start_index:
            # Initialize the unitary with 0s
            unitary = np.zeros((phy_dim, phy_dim), dtype=np.complex128)

            # Set the first row of the unitary to the data of the MPS at the specified index
            unitary[0, :] = mps_data.reshape((1, -1))

            # Set the second row of the unitary to the null space of the data of the MPS at the specified index
            unitary[1, :] = linalg.null_space(mps_data.reshape(1, -1).conj()).reshape(1, -1)
        else:
            # If the sub-MPS has more than one site, the unitary is the MPS tensor at the specified site
            unitary = mps_data

        # Convert the unitary to a qtn.Tensor
        # .T at the end is useful for the application of unitaries as quantum circuit
        unitary = qtn.Tensor(unitary.reshape((phy_dim, phy_dim)).T, inds=("v", "p"), tags={"G"})

        # Append the unitary to the list of generated unitaries
        generated_unitaries.append(unitary)

        return generated_unitaries

    def generate_unitaries(self) -> list:
        """ Generate the unitaries of the MPS.

        Returns
        -------
        `generated_unitary_list` : list
            A list of unitaries to be applied to the MPS.

        Raises
        ------
        ValueError
            If all the generated unitaries are not unitary.

        Usage
        -----
        >>> mps.generate_unitaries()
        """
        # Copy the MPS (as the MPS will be modified in place)
        mps_copy = copy.deepcopy(self.mps)

        # Initialize the list of generated unitaries
        generated_unitary_list = []

        # Get the indices of the MPS tensors
        sub_mps_indices = self.get_submps_indices()

        # Iterate over the tensors' starting and ending indices
        for start_index, end_index in sub_mps_indices:
            generated_unitaries: list[qtn.Tensor] = []

            # Iterate over the range from start_index to end_index (inclusive)
            for index in range(start_index, end_index + 1):
                if index == end_index:
                    # Generate a single site unitary for the current tensor
                    generated_unitaries = self._generate_single_site_unitary(mps_copy[index].data,
                                                                             start_index,
                                                                             end_index,
                                                                             generated_unitaries)

                elif index != start_index:
                    # Generate a two site unitary for the current tensor
                    generated_unitaries = self._generate_two_site_unitary(mps_copy[index].data,
                                                                          generated_unitaries)

                else:
                    # Generate a first site unitary for the current tensor
                    generated_unitaries = self._generate_first_site_unitary(mps_copy[index].data,
                                                                            generated_unitaries)

            # Check if all the generated unitaries are unitary
            for generated_unitary in generated_unitaries:
                if not np.allclose(np.eye(generated_unitary.shape[0]) - generated_unitary.data @ generated_unitary.data.T.conj(), 0):
                    raise ValueError("ValueError : every generated unitary in the list must be a unitary.")
                if not np.allclose(np.eye(generated_unitary.shape[0]) - generated_unitary.data.T.conj() @ generated_unitary.data, 0):
                    raise ValueError("ValueError : every generated unitary in the list must be a unitary.")

            generated_unitary_list.append([start_index,
                                           end_index,
                                           generated_unitaries])

        return generated_unitary_list

    # TODO: Redo the comments for better clarity.
    def generate_bond_d_unitary(self) -> list:
        """ Generate the unitary for the bond-d (physical dimension) compression of the MPS.

        Returns
        -------
        `generated_unitary_list` : list
            A list of unitaries to be applied to the MPS.

        Usage
        -----
        >>> mps.generate_bond_d_unitary()
        """
        # Copy the MPS (as the MPS will be modified in place with `.compress` and `.canonicalize` methods)
        mps_copy = copy.deepcopy(self)

        # Compress the MPS to a bond dimension of the physical dimension of the MPS
        mps_copy.compress(mode="right", max_bond_dimension=self.physical_dimension)

        # Right canonicalize the compressed MPS
        mps_copy.canonicalize(mode="right", normalize=True)

        # Generate the unitaries
        generated_unitary_list = mps_copy.generate_unitaries()

        return generated_unitary_list

    def _apply_unitary_layer(self,
                             generated_unitary_list: list) -> None:
        """ Apply the unitary layer on the MPS.

        Parameters
        ----------
        `generated_unitary_list` : list
            A list of unitaries to be applied to the MPS.
        """
        # Iterate over the generated unitary list and the start and end indices
        for start_index, end_index, generated_unitaries in generated_unitary_list:
            # Iterate over the indices of the MPS
            for index in range(start_index, end_index + 1):
                # If the index is the end index
                if index == end_index:
                    # Apply the generated unitary gates to the MPS (use `.gate_` as the operation is inplace)
                    # o-o-o-o-o-o-o
                    # | | | | | | |
                    #     GGG
                    #     | |
                    self.mps.gate_(generated_unitaries[index - start_index].data, where=[index])

                    # Define the location to contract the tensors after applying the unitary gates
                    loc = np.where([isinstance(self.mps[site], tuple) for site in range(self.num_sites)])[0][0]

                    # Contract the tensors at the specified location
                    # o-o-o-GGG-o-o-o
                    # | | | / \ | | |
                    self.contract_index(self.mps[loc][-1].inds[-1])

                else:
                    # Apply a two-site gate and then split resulting tensor to retrieve the MPS form:
                    #     -o-o-A-B-o-o-
                    #      | | | | | |            -o-o-GGG-o-o-           -o-o-X~Y-o-o-
                    #      | | GGG | |     ==>     | | | | | |     ==>     | | | | | |
                    #      | | | | | |                 i j                     i j
                    #          i j
                    # As might be found in Time-evolving block decimation (TEBD) algorithm
                    self.mps.gate_split_(generated_unitaries[index - start_index].data,
                                         where=[index, index + 1])

        # Permute the arrays of the MPS
        self.permute(shape="lpr")

        # Compress the MPS
        self.compress(mode="right")

    def _apply_inverse_unitary_layer(self,
                                     generated_unitary_list: list):
        """ Apply the inverse unitary layer on the MPS.

        Parameters
        ----------
        `generated_unitary_list` : list
            A list of unitaries to be applied to the MPS.
        """
        # Iterate over the generated unitary list and the start and end indices
        for start_index, end_index, generate_unitaries in generated_unitary_list:
            # Iterate over the indices of the MPS in reverse order
            for index in list(reversed(range(start_index, end_index + 1))):
                # If the index is the end index
                if index == end_index:
                    # Add the generated unitary gates to the MPS (use `.gate_` as the operation is inplace)
                    # o-o-o-o-o-o-o
                    # | | | | | | |
                    #     GGG
                    #     | |
                    self.mps.gate_(generate_unitaries[index - start_index].data.conj().T, where=[index])

                    # Define the location to contract the tensors after applying the unitary gates
                    loc = np.where([isinstance(self.mps[jt], tuple) for jt in range(self.num_sites)])[0][0]

                    # Contract the tensors at the specified location
                    # o-o-o-GGG-o-o-o
                    # | | | / \ | | |
                    self.contract_index(self.mps[loc][-1].inds[-1])

                else:
                    # Apply a two-site gate and then split resulting tensor to retrieve the MPS form:
                    #     -o-o-A-B-o-o-
                    #      | | | | | |            -o-o-GGG-o-o-           -o-o-X~Y-o-o-
                    #      | | GGG | |     ==>     | | | | | |     ==>     | | | | | |
                    #      | | | | | |                 i j                     i j
                    #          i j
                    # As might be found in Time-evolving block decimation (TEBD) algorithm
                    self.mps.gate_split_(generate_unitaries[index - start_index].data.conj().T,
                                         where=[index, index + 1])

        # Permute the arrays of the MPS
        self.permute(shape='lpr')

    def apply_unitary_layer(self,
                            unitary_layer: list,
                            inverse: bool = False) -> None:
        """ Apply the unitary layer on the MPS.

        Parameters
        ----------
        `unitary_layer` : list
            A list of unitaries to be applied to the MPS.
        `inverse` : bool, optional
            Whether to apply the inverse unitary layer.

        Usage
        -----
        >>> mps.apply_unitary_layer(unitary_layer, inverse=True)
        """
        if inverse:
            self._apply_inverse_unitary_layer(unitary_layer)
        else:
            self._apply_unitary_layer(unitary_layer)

    def apply_unitary_layers(self,
                             unitary_layers: list[list],
                             inverse: bool = False) -> None:
        """ Apply the unitary layers on the MPS.

        Parameters
        ----------
        `unitary_layers` : list[list]
            A list of unitary layers to be applied to the MPS.
        `inverse` : bool, optional
            Whether to apply the inverse unitary layers.

        Usage
        -----
        >>> mps.apply_unitary_layers(unitary_layers, inverse=True)
        """
        # Iterate over the unitary layers in reverse order, and apply the unitary layers to the MPS
        for layer in reversed(unitary_layers):
            self.apply_unitary_layer(layer, inverse=inverse)

    @staticmethod
    def _circuit_from_unitary_layer(circuit: Circuit,
                                    unitary_layer: list) -> None:
        """ Apply a unitary layer to the quantum circuit.

        Parameters
        ----------
        `circuit` : qickit.circuit.Circuit
            The quantum circuit.
        `unitary_layer` : list
            The unitary layer to be applied to the circuit.
        """
        # Iterate over the generated unitary list
        for start_index, end_index, generated_unitaries in unitary_layer:
            # Iterate over the start and end indices
            for index in range(start_index, end_index + 1):
                # Define the unitary matrix
                unitary = generated_unitaries[index - start_index].data

                # If this is the last index, then apply the unitary to the last qubit
                if index == end_index:
                    circuit.unitary(unitary, [index])

                # Otherwise, apply the unitary to the current and next qubits
                else:
                    circuit.unitary(unitary, [index + 1, index])

    def circuit_from_unitary_layers(self,
                                    qc_framework: Type[Circuit],
                                    unitary_layers: list[list]) -> Circuit:
        """ Generate a quantum circuit from the MPS unitary layers.

        Parameters
        ----------
        `qc_framework` : type[qickit.circuit.Circuit]
            The quantum circuit framework.
        `unitary_layers` : list[list]
            A list of unitary layers to be applied to the circuit.

        Returns
        -------
        `circuit` : qickit.circuit.Circuit
            The quantum circuit.
        """
        # Define the quantum circuit
        circuit = qc_framework(self.num_sites, self.num_sites)

        # Iterate over the unitary layers in reverse order and apply the unitary layer
        for layer in reversed(range(len(unitary_layers))):
            MPS._circuit_from_unitary_layer(circuit, unitary_layers[layer])

        return circuit

    # TODO: Specify the parameters for best readability
    def draw(self) -> None:
        """ Draw the MPS.

        Returns
        -------
        `fig`
            The figure of the MPS.

        Usage
        -----
        >>> mps.draw()
        """
        self.mps.draw()

In [None]:
from __future__ import annotations

__all__ = ["Sequential"]

# Import `qickit.circuit.Circuit`
from qickit.circuit import Circuit # type: ignore

# Import `qmprs.mps.MPS`
from qmprs.mps import MPS

# Import `qmprs.synthesis.mps_encoding.MPSEncoder`
from qmprs.synthesis.mps_encoding import MPSEncoder


class Sequential(MPSEncoder):
    """ `qmprs.synthesis.mps_encoder.Sequential` is the class for preparing MPS
    using Sequential encoding. The circuit is constructed using the disentangling
    algorithm described in the 2019 paper by Shi-Ju Ran.

    ref: https://arxiv.org/abs/1908.07958

    Notes
    -----
    - The circuit depth scales $O(N * \chi^2)$ where N is the number of qubits and $\chi$
    is the bond dimension.
    - The sequential encoding approach allows for encoding of long-range correlated states.

    Parameters
    ----------
    `circuit_framework` : Type[Circuit]
        The quantum circuit framework.
    """
    def prepare_mps(self,
                    mps: MPS,
                    **kwargs) -> Circuit:
        """ Prepare the quantum state using MPS.

        Parameters
        ----------
        `mps` : qmprs.mps.MPS
            The MPS to be prepared.
        `num_layers` : int
            The number of sequential layers. Passed as a kwarg.
        """
        # Define the number of layers
        num_layers = kwargs.get("num_layers")

        # Check if the number of layers is a positive integer
        if not isinstance(num_layers, int) or num_layers < 1:
            raise ValueError("The number of layers must be a positive integer.")

        # Normalize the MPS
        mps.normalize()

        # Compress the MPS to canonical form
        mps.compress(mode="right")

        # Define the circuit to prepare the MPS
        circuit = self.circuit_framework(mps.num_sites, mps.num_sites)

        def sequential_unitary_circuit(mps: MPS,
                                       circuit: Circuit) -> Circuit:
            """ Sequentially apply unitary layers to a MPS to prepare a target state.

            Parameters
            ----------
            `mps` : qmprs.mps.MPS
                The MPS state.
            `circuit` : Circuit
                The quantum circuit.

            Returns
            -------
            `circuit` : Circuit
                The quantum circuit preparing the MPS.
            """
            # Define the unitary layers
            unitary_layers: list = []

            # Permute the arrays to the left-right canonical form
            mps.permute(shape="lpr")
            mps.canonicalize("right", normalize=True)

            # Iterate over the number of layers
            for _ in range(num_layers):
                # Generate the unitary for the bond-d (physical dimension) compression of the MPS.
                unitary_layer = mps.generate_bond_d_unitary()

                # Append the unitary layer to the list of unitary layers
                unitary_layers.append(unitary_layer)

                # Apply the inverse unitary layer on the wavefunction
                mps.apply_unitary_layer(unitary_layer, inverse=True)

                # Normalize the MPS and convert to right canonical form
                mps.canonicalize(mode="right", normalize=True)

                # Left canonicalize and right compress the MPS (default mode)
                mps.compress()

            # Generate the quantum circuit from the unitary layers
            circuit.add(mps.circuit_from_unitary_layers(type(circuit), unitary_layers), range(mps.num_sites))

        # Define the sequential unitary circuit that prepares the target MPS with
        # the specified parameters
        sequential_unitary_circuit(mps, circuit)

        # Apply a vertical reverse
        circuit.vertical_reverse()

        return circuit