# Initalizing quantum circuit to arb state

Following https://arxiv.org/pdf/quant-ph/0406176.pdf

In [1]:
#                  n : ──\\──(C-0)─────────────(C-0)────────────
#    |ψ> (n+1)                 │                 │             
#                  1 : ─────── Rz (-phi) ─────── Ry(-theta)──── |ψ''>

Any qubit state can be written as:

$$ |\psi \rangle = \alpha_{0}|0\rangle + \alpha_{1}|1\rangle = r e^{i t/2} \bigg[ e^{-i\phi/2} \cos\big(\frac{\theta}{2}\big)|0\rangle +  e^{+i\phi/2} \sin\big(\frac{\theta}{2}\big)|1\rangle \bigg]$$

- the constant phase $r e^{i t/2}$ can be ignored!
- state is defined by $\theta$ and $\phi$

We can use $R_{y}(\theta)$ and $R_{z}(\phi)$ to rotate general state to zero state. First rotate anti-clockwise around z-axis by $\psi$ then anti-clockwise about x by $\theta$:

$$R_{y}(-\theta) R_{z}(-\phi) |\psi \rangle =  r e^{i t/2}|0 \rangle$$

$$R_{y}(\theta-\pi) R_{z}(\pi-\phi) |\psi \rangle =  r e^{i (t-\pi)/2}|1 \rangle$$

The second equation is important in this derivation

$$R_{y}(-\theta) R_{z}(-\phi) |\psi \rangle =  r e^{i t/2}|0 \rangle$$

Aka any state can be taken to $|0 \rangle$ state using $\phi$ and $\theta$ rotation!



For any multi-qubit state, we can write:

$$|\psi \rangle = \alpha_{0_{0}}|000 \rangle + \alpha_{0_{1}}|001\rangle + \alpha_{1_{0}}|010 \rangle + \alpha_{1_{1}}+ |011\rangle \alpha_{2_{0}}+|100 \rangle + \alpha_{2_{1}}|101\rangle + \alpha_{3_{0}}|110 \rangle + \alpha_{3_{1}}|111\rangle
$$

- here $\alpha_{A_{B}}$, $A$ runs through the states $2^{n-1}-1$ and $B$ is the value of the rightmost bit!


Then factor out RIGHT-most bit:

$$|\psi \rangle =  |00 \rangle \otimes \big[ \alpha_{0_{0}}|0 \rangle + \alpha_{0_{1}}|1\rangle \big] + \\ 
\: \: \: \: \: \: \: \: \: \:   |01 \rangle \otimes \big[ \alpha_{1_{0}}|0 \rangle + \alpha_{1_{1}}|1\rangle \big]+ \\ 
\: \: \: \: \: \: \: \: \: \:   |10 \rangle \otimes \big[ \alpha_{2_{0}}|0 \rangle + \alpha_{2_{1}}|1\rangle \big]+ \\ 
\: \: \: \: \: \: \:    |11 \rangle \otimes  \big[ \alpha_{3_{0}}|0 \rangle + \alpha_{3_{1}}|1\rangle \big] $$

We can re-write this as:  
 
$$|\psi \rangle =  |00 \rangle \otimes \big[|\rho_{0} \rangle \big] + \\ 
\: \: \: \: \: \: \: \: \: \: |01 \rangle \otimes \big[|\rho_{1} \rangle \big]+ \\ 
\: \: \: \: \: \: \: \: \: \:  |10 \rangle \otimes \big[|\rho_{2} \rangle \big]+ \\ 
\: \: \: \: \: \: \:   |11 \rangle \otimes  \big[|\rho_{3} \rangle \big] $$

- Then we need to map each of the $|\rho_{0} \rangle$ to $|\rho_{(2^{n-1}-1)} \rangle$ states to $|0 \rangle$
    - by finding appropriate $\phi$ and $\theta$ angles 


- Doing this Simultanously on all states amounts to doing the following unitary:
    - this disentangles the least significant (right-most) bit at each step!

$$
U=
\left(\begin{array}{cccc}
R_{y}\left(-\theta_{0}\right) R_{z}\left(-\phi_{0}\right) & & & \\
& R_{y}\left(-\theta_{1}\right) R_{z}\left(-\phi_{1}\right) & & \\
& & \ddots& & \\
& & & R_{y}\left(-\theta_{2^{n}-1}-1\right)
\end{array}\right)$$

aka the action is:

$$
U|\psi\rangle=|\psi'\rangle \otimes |0\rangle=\left(\begin{array}{c}
r_{0} e^{i t_{0}} \\
r_{1} e^{i t_{1}} \\
\vdots\\
r_{2^{n}-1_{-1}} e^{i t_{2^{n}-1}-1}
\end{array}\right) \otimes|0\rangle$$


NOW $U$ can be implemented using a multiplexed $R_{z}$ gate followed by a multiplexed $R_{y}$ gate!

# Functions

In [2]:
import numpy as np

## 1. Single Qubit state

- Given $\alpha_{0}$ and $\alpha_{1}$ find theta and phi

$$ |\psi \rangle = \alpha_{0}|0\rangle + \alpha_{1}|1\rangle = r e^{i t/2} \bigg[ e^{-i\phi/2} \cos\big(\frac{\theta}{2}\big)|0\rangle +  e^{+i\phi/2} \sin\big(\frac{\theta}{2}\big)|1\rangle \bigg]$$

In [3]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import Single_qubit_rotation

alpha_0 = 1/np.sqrt(2) 
alpha_1 = 1j/np.sqrt(2)

global_phase , theta, phi = Single_qubit_rotation(alpha_0, alpha_1)

state = global_phase * np.array([[np.exp(-1j*phi/2)*np.cos(theta/2)], [np.exp(1j*phi/2)*np.cos(theta/2)]])
np.around(state, 3)

array([[0.707+0.j   ],
       [0.   +0.707j]])

## 2. Rotations to disentangle right-most bit

Given a state like:

$$|\psi \rangle =  |00 \rangle \otimes \big[|\rho_{0} \rangle \big] + \\ 
\: \: \: \: \: \: \: \: \: \: |01 \rangle \otimes \big[|\rho_{1} \rangle \big]+ \\ 
\: \: \: \: \: \: \: \: \: \:  |10 \rangle \otimes \big[|\rho_{2} \rangle \big]+ \\ 
\: \: \: \: \: \: \:   |11 \rangle \otimes  \big[|\rho_{3} \rangle \big] $$

Want to generate angles need to implement:

$$
U=
\left(\begin{array}{cccc}
R_{y}\left(-\theta_{0}\right) R_{z}\left(-\phi_{0}\right) & & & \\
& R_{y}\left(-\theta_{1}\right) R_{z}\left(-\phi_{1}\right) & & \\
& & \ddots& & \\
& & & R_{y}\left(-\theta_{2^{n}-1}-1\right)
\end{array}\right)$$

In [4]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import Rotations_to_disentangle

qubit_state_vector = np.array([
    [1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [-1j/np.sqrt(8)],
    [1/np.sqrt(8)],
    [-1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [1/np.sqrt(8)]
])

remaining_vector, theta_list, phi_list = Rotations_to_disentangle(qubit_state_vector.flat)

np.around(remaining_vector, 4)

array([0.5   +0.j    , 0.3536-0.3536j, 0.    +0.5j   , 0.5   +0.j    ])

In [5]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import prepare_arb_state_cirq_matrix_gate
import cirq
qubits = list(cirq.LineQubit.range(3))

ansatz_circuit = prepare_arb_state_cirq_matrix_gate(qubit_state_vector.flat,
                             start_qubit_ind=0)

undo_circuit = cirq.Circuit(
cirq.rz(phi_list[0]).controlled(num_controls=2, control_values=[0,0]).on(
                    *qubits),
cirq.rz(phi_list[1]).controlled(num_controls=2, control_values=[0,1]).on(
                *qubits),
cirq.rz(phi_list[2]).controlled(num_controls=2, control_values=[1,0]).on(
                    *qubits),
cirq.rz(phi_list[3]).controlled(num_controls=2, control_values=[1,1]).on(
                    *qubits),
    
cirq.ry(theta_list[0]).controlled(num_controls=2, control_values=[0,0]).on(
                    *qubits),
cirq.ry(theta_list[1]).controlled(num_controls=2, control_values=[0,1]).on(
                    *qubits),
cirq.ry(theta_list[2]).controlled(num_controls=2, control_values=[1,0]).on(
                    *qubits),
cirq.ry(theta_list[3]).controlled(num_controls=2, control_values=[1,1]).on(
                    *qubits),
)


# undo_circuit = cirq.Circuit(
# cirq.rz(phi_list[0]).on(qubits[1]),
# cirq.rz(phi_list[1]).on(qubits[1]),
    
# cirq.ry(theta_list[0]).on(qubits[1]),
# cirq.ry(theta_list[1]).on(qubits[1]),
# )

final_circuit = cirq.Circuit(
ansatz_circuit.all_operations(),
undo_circuit.all_operations() 
)
final_circuit

In [6]:
np.around(final_circuit.final_state_vector(), 4)

array([0.5   +0.j    , 0.    +0.j    , 0.3536-0.3536j, 0.    -0.j    ,
       0.    +0.5j   , 0.    +0.j    , 0.5   +0.j    , 0.    +0.j    ])

In [7]:
remaining_vector

[(0.49999999999999994+0j),
 (0.35355339059327373-0.3535533905932737j),
 (3.0616169978683824e-17+0.49999999999999994j),
 (0.49999999999999994+0j)]

In [8]:
remaining_circuit = prepare_arb_state_cirq_matrix_gate(remaining_vector,
                             start_qubit_ind=0)
remaining_circuit.append(cirq.I.on(qubits[-1]))
np.around(remaining_circuit.final_state_vector(), 4)

array([0.5   +0.j    , 0.    +0.j    , 0.3536-0.3536j, 0.    +0.j    ,
       0.    +0.5j   , 0.    +0.j    , 0.5   +0.j    , 0.    +0.j    ])

## 3. Generate quantum circuit 

In [9]:
# first split of multi-control gate there are no possible CNOT cancellations

# But for all next reductions the gates split into pairs (last CNOT gate will cancell out at each step!)

In [10]:
qubits = list(cirq.LineQubit.range(4))

In [11]:
THETA = np.pi/7

Rz_control_circ = cirq.Circuit(

cirq.rz(THETA).controlled(num_controls=3, control_values=[0,0,0]).on(
                    *qubits)
)
Rz_control_circ

In [12]:
cirq.Circuit(cirq.decompose(Rz_control_circ))

This can be decomposed into:

In [13]:
Rz_first_decomp = cirq.Circuit(

cirq.rz(THETA/2).controlled(num_controls=2, control_values=[0,0]).on(
                    *qubits[1:]),
    
cirq.CNOT(qubits[0], qubits[-1]),
    
cirq.rz(THETA/2).controlled(num_controls=2, control_values=[0,0]).on(
                    *qubits[1:]),
    
cirq.CNOT(qubits[0], qubits[-1])
)
Rz_first_decomp

In [14]:
np.allclose(Rz_control_circ.unitary(), Rz_first_decomp.unitary())

True

We can repeat this process for all the zero controlled rotation gates!

aka final step is:

In [15]:
Rz_second_decomp = cirq.Circuit(

cirq.rz(THETA/4).controlled(num_controls=1, control_values=[0]).on(
                    *qubits[2:]),
cirq.CNOT(qubits[1], qubits[-1]),

cirq.rz(THETA/4).controlled(num_controls=1, control_values=[0]).on(
                    *qubits[2:]),
cirq.CNOT(qubits[1], qubits[-1]), 
    
cirq.CNOT(qubits[0], qubits[-1]),

cirq.CNOT(qubits[1], qubits[-1]), 
    
cirq.rz(THETA/4).controlled(num_controls=1, control_values=[0]).on(
                    *qubits[2:]),
cirq.CNOT(qubits[1], qubits[-1]),

cirq.rz(THETA/4).controlled(num_controls=1, control_values=[0]).on(
                    *qubits[2:]),
# cirq.CNOT(qubits[1], qubits[-1]), 
    
cirq.CNOT(qubits[0], qubits[-1])
)
Rz_second_decomp

In [16]:
np.allclose(Rz_control_circ.unitary(), Rz_second_decomp.unitary())

True

IMPORTANTLY some CNOT gates will cancel,^^^^^^^^^^^^ see middle two here

Hence we can use a recursive function to simplify these gates (note we sometimes get CNOT cancellation!)

See page 11 of https://arxiv.org/pdf/quant-ph/0406176.pdf

In [17]:
tt1 = cirq.Circuit(
cirq.ry(np.pi).controlled(num_controls=1, control_values=[0]).on(
                    *qubits[:2]),
)
tt1

In [18]:
tt2 = cirq.Circuit(
cirq.ry(np.pi/2).on(qubits[1]),
cirq.CNOT(qubits[0], qubits[1]),
cirq.ry(np.pi/2).on(qubits[1]),
cirq.CNOT(qubits[0], qubits[1]))
tt2

In [19]:
THETA = np.pi/7
Rz_ones = cirq.Circuit(

cirq.rz(THETA).controlled(num_controls=3, control_values=[1, 1, 1]).on(
                    *qubits)
)
Rz_ones

In [20]:
Rz_ones_decomp = cirq.Circuit(

cirq.rz(THETA/2).controlled(num_controls=2, control_values=[1,1]).on(
                    *qubits[1:]),
    
cirq.CNOT(qubits[0], qubits[-1]),
    
cirq.rz(-THETA/2).controlled(num_controls=2, control_values=[1,1]).on(
                    *qubits[1:]), # NOTE NEGATIVE SIGN
    
cirq.CNOT(qubits[0], qubits[-1])
)
Rz_ones_decomp

In [21]:
Rz_ones_decomp = cirq.Circuit(

cirq.CNOT(qubits[0], qubits[-1]),
cirq.rz(-THETA/2).controlled(num_controls=2, control_values=[1,1]).on(
                    *qubits[1:]),
    
cirq.CNOT(qubits[0], qubits[-1]),
    
cirq.rz(+THETA/2).controlled(num_controls=2, control_values=[1,1]).on(
                    *qubits[1:]),
    

)
Rz_ones_decomp

In [22]:
np.allclose(Rz_ones.unitary(), Rz_ones_decomp.unitary())

True

## 3.1 Recursive build

In [23]:
def R_full_decomp(target_gate, Angle, control_list, line_qubit_list, include_last_CNOT=True, last_control_bit=None):
    
#     print(line_qubit_list)
#     print(control_list)
    
    circuit= cirq.Circuit()
    LSB = line_qubit_list[0]
    # case of no multiplexing: base case for recursion
    if len(line_qubit_list)==1:
        if target_gate == 'Ry':
            Ry_gate = cirq.ry(Angle)
            circuit.append(Ry_gate.on(LSB))
        elif target_gate == 'Rz':
                Rz_gate = cirq.rz(Angle)
                circuit.append(Rz_gate.on(LSB))
        else:
            raise ValueError(f'Incorrect gate specificed: {target_gate}')
    
        return circuit
    
    
    if ((target_gate == 'Rz') and last_control_bit==1):
        Angle_left = Angle/2
        Angle_right = -Angle/2  # note sign here!
    else:
        Angle_left = Angle/2
        Angle_right = Angle/2  

    MSB = line_qubit_list[-1]
    last_control_bit = control_list[0]
    
    Angle_left = Angle/2
    decomp_left = R_full_decomp(target_gate, 
                                Angle_left, 
                                control_list[1:], 
                                line_qubit_list[:-1], 
                                last_control_bit=last_control_bit,
                                include_last_CNOT=False)

    circuit = cirq.Circuit(
       [
           circuit.all_operations(),
           *decomp_left.all_operations(),
       ]
    )
        
    circuit.append(cirq.CNOT(MSB, LSB))
    
    
    if ((target_gate == 'Rz') and (last_control_bit==1)) or ((target_gate == 'Ry') and (last_control_bit==1)):
        Angle_right = -Angle/2  # note sign here!
    else:
        Angle_right = Angle/2  # note sign here!
        
    decomp_right = R_full_decomp(target_gate, 
                            Angle_right, 
                            control_list[1:], 
                            line_qubit_list[:-1], 
                            last_control_bit=last_control_bit,
                            include_last_CNOT=False)
    
    if len(control_list) > 1:
        
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *list(decomp_right.all_operations())[::-1], # reversed (allowed as circuit is symmetric)
                               ]
                            )
    else:
            
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *decomp_right.all_operations(),
                               ]
                            )
    # attach a final CNOT
    if include_last_CNOT:
        circuit.append(cirq.CNOT(MSB, LSB))
    return circuit

In [24]:
Angle = np.pi
control_list=[1, 0, 1]
line_qubit_list = list(cirq.LineQubit.range(len(control_list)+1))
decomp_R = R_full_decomp('Rz', Angle, control_list, line_qubit_list, include_last_CNOT=True, last_control_bit=None)

decomp_R

In [25]:

full_R = cirq.Circuit(
cirq.rz(Angle).controlled(num_controls=len(control_list), control_values=control_list).on(
                    *line_qubit_list[::-1])
)
full_R

In [26]:
np.allclose(decomp_R.unitary(), full_R.unitary())

True

In [27]:
def R_full_decomp(target_gate, Angle, control_list, start_qubit_ind, include_last_CNOT=True, last_control_bit=None):
    
    """
    0: ───@───────
          │
    1: ───@───────
          │
    2: ───@───────
          │
    3: ───Ry(π)───
    
    becomes

    0: ─────────────────────────────────────────────────────────────────────@─────────────────────────────────────────────────────────────────────@───
                                                                            │                                                                     │
    1: ──────────────────────────────────@──────────────────────────────────┼──────────────────────────────────@──────────────────────────────────┼───
                                         │                                  │                                  │                                  │
    2: ────────────────@─────────────────┼────────────────@─────────────────┼────────────────@─────────────────┼────────────────@─────────────────┼───
                       │                 │                │                 │                │                 │                │                 │
    3: ───Ry(0.125π)───X───Ry(-0.125π)───X───Ry(0.125π)───X───Ry(-0.125π)───X───Ry(0.125π)───X───Ry(-0.125π)───X───Ry(0.125π)───X───Ry(-0.125π)───X───

    
    """
    
    line_qubit_list = cirq.LineQubit.range(start_qubit_ind, start_qubit_ind+len(control_list)+1)
    
    control_list = control_list[::-1]
    circuit= cirq.Circuit()
    LSB = line_qubit_list[-1]
    # case of no multiplexing: base case for recursion
    if len(line_qubit_list)==1:
        if target_gate == 'Ry':
            Ry_gate = cirq.ry(Angle)
            circuit.append(Ry_gate.on(LSB))
        elif target_gate == 'Rz':
                Rz_gate = cirq.rz(Angle)
                circuit.append(Rz_gate.on(LSB))
        else:
            raise ValueError(f'Incorrect gate specificed: {target_gate}')
    
        return circuit
    
    
    if ((target_gate == 'Rz') and last_control_bit==1):
        Angle_left = Angle/2
        Angle_right = -Angle/2  # note sign here!
    else:
        Angle_left = Angle/2
        Angle_right = Angle/2  

    MSB = line_qubit_list[0]
    last_control_bit = control_list[-1]
    
    Angle_left = Angle/2
    decomp_left = R_full_decomp(target_gate, 
                                Angle_left, 
                                control_list[:-1], 
                                start_qubit_ind+1, 
                                last_control_bit=last_control_bit,
                                include_last_CNOT=False)

    circuit = cirq.Circuit(
       [
           circuit.all_operations(),
           *decomp_left.all_operations(),
       ]
    )
        
    circuit.append(cirq.CNOT(MSB, LSB))
    
    if ((target_gate == 'Rz') and (last_control_bit==1)) or ((target_gate == 'Ry') and (last_control_bit==1)):
        Angle_right = -Angle/2  # note sign here!
    else:
        Angle_right = Angle/2  # note sign here!
    
    decomp_right = R_full_decomp(target_gate, 
                    Angle_right, 
                        control_list[:-1], 
                        start_qubit_ind+1, 
                    last_control_bit=last_control_bit,
                    include_last_CNOT=False)
    
    if len(control_list) > 1:
       
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *list(decomp_right.all_operations())[::-1], # reversed (allowed as circuit is symmetric)
                               ]
                            )
    else:
            
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *decomp_right.all_operations(),
                               ]
                            )
    # attach a final CNOT
    if include_last_CNOT:
        circuit.append(cirq.CNOT(MSB, LSB))
    return circuit




In [28]:
Angle = np.pi
control_list=[1,1,1]
start_ind = 0

decomp_R = R_full_decomp('Ry', Angle, control_list, start_ind, include_last_CNOT=True, last_control_bit=None)

decomp_R

In [29]:
line_qubit_list = cirq.LineQubit.range(start_ind, start_ind+len(control_list)+1)

full_R = cirq.Circuit(
cirq.ry(Angle).controlled(num_controls=len(control_list), control_values=control_list).on(
                    *line_qubit_list)
)
full_R

In [30]:
np.allclose(decomp_R.unitary(), full_R.unitary())

True

In [31]:
Angle = np.pi
control_list=[]
start_ind = 1

decomp_R = R_full_decomp('Rz', Angle, control_list, start_ind, include_last_CNOT=True, last_control_bit=None)

decomp_R

In [32]:
def R_angle_list(target_gate, list_of_angles, start_qubit_ind):
    
    N_q = int(np.log2(len(list_of_angles)))
    circuit = cirq.Circuit()
    
    for control_ind, angle in enumerate(list_of_angles):
        if angle == 0:
            continue
        
        if N_q ==0:
            control_list=[]
        else:
            control_list = list(map(lambda x: int(x), list(np.binary_repr(control_ind, width=N_q))))
        
        decomp_R = R_full_decomp(target_gate, 
                                 angle,
                                 control_list, 
                                 start_qubit_ind,
                                 include_last_CNOT=True,
                                 last_control_bit=None)
        circuit.append(decomp_R)
    return circuit

In [33]:
qubit_state_vector = np.array([
    [1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [-1j/np.sqrt(8)],
    [1/np.sqrt(8)],
    [-1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [1/np.sqrt(8)]
])

remaining_vector, theta_list, phi_list = Rotations_to_disentangle(qubit_state_vector.flat)

np.around(remaining_vector, 4)

array([0.5   +0.j    , 0.3536-0.3536j, 0.    +0.5j   , 0.5   +0.j    ])

In [34]:
theta_list

[-1.5707963267948966,
 -1.5707963267948966,
 -1.5707963267948966,
 -1.5707963267948966]

In [35]:
R_angle_list('Rz', theta_list, 0)

In [36]:
remaining_vector, theta_list, phi_list = Rotations_to_disentangle(remaining_vector)

np.around(remaining_vector, 4)

array([0.6533-0.2706j, 0.5   +0.5j   ])

In [37]:
R_angle_list('Rz', theta_list, 0)

In [38]:
remaining_vector, theta_list, phi_list = Rotations_to_disentangle(remaining_vector)

np.around(remaining_vector, 4)

array([0.9808+0.1951j])

In [39]:
R_angle_list('Rz', theta_list, 0)

In [40]:
theta_list

[-1.5707963267948966]

In [79]:
from copy import deepcopy

def disentangle_circuit(qubit_state_vector, start_qubit_ind):
    """
    """

    circuit = cirq.Circuit()
    n_qubits = np.log2(len(qubit_state_vector))
    
    if np.ceil(n_qubits) != np.floor(n_qubits):
        raise ValueError('state vector is not a qubit state')
    
    n_qubits = int(n_qubits)
#     print(n_qubits)
    
    N_qubits_remaining_vector= n_qubits
    remaining_vector = deepcopy(qubit_state_vector)
    for qubit_ind in range(start_qubit_ind, start_qubit_ind+n_qubits):
        # work out which rotations must be done to disentangle the LSB
        # qubit (we peel away one qubit at a time)
        remaining_vector, theta_list, phi_list = Rotations_to_disentangle(remaining_vector)
        
        decomp_R = R_angle_list('Rz', phi_list, start_qubit_ind)
        circuit.append(decomp_R)

        decomp_R = R_angle_list('Ry', theta_list, start_qubit_ind)
        circuit.append(decomp_R)
    
    final_global_phase = remaining_vector
    return circuit, final_global_phase

In [80]:
qubit_state_vector = np.array([
    [1/np.sqrt(4)],
    [1/np.sqrt(4)],
    [1/np.sqrt(4)],
    [-1j/np.sqrt(4)],
])

# qubit_state_vector = np.array([
#     [1/np.sqrt(8)],
#     [1/np.sqrt(8)],
#     [1/np.sqrt(8)],
#     [1/np.sqrt(8)],
#     [1/np.sqrt(8)],
#     [-1/np.sqrt(8)],
#     [1j/np.sqrt(8)],
#     [1/np.sqrt(8)]
# ])

disent, G_phase = disentangle_circuit(qubit_state_vector, 3)
disent

In [43]:
new = cirq.inverse(disent)
new

In [44]:
np.around(G_phase*new.final_state_vector(), 10).reshape((len(qubit_state_vector),1))

array([[0.5+0.j ],
       [0.5+0.j ],
       [0.5+0.j ],
       [0. -0.5j]])

In [45]:
np.around(qubit_state_vector, 10)

array([[ 0.5+0.j ],
       [ 0.5+0.j ],
       [ 0.5+0.j ],
       [-0. -0.5j]])

In [46]:
z = 0.32664074-0.13529902j

In [47]:
np.vdot(z,z)

(0.12499999784070798+0j)

In [48]:
0.35355339**2

0.12499999958049211

In [49]:
G_phase

[array([0.92387953-0.38268343j])]

In [50]:
np.around(G_phase*new.final_state_vector(), 10).reshape((len(qubit_state_vector),1))

array([[0.5+0.j ],
       [0.5+0.j ],
       [0.5+0.j ],
       [0. -0.5j]])

In [68]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import intialization_circuit
start_qubit_ind=0
check_circuit=True

t, gphase= intialization_circuit(qubit_state_vector, start_qubit_ind, check_circuit=check_circuit)
t

In [65]:
qubit_state_vector = np.array([
    [1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [1/np.sqrt(8)],
    [-1/np.sqrt(8)],
    [1j/np.sqrt(8)],
    [1/np.sqrt(8)]
])

start_qubit_ind=0
check_circuit=True
t, gphase = intialization_circuit(qubit_state_vector, start_qubit_ind, check_circuit=check_circuit)
t

In [66]:
np.around(gphase*t.final_state_vector(), 10).reshape((len(qubit_state_vector),1))

array([[ 0.35355339-0.j        ],
       [ 0.35355339-0.j        ],
       [ 0.35355339-0.j        ],
       [ 0.35355339+0.j        ],
       [ 0.35355339-0.j        ],
       [-0.35355339+0.j        ],
       [-0.        +0.35355339j],
       [ 0.35355339-0.j        ]])

In [67]:
np.allclose(np.around(qubit_state_vector.flat, 10), np.around(gphase*t.final_state_vector(),10))

True

In [None]:
np.around(qubit_state_vector.flat, 10)

In [None]:
qubit_state_vector.shape

In [None]:
# def single_decomp_multi_control_ZERO_Rz(target_gate, Angle, start_qubit_ind, N_controls, last_CNOT=True):
#     """
#          0: ───(0)─────              0: ──────────────@──────────────@───
#                │                                      │              │
#          1: ───(0)─────              1: ───(0)────────┼───(0)────────┼───
#                │            TO              │         │   │          │
#          2: ───(0)─────              2: ───(0)────────┼───(0)────────┼───
#                │                            │         │    │         │
#          3: ───Rz(π)───              3: ───Rz(0.5π)───X───Rz(0.5π)───X───
         
         
#             0: ───────────────────────────────────@──────────────────────────────────@───
#                                                   │                                  │
#             1: ───────────────@───────────────@───┼───@───────────────@──────────────┼───
#    TO                         │               │   │   │               │              │
#             2: ───(0)─────────┼───(0)─────────┼───┼───┼───(0)─────────┼───(0)────────┼───
#                   │           │   │           │   │   │   │           │   │          │
#             3: ───Rz(0.25π)───X───Rz(0.25π)───X───X───X───Rz(0.25π)───X───Rz(0.25π)──X───
#                                               ^       ^
#                                             ## cancel ##
#     """
        
    
#     qubits_list = cirq.LineQubit.range(start_qubit_ind, start_qubit_ind+N_controls+1)
    
#     LSB = qubits_list[0] # least significant bit
#     circuit = cirq.Circuit()
    
#     # case of no multiplexing: base case for recursion
#     if N_controls == 0:
#         if target_gate == 'Ry':
#             Ry_gate = cirq.ry(Angle)
#             circuit.append(Ry_gate.on(LSB))
#         elif target_gate == 'Rz':
#             Rz_gate = cirq.rz(Angle)
#             circuit.append(Rz_gate.on(LSB))
#         else:
#             raise ValueError(f'Incorrect gate specificed: {target_gate}')
        
#         return circuit
    

#     MSB = qubits_list[-1] # most significant bit
    
#     # calc the combo angles
#     new_angle = Angle/2
    
#     # recursive step on half the angles fulfilling the above assumption
#     multiplex_1 = single_decomp_multi_control_ZERO_Rz(target_gate,
#                                                       new_angle,
#                                                       start_qubit_ind,
#                                                       N_controls-1, 
#                                                       last_CNOT=False)
#     circuit = cirq.Circuit(
#        [
#            circuit.all_operations(),
#            *multiplex_1.all_operations(),
#        ]
#     )
    
#     circuit.append(cirq.CNOT(MSB, LSB))

#     # implement extra efficiency from the paper of cancelling adjacent
#     # CNOTs (by leaving out last CNOT and reversing (NOT inverting) the
#     # second lower-level multiplex)
#     multiplex_2 = single_decomp_multi_control_ZERO_Rz(target_gate,
#                                                       new_angle,
#                                                       start_qubit_ind,
#                                                       N_controls-1, 
#                                                       last_CNOT=False)
    
#     if N_controls > 1:
#         circuit = cirq.Circuit(
#                                [
#                                    circuit.all_operations(),
#                                    *list(multiplex_2.all_operations())[::-1], # reversed (allowed as circuit is symmetric)
#                                ]
#                             )
#     else:
#         circuit = cirq.Circuit(
#                                [
#                                    circuit.all_operations(),
#                                    *multiplex_2.all_operations(),
#                                ]
#                             )
#     # attach a final CNOT
#     if last_CNOT:
#         circuit.append(cirq.CNOT(MSB, LSB))
    
#     return circuit

In [None]:
def single_decomp_multi_control_ZERO_Rz(target_gate, list_of_angles, start_qubit_ind, N_controls, last_CNOT=True):
    """
         0: ───(0)─────              0: ──────────────@──────────────@───
               │                                      │              │
         1: ───(0)─────              1: ───(0)────────┼───(0)────────┼───
               │            TO              │         │   │          │
         2: ───(0)─────              2: ───(0)────────┼───(0)────────┼───
               │                            │         │    │         │
         3: ───Rz(π)───              3: ───Rz(0.5π)───X───Rz(0.5π)───X───
         
         
            0: ───────────────────────────────────@──────────────────────────────────@───
                                                  │                                  │
            1: ───────────────@───────────────@───┼───@───────────────@──────────────┼───
   TO                         │               │   │   │               │              │
            2: ───(0)─────────┼───(0)─────────┼───┼───┼───(0)─────────┼───(0)────────┼───
                  │           │   │           │   │   │   │           │   │          │
            3: ───Rz(0.25π)───X───Rz(0.25π)───X───X───X───Rz(0.25π)───X───Rz(0.25π)──X───
                                              ^       ^
                                            ## cancel ##
    """
        
    number_angles = len(list_of_angles)
    local_num_qubits = int(np.log2(number_angles)) 
    qubits_list = cirq.LineQubit.range(start_qubit_ind, start_qubit_ind+ local_num_qubits+1)
    
    LSB = qubits_list[0] # least significant bit
    circuit = cirq.Circuit()
    
    # case of no multiplexing: base case for recursion
    if local_num_qubits == 1:
        if target_gate == 'Ry':
            Ry_gate = cirq.ry(list_of_angles[0])
            circuit.append(Ry_gate.on(LSB))
        elif target_gate == 'Rz':
            Rz_gate = cirq.rz(list_of_angles[0])
            circuit.append(Rz_gate.on(LSB))
        else:
            raise ValueError(f'Incorrect gate specificed: {target_gate}')
        
        return circuit
    

    MSB = qubits_list[local_num_qubits-1] # most significant bit
    
    # calc the combo angles
    
    # recursive step on half the angles fulfilling the above assumption
    multiplex_1 = single_decomp_multi_control_ZERO_Rz(target_gate,
                                                       list_of_angles[0:(number_angles // 2)],
                                                      start_qubit_ind,
                                                      N_controls-1, 
                                                      last_CNOT=False)
    circuit = cirq.Circuit(
       [
           circuit.all_operations(),
           *multiplex_1.all_operations(),
       ]
    )
    
    circuit.append(cirq.CNOT(MSB, LSB))

    # implement extra efficiency from the paper of cancelling adjacent
    # CNOTs (by leaving out last CNOT and reversing (NOT inverting) the
    # second lower-level multiplex)
    multiplex_2 = single_decomp_multi_control_ZERO_Rz(target_gate,
                                                       list_of_angles[0:(number_angles // 2)],
                                                      start_qubit_ind,
                                                      N_controls-1, 
                                                      last_CNOT=False)
    
    if N_controls > 1:
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *list(multiplex_2.all_operations())[::-1], # reversed (allowed as circuit is symmetric)
                               ]
                            )
    else:
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *multiplex_2.all_operations(),
                               ]
                            )
    # attach a final CNOT
    if last_CNOT:
        circuit.append(cirq.CNOT(MSB, LSB))
    
    return circuit

In [None]:
start_qubit_ind = 0
number_Controls = 3
# THETA = np.pi/7

circuit = single_decomp_multi_control_ZERO_Rz('Ry',
                                    theta_list,
                                    start_qubit_ind,
                                    number_Controls, 
                                    last_CNOT=True)
circuit

In [None]:

q_list = list(cirq.LineQubit.range(start_qubit_ind, start_qubit_ind+number_Controls+1))

Ry_control_circ = cirq.Circuit(

cirq.ry(THETA).controlled(num_controls=3, control_values=[0,0,0]).on(
                    *[*q_list[1:], q_list[0]])
)
Ry_control_circ

In [None]:
# np.allclose(circuit.unitary(), Ry_control_circ.unitary())

## 3.2 Build circuit from list of thetas

In [None]:
def Rotations_circuit(target_gate, list_angles, start_qubit_ind, N_controls, last_CNOT=True):
    
    circuit = cirq.Circuit()
    
#     for ind, angle in enumerate(list_angles): 
#         circuit.append(single_decomp_multi_control_ZERO_Rz(target_gate, 
#                                                            angle, 
#                                                            start_qubit_ind + ind,  # note plus
#                                                            N_controls - ind, #note minus
#                                                            last_CNOT=last_CNOT))

    circuit.append(single_decomp_multi_control_ZERO_Rz(target_gate, 
                                                       list_angles[0], 
                                                       start_qubit_ind,  
                                                       N_controls, 
                                                       last_CNOT=last_CNOT))
    return circuit

In [None]:
qubit_state_vector = np.array([
    [0.125],
    [0.125],
    [0.125j],
    [-0.125j],
    [0.125],
    [-0.125],
    [0.125j],
    [0.125j]
])

remaining_vector, theta_list, phi_list = Rotations_to_disentangle(qubit_state_vector.flat)

c = Rotations_circuit('Rz', theta_list, 0, 2)
c

In [None]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import recursive_multiplex
recursive_multiplex('Rz', theta_list, 0, 3)

In [None]:
theta_list

## 3.3 Build disentangling circuit

In [None]:
# from copy import deepcopy
# def disentangle_circuit(qubit_state_vector, start_qubit_ind):
#     """
#     """

#     circuit = cirq.Circuit()intialization_circuit
#     n_qubits = np.log2(len(qubit_state_vector))

#     if np.ceil(n_qubits) != np.floor(n_qubits):
#         raise ValueError('state vector is not a qubit state')
    
#     n_qubits = int(n_qubits)

#     remaining_vector = deepcopy(qubit_state_vector)
#     qubit_inds = list(range(start_qubit_ind, start_qubit_ind+n_qubits))
#     for j, qubit_ind in enumerate(qubit_inds):
        
        
#         remaining_vector, theta_list, phi_list = Rotations_to_disentangle(remaining_vector)
#         N_controls = int(np.log2(len(remaining_vector)))
#         add_last_cnot = True
#         if np.linalg.norm(phi_list) != 0 and np.linalg.norm(theta_list) != 0:
#             add_last_cnot = False

#         if np.linalg.norm(phi_list) != 0:
#             rz_mult_circuit = Rotations_circuit('Rz',
#                                                   phi_list,
#                                                   start_qubit_ind + j,
#                                                   N_controls,
#                                                   last_CNOT=add_last_cnot)
#             circuit.append(rz_mult_circuit)

#         if np.linalg.norm(theta_list) != 0:
#             ry_mult_circuit = Rotations_circuit('Ry',
#                                                   theta_list,
#                                                   start_qubit_ind + j,
#                                                   N_controls,
#                                                   last_CNOT=add_last_cnot)
#             circuit = cirq.Circuit(
#                        [
#                            circuit.all_operations(),
#                            *list(ry_mult_circuit.all_operations())[::-1],
#                        ]
#                     )
            
#     return circuit

In [None]:
def disentangle_circuit(qubit_state_vector, start_qubit_ind):
    """
    """

    circuit = cirq.Circuit()
    n_qubits = np.log2(len(qubit_state_vector))

    if np.ceil(n_qubits) != np.floor(n_qubits):
        raise ValueError('state vector is not a qubit state')
    
    n_qubits = int(n_qubits)

    remaining_vector = deepcopy(qubit_state_vector)
    qubit_inds = range(start_qubit_ind, start_qubit_ind+n_qubits)
    for j, qubit_ind in enumerate(qubit_inds):
        # work out which rotations must be done to disentangle the LSB
        # qubit (we peel away one qubit at a time)
        remaining_vector, theta_list, phi_list = Rotations_to_disentangle(remaining_vector)
        N_controls = int(np.log2(len(remaining_vector)))
        print(len(theta_list))
        
        add_last_cnot = True
        if np.linalg.norm(phi_list) != 0 and np.linalg.norm(theta_list) != 0:
            add_last_cnot = False

        if np.linalg.norm(phi_list) != 0:
            rz_mult_circuit = single_decomp_multi_control_ZERO_Rz('Rz',
                                                  phi_list,
                                                  qubit_ind,
                                                  N_controls,
                                                  last_CNOT=add_last_cnot)
            circuit.append(rz_mult_circuit)

        if np.linalg.norm(theta_list) != 0:
            ry_mult_circuit = single_decomp_multi_control_ZERO_Rz('Ry',
                                                  theta_list,
                                                  qubit_ind,
                                                  N_controls,
                                                  last_CNOT=add_last_cnot)
            circuit = cirq.Circuit(
                       [
                           circuit.all_operations(),
                           *list(ry_mult_circuit.all_operations())[::-1],
                       ]
                    )
    return circuit

In [None]:
qubit_state_vector = np.array([
    [np.sqrt(0.2)],
    [0],
    [np.sqrt(0.1)],
    [np.sqrt(0.1)],
    [np.sqrt(0.3)],
    [np.sqrt(0.2)],
    [np.sqrt(0.1)],
    [0],
])

# c = disentangle_circuit(qubit_state_vector, 0)
# c

In [None]:
new = cirq.inverse(c)
new

In [None]:
np.around(new.final_state_vector(), 3)

In [None]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import intialization_circuit
start_qubit_ind=0
check_circuit=True

t = intialization_circuit(qubit_state_vector, start_qubit_ind, check_circuit=check_circuit, threshold=7)
t

In [None]:
np.around(t.final_state_vector(), 3)

In [None]:
np.allclose(tt1.unitary(), tt2.unitary())

In [None]:
theta_list

In [None]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import recursive_multiplex

test = recursive_multiplex('Ry', theta_list, 10, 20)
test

In [None]:
def recursive_multiplex_NEW(target_gate, list_of_angles, start_qubit_ind, last_cnot=True):
    """
    Args:
        target_gate (Gate): Ry or Rz gate to apply to target qubit,
                            multiplexed over all other "select" qubits
                            
        list_of_angles (list[float]): list of rotation angles to apply Ry and Rz
        
        last_cnot (bool): add the last cnot if last_cnot = True
    """
    number_angles = len(list_of_angles)
    local_num_qubits = int(np.log2(number_angles)) + 1 # +1 for n+1 qubits!
    
    qubits_list = cirq.LineQubit.range(start_qubit_ind, start_qubit_ind+local_num_qubits+1)
    
    LSB = qubits_list[-1] # least significant bit
    MSB = qubits_list[0] # most significant bit
    
    circuit = cirq.Circuit()
    
    # case of no multiplexing: base case for recursion
    if local_num_qubits == 1:
        if target_gate == 'Ry':
            Ry_gate = cirq.ry(list_of_angles[0])
            circuit.append(Ry_gate.on(LSB))
        elif target_gate == 'Rz':
            Rz_gate = cirq.rz(list_of_angles[0])
            circuit.append(Rz_gate.on(LSB))
        else:
            raise ValueError(f'Incorrect gate specificed: {target_gate}')
        
        return circuit
    
    angle_weight = np.kron([[0.5, 0.5], [0.5, -0.5]],
                               np.identity(2 ** (local_num_qubits - 2)))
    
    # calc the combo angles
    list_of_angles = angle_weight.dot(np.array(list_of_angles)).tolist()
    
    # recursive step on half the angles fulfilling the above assumption
    multiplex_1 = recursive_multiplex_NEW(target_gate, list_of_angles[0:(number_angles // 2)],
                                      start_qubit_ind+1,
                                      False)
    circuit = cirq.Circuit(
       [
           circuit.all_operations(),
           *multiplex_1.all_operations(),
       ]
    )
    
    circuit.append(cirq.CNOT(MSB, LSB))

    # implement extra efficiency from the paper of cancelling adjacent
    # CNOTs (by leaving out last CNOT and reversing (NOT inverting) the
    # second lower-level multiplex)
    multiplex_2 = recursive_multiplex_NEW(target_gate, list_of_angles[(number_angles // 2):],
                                      start_qubit_ind+1,
                                      False)
    
    if number_angles > 1:
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *list(multiplex_2.all_operations())[::-1], # reversed (allowed as circuit is symmetric)
                               ]
                            )
    else:
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *multiplex_2.all_operations(),
                               ]
                            )
    # attach a final CNOT
    if last_cnot:
        circuit.append(cirq.CNOT(MSB, LSB))
    
    return circuit

In [None]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import recursive_multiplex

test = recursive_multiplex('Ry', [0.5, 0.6, 0.7, 0.8], 0, 3)
test

In [None]:
tt2.unitary()

In [None]:
tt1.unitary()

In [None]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import intialization_circuit

qubit_state_vector = [1j/np.sqrt(8) if i%2==0 else 1/np.sqrt(8) for i in range(2**3)]
# qubit_state_vector = [1/np.sqrt(8) for i in range(2**3)]

start_qubit_ind=0
check_circuit=False

c = intialization_circuit(qubit_state_vector, start_qubit_ind, check_circuit=check_circuit, threshold=7)
c

In [None]:
c.final_state_vector()

In [None]:
def recursive_multiplex_NEW(target_gate, 
                            list_control_qubits, 
                            active_qubit, last_cnot=True,
                            running_circuit=cirq.Circuit()):
    """
    Iteratively breaks into:
    
         0: ───(0)─────              0: ──────────────@──────────────@───
               │                                      │              │
         1: ───(0)─────              1: ───(0)────────┼───(0)────────┼───
               │            TO              │         │   │          │
         2: ───(0)─────              2: ───(0)────────┼───(0)────────┼───
               │                            │         │    │         │
         3: ───Rz(π)───              3: ───Rz(0.5π)───X───Rz(0.5π)───X───
    """
    
    
    for ind, control_q in enumerate(list_control_qubits):
        
        if ind == 0:
            # no CNOT cancellation possible

In [None]:
class single_decomp_multi_control_ZERO_Rz(cirq.Gate):
    """
         0: ───(0)─────              0: ──────────────@──────────────@───
               │                                      │              │
         1: ───(0)─────              1: ───(0)────────┼───(0)────────┼───
               │            TO              │         │   │          │
         2: ───(0)─────              2: ───(0)────────┼───(0)────────┼───
               │                            │         │    │         │
         3: ───Rz(π)───              3: ───Rz(0.5π)───X───Rz(0.5π)───X───
    """

    def __init__(self, Angle, list_control_qubits, active_qubit, include_last_CNOT=True):

        self.Angle = Angle
        self.list_control_qubits = list_control_qubits
        self.active_qubit


    def _decompose_(self, qubits, include_final_CNOT=True):
        
        divisor = 2**len(self.list_control_qubits)
        
        least_significant_bit = qubits[0] 
        active_bit = qubits[-1]
        if len(self.list_control_qubits) == []:
            pass
        else:
            
            new_angle = self.Angle/2
            new_control = 
            
        
        # then do final



    def num_qubits(self):
        return len(self.list_control_qubits) + 1

In [None]:
len([])

In [None]:
from quchem.Qcircuit.Circuit_functions_to_create_arb_state import recursive_multiplex

test = recursive_multiplex('Ry', theta_list, 0, 20)
test

In [None]:
theta_list = [0.125 * np.pi for _ in range(8)]

In [None]:
def recursive_multiplex(target_gate, list_of_angles, start_qubit_ind, last_cnot=True):
    """
    Args:
        target_gate (Gate): Ry or Rz gate to apply to target qubit,
                            multiplexed over all other "select" qubits
                            
        list_of_angles (list[float]): list of rotation angles to apply Ry and Rz
        
        last_cnot (bool): add the last cnot if last_cnot = True
    """
    number_angles = len(list_of_angles)
    local_num_qubits = int(np.log2(number_angles)) + 1 # +1 for n+1 qubits!
    
    qubits_list = cirq.LineQubit.range(local_num_qubits-start_qubit_ind, start_qubit_ind,-1)
    
    LSB = qubits_list[0] # least significant bit
    MSB = qubits_list[local_num_qubits-1] # most significant bit
    
    circuit = cirq.Circuit()
    
    # case of no multiplexing: base case for recursion
    if local_num_qubits == 1:
        if target_gate == 'Ry':
            Ry_gate = cirq.ry(list_of_angles[0])
            circuit.append(Ry_gate.on(LSB))
        elif target_gate == 'Rz':
            Rz_gate = cirq.rz(list_of_angles[0])
            circuit.append(Rz_gate.on(LSB))
        else:
            raise ValueError(f'Incorrect gate specificed: {target_gate}')
        
        return circuit
    
    angle_weight = np.kron([[0.5, 0.5], [0.5, -0.5]],
                               np.identity(2 ** (local_num_qubits - 2)))
    
    # calc the combo angles
    list_of_angles = angle_weight.dot(np.array(list_of_angles)).tolist()
    
    # recursive step on half the angles fulfilling the above assumption
    multiplex_1 = recursive_multiplex(target_gate, list_of_angles[0:(number_angles // 2)],
                                      start_qubit_num,
                                      end_qubit_num-1,
                                      False)
    circuit = cirq.Circuit(
       [
           circuit.all_operations(),
           *multiplex_1.all_operations(),
       ]
    )
    
    circuit.append(cirq.CNOT(MSB, LSB))

    # implement extra efficiency from the paper of cancelling adjacent
    # CNOTs (by leaving out last CNOT and reversing (NOT inverting) the
    # second lower-level multiplex)
    multiplex_2 = recursive_multiplex(target_gate, list_of_angles[(number_angles // 2):],
                                      start_qubit_num,
                                      end_qubit_num-1,
                                      False)
    
    if number_angles > 1:
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *list(multiplex_2.all_operations())[::-1], # reversed (allowed as circuit is symmetric)
                               ]
                            )
    else:
        circuit = cirq.Circuit(
                               [
                                   circuit.all_operations(),
                                   *multiplex_2.all_operations(),
                               ]
                            )
    # attach a final CNOT
    if last_cnot:
        circuit.append(cirq.CNOT(MSB, LSB))
    
    return circuit

In [None]:
: ────────────────@────────────────@────────────────@───────────────@───
                   │                │                │               │
1: ───Rz(-0.25π)───X───Rz(-0.25π)───X───Rz(-0.25π)───X───Rz(0.25π)───X───

In [None]:
qubits = cirq.LineQubit.range(0,2)
THETA = np.pi

tt = cirq.Circuit(

cirq.CNOT(qubits[0], qubits[1]),
cirq.rz(THETA/2).on(qubits[1]),
cirq.CNOT(qubits[0], qubits[1]),
cirq.rz(THETA/2).on(qubits[1]),
cirq.CNOT(qubits[0], qubits[1]),
cirq.rz(THETA/2).on(qubits[1]),
cirq.CNOT(qubits[0], qubits[1]),
)
tt

In [None]:
tt2 = cirq.Circuit(

cirq.CNOT(qubits[0], qubits[1]),
# cirq.rz(THETA/2).on(qubits[1]),
# cirq.rz(-THETA/2).on(qubits[1]),
# cirq.CNOT(qubits[0], qubits[1]),
# cirq.CNOT(qubits[0], qubits[1]),
cirq.rz(THETA/2).on(qubits[1]),
cirq.CNOT(qubits[0], qubits[1]),
)
tt2

In [None]:
np.allclose(tt.unitary(), tt2.unitary())