# Custom Operations

The core design of the library allows us to define our custom operations at different levels.

In [None]:
from traceback import print_exc

import numpy as np

from braandket import ArrayLike, PureStateTensor, pi
from braandket_circuit import BnkParticle, BnkRuntime, BnkState, H, M, QOperation, QParticle, QSystemStruct, Ry, Rz, X, \
    allocate_qubits, register_apply_impl
from braandket_circuit.utils import iter_struct

## 1. Overriding the `__call__` method

As we have seen in the previous sections, we can subclass the `QOperation` class and override the `__call__` method to define our custom operations.

Here is an example of a custom parameterized operation `Uzyz` as in [Parameterized Circuit](example4_parameterized_circuit.ipynb).

In [None]:
class Uzyz(QOperation):
    def __init__(self, thetas: ArrayLike):
        super().__init__()
        self.thetas = thetas

    def __call__(self, qubit: QParticle):
        Rz(self.thetas[0])(qubit)
        Ry(self.thetas[1])(qubit)
        Rz(self.thetas[2])(qubit)

Here is an other example of a custom operation `FourierSampling`

In [None]:
class FourierSampling(QOperation):
    def __call__(self, *qubits: QParticle):
        for qubit in qubits:
            H(qubit)

## 2. Without overriding the `__call__` method

It is worth noting that overriding the `__call__` method optional.


For example, we want to define an operation $P$, that has a phase shift only at the last diagonal element. 

$$
P(\theta)=\begin{pmatrix} 1 & & & \\ & \ddots & & \\ & & 1 & \\ & & & e^{i\theta} \end{pmatrix}
$$

But we are not clear yet how to implement it with the existing operations, then we can define it without overriding the `__call__` method.

In [None]:
class LastDiagonalPhase(QOperation):
    def __init__(self, theta: ArrayLike):
        super().__init__()
        self.theta = theta

Since we have given no information about how `LastDiagonalPhase` works, calling it will raise an error. Although such operations cannot be actually called, it can be used in many other processes like visualization and compilation.  

In [None]:
q0, q1 = allocate_qubits(2)

try:
    LastDiagonalPhase(pi)(q0, q1)
except NotImplementedError:
    print_exc()

## 3. Registering runtime-specific implementation

In some cases, although we won't describe the custom operation with other existing operations, we clearly know how it works, and we want it to be callable, then we can register a runtime-specific implementation via `register_apply_impl`.


Take the `LastDiagnalPhase` as an example, we can implement it when it is running on a simulator based on BnkRuntime. Yet we only implement the calculation for pure states with numpy values.

In [None]:
@register_apply_impl(BnkRuntime, LastDiagonalPhase)
def last_diagonal_phase_impl(rt: BnkRuntime, op: LastDiagonalPhase, *args: QSystemStruct):
    particles = tuple(particle for particle in iter_struct(args, atom_typ=BnkParticle))
    state = BnkState.prod(*(particle.state for particle in particles))
    if not isinstance(state.tensor, PureStateTensor):
        raise NotImplementedError

    values = state.tensor.values()
    if not isinstance(values, np.ndarray):
        raise NotImplementedError

    spaces = state.tensor.ket_spaces
    shape = values.shape
    values = rt.backend.reshape(values, [-1])

    phase = np.ones_like(values, dtype=np.complex128)
    phase[-1] *= np.exp(1j * op.theta)
    values = values * phase

    values = rt.backend.reshape(values, shape)
    state.tensor = PureStateTensor.of(values, spaces, backend=rt.backend)

After registering the runtime-specific implementation, we can now call the `LastDiagonalPhase` operation.

In [None]:
LastDiagonalPhase(pi)(q0, q1)

Then we can verify the effect of the `LastDiagonalPhase` operation.

In [None]:
q0, q1 = allocate_qubits(2)
X(q0)

# this part is equivalent to CNOT(q0, q1)
H(q1)
LastDiagonalPhase(pi)(q0, q1)
H(q1)

M(q1)