# Kitaev Phase Estimation Algorithm

## Overview

- Type: Quantum subroutine for phase estimation.
- Estimated difficulty: Difficult
- Preliminaries: Hadamard phase estimation **(strongly recommended)**.
- Number of qubits: Size of $U$ and one auxiliary qubit.
- Number of operations: $O(1/\text{precision})$ calls to the controlled-$U$, plus $O(\log 1/\text{precision})$ classical processing.
- Maximum depth: $O(1/\text{precision})$

Given a unitary $U$ with eigenvalue $e^{2\pi i \theta}$, the task of phase estimation is to determine the phase $\theta$ to precision $\epsilon$ with probability $1-\delta$. The Kitaev phase estimation algorithm is a procedure for solving this problem built on Hadamard phase estimation. In fact, it is Hadamard phase estimation, but repeated several times with certain powers of $U$ and followed by classical post processing on the resulting data. Thus, the "quantum" part of the algorithm is essentially unchanged. For this reason, sometimes the previously described "Hadamard Phase Estimation" is termed Kitaev phase estimation.

Kitaev phase estimation is iterative, meaning the same type of operation is repeated multiple times, and nonadaptive, meaning prior measurements do not affect subsequent choice of operations. The procedure provides a quadratic improvement over Hadamard phase estimation in terms of accuracy.

## Description

**Basically the same as Hadamard, could show multiple in parallel and then a laptop gif receiving the information.**

### Quantum part
The basic idea of the Kitaev approach is as follows. We estimate $2^j \theta$ for $j = 0,1, \dots, m$, to precision at least $1/16$ by running Hadamard phase estimation on $U_j = U^{2^j}$. It is important that all of these estimates are taken in sequence, without resetting the initial state, or else each estimate may correspond to a different phase present in the initial superposition state. The choice of $1/16$ is for technical reasons, but the point is that this accuracy is a constant, independent of our final desired precision in the estimate. 

### Classical part

After the previous section, we have estimates $\rho_1,\rho_2, \dots, \rho_m$ of $\theta 2^{j-1}$. Thinking of $\theta$ as having a binary expansion $0.\theta_1\theta_2\dots\theta_m$, each $\rho_j$ is given by $0.\theta_j\theta_{j+1}\dots$, and therefore contains information about the bits in the expansion. Moreover, $\rho_j$ is going to be most heavily influenced by the $j$th bit of $\theta$. 

The classical part of the algorithm, given below in pseudocode, is a concrete procedure for extracting the bits $\theta_j$ using the real valued $\rho_j$. It proceeds from those $\rho_j$ corresponding to less significant bits, i.e., $j$ larger, and infers from these values the more significant bits. At the end of the protocol, the bit in the binary expansion of $\theta$ are inferred with reasonably high confidence.  

### Pseudocode (Adapted from Svore, Hastings, Freedman)
Input: 
1. black box controlled-$U$ acting on "main register", controlled on one auxiliary qubit. 
1. Input state $\vert\psi\rangle$ on the main register.
1. Specified precision $\epsilon$.

Output:
1. An $O(\epsilon)$ estimate $0.\alpha_1\alpha_2\dots\alpha_{m+2}$ of $\theta$ as a bitstring. 
1. An approximate projection of the input state onto the $\theta$ eigenspace of $U$

Procedure:
1. Set $m \in O(\log_2 (1/\epsilon)$.
1. Use Hadamard phase estimation to obtain estimates $\rho_j$ of $\theta_j = 2^{j-1} \theta$ for $j = 0, \dots m$ to precision $1/16$.
1. Set $\beta$ to be $\rho_m$ rounded to the nearest eighth. In binary, $\beta = 0.\alpha_m \alpha_{m+1}\alpha_{m+2}$ where $\alpha_j$ are bits.
1. for j = m - 1 to 1:
    1. Set $\alpha_j = 0$ if $\vert 0.0\alpha_{j+1}\alpha_{j+2} - \rho_j\vert_{\text{mod}\;1} < 1/4$
    1. Set $\alpha_j = 1$ otherwise
1. return $0.\alpha_1\alpha_2\dots\alpha_{m+2}$


### Qiskit Implementation

#### Imports

In [1]:
import numpy as np
from qiskit import ClassicalRegister, QuantumRegister, QuantumCircuit
from qiskit_aer.primitives import Sampler
from qiskit.circuit.library.standard_gates import SwapGate, TGate # For demonstration

First define a function `distance_mod1` for later convenience. This essentially computes the distance between two points on a circle of circumference $1$. 

In [2]:
def distance_mod1(x,y):
    """
    Args:
        x (float), y (float): Values in [0,1). If outside range, values are taken mod 1.
    Returns:
        The distance between x and y mod 1. That is, the distance between x and y on a circle with circumference 1.
    """
    # Put x and y within proper bounds
    x %= 1
    y %= 1
    return np.min([np.abs(x-y), 1-np.abs(x-y)])

#### Hadamard QPE dependencies

Because Kitaev QPE builds on the Hadamard QPE, we will need to borrow the functions which execute that.

In [3]:
def hadamard_circuit(U, add_s_gate=False, measure=True):
    """
    Args:
        U (Gate): Unitary gate in phase estimation problem
        add_s_gate (bool): Whether to at S gate to auxiliary register following controlled unitaries. 
            Equivalently, whether to perform C(U) (False) or C(i U) (True). 
        measure (bool): Whether to measure the auxiliary qubit (True) or not (False).
    Returns:
        QuantumCircuit that implements the Hadamard test
    """
    # Initialize registers and circuit
    aux = QuantumRegister(1, 'aux')
    main = QuantumRegister(U.num_qubits, 'q')
    circuit = QuantumCircuit(aux, main)
    
    # Construct controlled-U gate
    cU = U.control(1)
    
    # Add gates to circuit
    circuit.h(aux)
    circuit.append(cU, aux[:] + main[:])
    if add_s_gate:
        circuit.s(aux)
    circuit.h(aux)
    
    # Add measurement if option specified
    if measure:
        creg = ClassicalRegister(1, 'c')
        circuit.add_register(creg)
        circuit.measure(aux, creg)
    
    return circuit

# State of main register is not reset after each measurement. Allows for phase estimation over superpositions.
def coherent_Hadamard_circuit(U, Ncycles):
    """
    Args:
        U (Gate): Unitary gate to measure
        Ncycles: Number of measurements (iterations) on auxiliary qubit for phased and unphased Hadamard test. 
    Returns:
        QuantumCircuit that implements a coherent version of the Hadamard test, where the state of the main register is not reset following measurements.
    """
    hadamard_circ = hadamard_circuit(U, add_s_gate = False, measure = True)
    
    
    # Initialize circuit with first iteration as initial fencepost
    circuit = hadamard_circuit(U, add_s_gate = False, measure = False)
    creg = ClassicalRegister(2*Ncycles, 'c')
    circuit.add_register(creg)
    circuit.measure(0, creg[0])
    circuit.reset(0)
    
    # Perform nonphased measurements
    for k in range(1, Ncycles):
        circuit = circuit.compose(hadamard_circ, qubits = None, clbits = k)
        circuit.reset(0)
        
    hadamard_circ = hadamard_circuit(U, add_s_gate = True, measure = True)
    
    # Perform phased measurements
    for k in range(0, Ncycles):
        circuit = circuit.compose(hadamard_circ, qubits = None, clbits = k + Ncycles)
        circuit.reset(0)
    
    return circuit

def Hadamard_phase_estimation(U, Uprep = None, precision = 2 * 10**-1, coherent = False, confidence_factor = 2):
    """
    Args:
        U (Gate): Unitary gate in phase estimation problem
        Uprep (Instruction): State preparation on main register starting from all |0> state (If None, no state prep added). 
        precision (float): Desired precision of phase estimate. 
        coherent (bool): Whether to keep state of principal register (True) 
            or reset to original state (False) after each measurement.
        confidence_factor (float): A parameter which scales the number of iterations proportionally. Should be set on the order of 1.
            Larger values increase confidence of answer being within set precision.
    Returns:
        An estimate for the phase of some eigenstate, sampled according to the initial state, with specified precision, and confidence related to confidence_factor.
    """

    Ncycles = int(confidence_factor//(precision**2))
    
    # Do single shot, coherent phase estimation if coherent = True
    if coherent:
        phase_estimation_circuit = coherent_Hadamard_circuit(U, Ncycles)
        # Perform state prep, if any
        if Uprep != None: phase_estimation_circuit.compose(Uprep, qubits = range(1,U.num_qubits + 1), front = True, inplace = True)
        # Run circuit
        job = Sampler().run(phase_estimation_circuit, shots = 1)
        result_dict = job.result().quasi_dists[0].binary_probabilities()
        measurements = list(job.result().quasi_dists[0].binary_probabilities())[0]
        
        # Split results into phased and unphased measurements
        phased_measurements = measurements[0:Ncycles]
        unphased_measurements = measurements[Ncycles:2*Ncycles]

        
        # Compute Hamming weight to get counts
        phased_counts = np.sum(list(map(int,phased_measurements)))
        unphased_counts = np.sum(list(map(int,unphased_measurements)))
        
        # Extract Pr(0)
        unphased_p0 = 1 - unphased_counts/Ncycles
        phased_p0 = 1 - phased_counts/Ncycles
    else:  
        # Otherwise, do incoherent phase estimation
        unphased_circuit = hadamard_circuit(U, add_s_gate = False, measure = True)
        phased_circuit = hadamard_circuit(U, add_s_gate = True, measure = True)
        # Do state prep, if any
        if Uprep != None:
            unphased_circuit = unphased_circuit.compose(Uprep, range(1,unphased_circuit.num_qubits),front = True)
            phased_circuit = phased_circuit.compose(Uprep, range(1,unphased_circuit.num_qubits), front = True)

        # Run circuits
        unphased_results = Sampler().run(unphased_circuit, shots = Ncycles).result().quasi_dists[0]
        phased_results = Sampler().run(phased_circuit, shots = Ncycles).result().quasi_dists[0]

        # Extract Pr(0)
        unphased_p0 = unphased_results.get(0, 0)
        phased_p0 = phased_results.get(0, 0)

    
    # Get cosine and sine of phase
    costheta = 2*unphased_p0 - 1
    sintheta = 1- 2*phased_p0
    
    #Extract phase as final result
    theta = np.arctan2(sintheta, costheta)
    # Make branch correction to [0, 2 pi)
    if theta < 0: theta += 2*np.pi
    # Return number between 0,1
    theta = theta/(2*np.pi)
    return theta

#### Kitaev classical post processing

Next, we define the function performing the full Kitaev protocol. 

In [31]:
def Kitaev_phase_estimation(U, Uprep = None, precision = 10**-2, coherent = False, confidence_factor = 2):
    """
    Args:
        U (Gate): Unitary gate in phase estimation problem
        Uprep (Instruction): State preparation on main register starting from all |0> state (If None, no state prep added). 
        precision (float): Desired precision of phase estimate. 
        coherent (bool): Whether to keep state of principal register (True) 
            or reset to original state (False) after each measurement.
        confidence_factor (float): A parameter which scales the number of iterations proportionally. Should be set on the order of 1.
            Larger values increase confidence of answer being within set precision.          
    Returns:
        An estimate for the phase of some eigenstate, sampled according to the initial state, 
        with specified precision and confidence related to confidence_factor.
    """
    # Parameters of precision
    nbits = int(np.ceil(np.log2(1/precision)))
    nreps = nbits - 2 
    # Perform Hadamard phase estimation on U**(2**j) to fixed precision
    phases = np.zeros(nreps)
    for j in range(nreps):
        phases[j] = Hadamard_phase_estimation(U.repeat(2**j), Uprep, precision = 1/16, coherent = False)
    
    # Classical part to amplify precision
    # Round last to nearest octant in [0,1)
    beta = (int(np.round(8*phases[nreps - 1]))%8)/8
    #Retrieve last 3 bits in binary form
    last_bits = bin(int(beta*8))[2:]
    last_bits = '0'*(3-len(last_bits)) + last_bits # pad with zeros to get bbb format 
    # Do alpha recursion for remaining bits
    alpha = [0]*nbits
    alpha[nreps + 1]= int(last_bits[2])
    alpha[nreps] = int(last_bits[1])
    alpha[nreps-1] = int(last_bits[0])
    for j in range(nreps-2,-1,-1):
        # fill in the values of alpha
        alpha[j] = int(distance_mod1(phases[j], .5 + alpha[j+1]/4 + alpha[j+2]/8) <.25)
    
    return np.sum(alpha * (2.0**(-np.arange(1, nbits+1))))

As an illustration, let's take $U = \mathrm{SWAP}$ and take as initial state $\vert10\rangle$. It turns out that $\vert10\rangle$ is an equal superposition of a symmetric (+1) and antisymmetric (-1) eigenstate of $\mathrm{SWAP}$. Hence, Kitaev QPE, like any decent QPE algorithm, should return $\theta = 1/2$ or $\theta = 0$, each with probability $1/2$. This can be seen by running the cell below multiple times.

In [32]:
# State preparation
prep_circuit = QuantumCircuit(2)
prep_circuit.x(1)

print('Kitaev QPE result: theta =', Kitaev_phase_estimation(SwapGate(), Uprep = prep_circuit, precision = 2**-8))

Kitaev QPE result: theta = 0.5


## References and resources

- Svore, Hastings, Freedman, ["Faster Phase Estimation"](https://arxiv.org/abs/1304.0741) (2013)
- Kitaev, Shen, Vyalvi, ["Classical and Quantum Computation"](https://bookstore.ams.org/gsm-47#:~:text=This%20book%20is%20an%20introduction,of%20complexity%20of%20an%20algorithm.) pages 128-129