# Using Pauli Frame Update

In [4]:
from qiskit import __version__
print(__version__)

1.4.2


In [605]:
from qiskit import QuantumCircuit, ClassicalRegister
from qiskit.quantum_info import Statevector, state_fidelity, partial_trace, DensityMatrix
from qiskit_aer import AerSimulator
from qiskit.visualization import plot_histogram
from qiskit import transpile 
import numpy as np
from qiskit_aer.noise import NoiseModel, depolarizing_error, ReadoutError
from qiskit.circuit.controlflow import IfElseOp
from qiskit.circuit.library import XGate, ZGate
import matplotlib.pyplot as plt

In [410]:
import importlib
import steane_ec_decoder1
importlib.reload(steane_ec_decoder1)
from steane_ec_decoder1 import lookup

# Function for Encoding

In [451]:
def encoding(qc: QuantumCircuit):
    theta = np.arctan(np.sqrt((np.sqrt(5) - 1) / 2))
    amp_0 = np.cos(theta/2)
    amp_1 = np.sin(theta/2)
    qc.initialize([amp_0, amp_1], 0)
    for i in range(7):
        qc.id(i)
    for i in range(4, 7):
        qc.h(i)
    qc.cx(0, 1)
    qc.cx(0, 2)
    qc.cx(6, 0)
    qc.cx(6, 1)
    qc.cx(6, 3)
    qc.cx(5, 0)
    qc.cx(5, 2)
    qc.cx(5, 3)
    qc.cx(4, 1)
    qc.cx(4, 2)
    qc.cx(4, 3)

# Functions for Stabilizer Extraction

In [396]:
def flag1(qc: QuantumCircuit, first_qubit: int, c1: ClassicalRegister):
    # Setting first ancilla to |+>
    qc.h(first_qubit+7)
    
    qc.cx(first_qubit+7, first_qubit+3)
    qc.cx(first_qubit+2, first_qubit+9)
    qc.cx(first_qubit+5, first_qubit+8)
    qc.cx(first_qubit+7, first_qubit+8)
    qc.cx(first_qubit+7, first_qubit)
    qc.cx(first_qubit+3, first_qubit+9)
    qc.cx(first_qubit+4, first_qubit+8)
    qc.cx(first_qubit+7, first_qubit+1)
    qc.cx(first_qubit+6, first_qubit+9)
    qc.cx(first_qubit+2, first_qubit+8)
    qc.cx(first_qubit+7, first_qubit+9)
    qc.cx(first_qubit+7, first_qubit+2)
    qc.cx(first_qubit+5, first_qubit+9)
    qc.cx(first_qubit+1, first_qubit+8)
    
    # Measure first ancilla in X-basis
    qc.h(first_qubit+7)
    
    qc.measure([first_qubit+7, first_qubit+8, first_qubit+9], c1)

In [397]:
def unflag(qc: QuantumCircuit, first_qubit: int, c2: ClassicalRegister):
    qc.h(first_qubit+7)
    qc.h(first_qubit+11)
    qc.h(first_qubit+12)
    
    qc.cx(first_qubit+7, first_qubit+3)
    qc.cx(first_qubit+2, first_qubit+9)
    qc.cx(first_qubit+5, first_qubit+8)
    qc.cx(first_qubit+7, first_qubit)
    qc.cx(first_qubit+3, first_qubit+9)
    qc.cx(first_qubit+4, first_qubit+8)
    qc.cx(first_qubit+7, first_qubit+1)
    qc.cx(first_qubit+6, first_qubit+9)
    qc.cx(first_qubit+2, first_qubit+8)
    qc.cx(first_qubit+7, first_qubit+2)
    qc.cx(first_qubit+5, first_qubit+9)
    qc.cx(first_qubit+1, first_qubit+8)
    
    qc.cx(first_qubit+3, first_qubit+10)
    qc.cx(first_qubit+12, first_qubit+2)
    qc.cx(first_qubit+11, first_qubit+5)
    qc.cx(first_qubit, first_qubit+10)
    qc.cx(first_qubit+12, first_qubit+3)
    qc.cx(first_qubit+11, first_qubit+4)
    qc.cx(first_qubit+1, first_qubit+10)
    qc.cx(first_qubit+12, first_qubit+6)
    qc.cx(first_qubit+11, first_qubit+2)
    qc.cx(first_qubit+2, first_qubit+10)
    qc.cx(first_qubit+12, first_qubit+5)
    qc.cx(first_qubit+11, first_qubit+1)
    
    
    qc.h(first_qubit+7)
    qc.h(first_qubit+11)
    qc.h(first_qubit+12)
    
    qc.measure([first_qubit+7, first_qubit+8, first_qubit+9, first_qubit+10, first_qubit+11, first_qubit+12], c2)


In [398]:
def flag2(qc: QuantumCircuit, first_qubit: int, c3: ClassicalRegister):
    # Setting last two ancillas to |+>
    qc.h(first_qubit+8)
    qc.h(first_qubit+9)
    
    qc.cx(first_qubit+3, first_qubit+7)
    qc.cx(first_qubit+9, first_qubit+2)
    qc.cx(first_qubit+8, first_qubit+5)
    qc.cx(first_qubit+8, first_qubit+7)
    qc.cx(first_qubit, first_qubit+7)
    qc.cx(first_qubit+9, first_qubit+3)
    qc.cx(first_qubit+8, first_qubit+4)
    qc.cx(first_qubit+1, first_qubit+7)
    qc.cx(first_qubit+9, first_qubit+6)
    qc.cx(first_qubit+8, first_qubit+2)
    qc.cx(first_qubit+9, first_qubit+7)
    qc.cx(first_qubit+2, first_qubit+7)
    qc.cx(first_qubit+9, first_qubit+5)
    qc.cx(first_qubit+8, first_qubit+1)
    
    # Measrue last two ancillas in the X-basis
    qc.h(first_qubit+8)
    qc.h(first_qubit+9)
    
    qc.measure([first_qubit+7, first_qubit+8, first_qubit+9], c3)

# Function for QEC

In [447]:
def QEC(qc: QuantumCircuit, c1: ClassicalRegister, c2: ClassicalRegister, c3: ClassicalRegister):
    # Measuring flag1 circuit
    flag1(qc, 0, c1)
    qc.reset([7,8,9])

    # body is what's appended if flag1 measures 0,0,0
    body = QuantumCircuit(13)
    body.add_register(c1, c2, c3)
    flag2(body, 0, c3)
    body.reset([7,8,9])

    # if flag1 and flag2 both measure 0,0,0 do nothing
    do_nothing = QuantumCircuit(13)
    do_nothing.add_register(c1, c2, c3)

    # if flag1 measures 0,0,0 and flag2 doesn't measure 0,0,0 then do unflag
    do_unflag = QuantumCircuit(13)
    do_unflag.add_register(c1, c2, c3)
    unflag(do_unflag, 0, c2)

    nested_if = IfElseOp((c3, 0), true_body=do_nothing, false_body=do_unflag)
    body.append(nested_if, list(range(13)), c1[:] + c2[:] + c3[:]) 

    # if flag1 isn't 0,0,0 then do unflag
    else_body = QuantumCircuit(13)
    else_body.add_register(c1, c2, c3)
    unflag(else_body, 0, c2)

    top_if = IfElseOp((c1, 0), true_body=body, false_body=else_body)
    
    qc.append(top_if, list(range(13)), c1[:] + c2[:] + c3[:])

# Encoding + 1 Round of EC 

In [459]:
noise_model = NoiseModel()

noise_model.add_all_qubit_quantum_error(depolarizing_error(0.005,1), ['id'])
noise_model.add_all_qubit_quantum_error(depolarizing_error(0.005,1), ['h'])
noise_model.add_all_qubit_quantum_error(depolarizing_error(0.05,2), ['cx'])

# 5% chance of flipping 0 <-> 1
readout_err = ReadoutError([[0.995, 0.005],  # P(measured 0 | actual 0), P(1 | 0)
                            [0.005, 0.995]]) # P(0 | 1), P(1 | 1)

# Apply to all qubits being measured
noise_model.add_readout_error(readout_err, [7]) 
noise_model.add_readout_error(readout_err, [8])
noise_model.add_readout_error(readout_err, [9])
noise_model.add_readout_error(readout_err, [10])
noise_model.add_readout_error(readout_err, [11])
noise_model.add_readout_error(readout_err, [12])

In [620]:
qc = QuantumCircuit(13)
# Classical Register for flag1 measurements
c1 = ClassicalRegister(3, "c1")
# Classical Register for unflag measurements 
c2 = ClassicalRegister(6, "c2")
# Classical Register for flag2 measurements
c3 = ClassicalRegister(3, "c3")
qc.add_register(c1, c2, c3)

encoding(qc)

QEC(qc, c1, c2, c3)

qc.save_statevector(label='statevector_post', pershot=True, conditional=True)

backend = AerSimulator(noise_model=noise_model)
transpiled = transpile(qc, backend, optimization_level=0)
job = backend.run(transpiled, shots=4, memory=True)
result = job.result()
memory = result.get_memory()

print(result)
print(memory)
#correction = lookup(memory)

Result(backend_name='aer_simulator', backend_version='0.14.2', qobj_id='', job_id='fff50121-6381-4680-8b71-f9605eda1f7f', success=True, results=[ExperimentResult(shots=4, success=True, meas_level=2, data=ExperimentResultData(counts={'0x1b9': 1, '0x9a': 1, '0x16': 1, '0x17c': 1}, memory=['0x17c', '0x16', '0x9a', '0x1b9'], statevector_post={'0x1b9': [Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))], '0x9a': [Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))], '0x16': [Statevector([ 0.+0.j,  0.+0.j,  0.+0.j, ..., -0.+0.j, -0.+0.j, -0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))], '0x17c': [Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))]}), header=QobjExperimentHeader(creg_sizes=[['c1', 3], ['c2', 6], ['c3', 3]], global_phase=4.71238898038469, 

In [621]:
hex_mem = result.data()['memory']
print(hex_mem)

['0x17c', '0x16', '0x9a', '0x1b9']


In [622]:
from collections import defaultdict

ordered_result = []
hex_mem = result.data()['memory']

seen = defaultdict(int)
for mem in hex_mem:
    idx = seen[mem]
    val = result.data()['statevector_post'][mem][idx]
    ordered_result.append({mem:val})
    seen[mem] += 1
    
print(ordered_result)

[{'0x17c': Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))}, {'0x16': Statevector([ 0.+0.j,  0.+0.j,  0.+0.j, ..., -0.+0.j, -0.+0.j, -0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))}, {'0x9a': Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))}, {'0x1b9': Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))}]


In [623]:
for i, sv in enumerate(ordered_result):
    bitstr_13, sv13 = next(iter(sv.items()))
    print(bitstr_13)
    print(sv13)

0x17c
Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))
0x16
Statevector([ 0.+0.j,  0.+0.j,  0.+0.j, ..., -0.+0.j, -0.+0.j, -0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))
0x9a
Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))
0x1b9
Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))


In [624]:
sv_correction = []
for i, sv in enumerate(ordered_result):
    str13, sv13 = next(iter(sv.items()))
    cz, cx = lookup(memory[i])
    for i, op in enumerate(cz):
        if i >= 7:
            break
        if op == 'Z':
            sv13 = sv13.evolve(ZGate(), [i])
    
    for i, op in enumerate(cx):
        if i >= 7:
            break
        if op == 'X':
            sv13 = sv13.evolve(XGate(), [i])
    sv_correction.append(sv13)
print(sv_correction)

[Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)), Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)), Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)), Statevector([0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
            dims=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))]


In [625]:
dens = DensityMatrix(sv_correction[1])
red_dens = partial_trace(dens, [7, 8, 9, 10, 11, 12])
sv = red_dens.to_statevector()
display(sv.draw("latex"))

<IPython.core.display.Latex object>

In [626]:
test = QuantumCircuit(7)
encoding(test)
test = Statevector(test)
display(test.draw("latex"))

<IPython.core.display.Latex object>