# Code Teleportation Module
All of the functions in this notebook assume that we are teleporting between two Steane codes. It's not a particularly useful scenario, but it keeps the
qubit counts somewhat manageable and let's us verify that all of the circuits are functioning properly.

There are 4 components to the code teleportation module:
1. Teleportation-enabled CAT state generation
2. Code A, B $|+\rangle$ state generation
3. Transversal CNOTs between $|+\rangle$ and half of the CAT
4. Logical Z-basis measurements on each half of the CAT state

In [196]:
import time
import qiskit
from qiskit.circuit.library import XGate, ZGate

In [4]:
backend = qiskit.Aer.get_backend('aer_simulator')

## 1. CAT state generation

In [123]:
def cat_state(num_A_qubits, num_B_qubits, teleportation=False, verification=False):
    cat_A_reg = qiskit.QuantumRegister(num_A_qubits, name='cat_A')
    cat_B_reg = qiskit.QuantumRegister(num_B_qubits, name='cat_B')
    epr_reg = qiskit.QuantumRegister(2, name='epr')
    
    if teleportation:
        cat_circ = qiskit.QuantumCircuit(cat_A_reg, epr_reg, cat_B_reg)
    else:
        cat_circ = qiskit.QuantumCircuit(cat_A_reg, cat_B_reg)
    
    # Create CAT state in each half - is there a way to do this in parallel?
    #for qubits in [cat_A_reg, cat_B_reg]:
    #    cat_circ.h(qubits[0])
    #    for i in range(len(qubits) - 1):
    #        cat_circ.cx(qubits[i], qubits[i+1])
    cat_circ.h(cat_A_reg[0])
    for i in range(len(cat_A_reg)-1):
        cat_circ.cx(cat_A_reg[i], cat_A_reg[i+1])
    
    # Entangle the two halves
    if teleportation:
        cat_circ.barrier()
        
        creg = qiskit.ClassicalRegister(2, name='creg')
        cat_circ.add_register(creg)
        
        # EPR generation (perfect)
        cat_circ.h(epr_reg[0])
        cat_circ.cx(epr_reg[0], epr_reg[1])
        cat_circ.barrier(epr_reg)
        
        # Teleport a CNOT between A_reg and B_reg
        cat_circ.cx(cat_A_reg[-1], epr_reg[0])
        cat_circ.cx(epr_reg[1], cat_B_reg[0])
        cat_circ.h(epr_reg[1])
        
        cat_circ.measure(epr_reg[0], creg[0])
        cat_circ.measure(epr_reg[1], creg[1])
        
        cat_circ.append(XGate(), [cat_B_reg[0]]).c_if(creg[0], 1)
        cat_circ.append(ZGate(), [cat_A_reg[-1]]).c_if(creg[1], 1)
        
        cat_circ.barrier()
    else:
        cat_circ.cx(cat_A_reg[-1], cat_B_reg[0])
        
    # Complete the CAT state in B register (is there a way to do this in parallel to the CNOT ladder in the A reg?)
    for i in range(len(cat_B_reg)-1):
        cat_circ.cx(cat_B_reg[i], cat_B_reg[i+1])
    cat_circ.barrier()
    
    if verification:
        # TODO: there are 2 remote CNOTs in the verification step that should be added here
        cat_qubits = [*cat_A_reg, *cat_B_reg]
        check = qiskit.QuantumRegister(num_A_qubits + num_B_qubits, name='check')
        cat_circ.add_register(check)
        for i, targ_qb in enumerate(check):
            cat_circ.cx(cat_qubits[i], targ_qb)
            cat_circ.cx(cat_qubits[(i+1) % len(cat_qubits)], targ_qb)
        cat_circ.barrier()
    
    return cat_circ

In [165]:
circ = cat_state(3, 4, teleportation=True, verification=False)

regA = [qreg for qreg in circ.qregs if qreg.name == 'cat_A'][0]
regB = [qreg for qreg in circ.qregs if qreg.name == 'cat_B'][0]
        
c = qiskit.ClassicalRegister(len(regA) + len(regB), name='c')
circ.add_register(c)

circ.measure([*regA, *regB], c)
print(circ.draw(fold=-1))

result = qiskit.execute(circ, backend).result()
counts = result.get_counts()
counts

         ┌───┐           ░                                                           ░                 ░ ┌─┐                  
cat_A_0: ┤ H ├──■────────░───────────────────────────────────────────────────────────░─────────────────░─┤M├──────────────────
         └───┘┌─┴─┐      ░                                                           ░                 ░ └╥┘┌─┐               
cat_A_1: ─────┤ X ├──■───░───────────────────────────────────────────────────────────░─────────────────░──╫─┤M├───────────────
              └───┘┌─┴─┐ ░                                                ┌───┐      ░                 ░  ║ └╥┘┌─┐            
cat_A_2: ──────────┤ X ├─░────────────────■───────────────────────────────┤ Z ├──────░─────────────────░──╫──╫─┤M├────────────
                   └───┘ ░ ┌───┐      ░ ┌─┴─┐     ┌─┐                     └─╥─┘      ░                 ░  ║  ║ └╥┘            
  epr_0: ────────────────░─┤ H ├──■───░─┤ X ├─────┤M├───────────────────────╫────────░─────────────────░──╫──╫─

{'1111111 11': 133,
 '0000000 11': 120,
 '1111111 01': 124,
 '0000000 01': 129,
 '0000000 00': 103,
 '0000000 10': 140,
 '1111111 10': 150,
 '1111111 00': 125}

### CAT state generation w/ SeqOp layout
- SWAP into/out of memory
- SWAP between two transmons

## 2. Logical $|+\rangle$ state generation
Logical Steane $|0\rangle_L=|0000000\rangle + |1010101\rangle + |0110011\rangle + |1100110\rangle + |0001111\rangle + |1011010\rangle + |0111100\rangle + |1101001\rangle$

(The above was taken from Equation 41 of this [paper](https://www2.physics.ox.ac.uk/sites/default/files/ErrorCorrectionSteane06.pdf)) Also note that the ordering of qubits is reversed because Qiskit is little endian...

The $|+\rangle$ state can be prepared directly (see Figure 12 of this [paper](https://arxiv.org/pdf/quant-ph/0504218.pdf)). To check this circuit works, I'll prepare $|+\rangle$ directly then apply logical $H$ (obtained by transversal Hadamards on each data qubit) to convice ourselves that this is the correct state.

In [126]:
def steane_plus():
    # Steane code logical |+> state
    steane_reg = qiskit.QuantumRegister(7, name='steane')
    steane_plus = qiskit.QuantumCircuit(steane_reg)
    steane_plus.h([2,4,5,6])
    for pair in [(2,0), (5,3), (6,1), (4,0), (6,3), (5,1), (6,0), (2,1), (4,3)]:
        steane_plus.cx(pair[0], pair[1])
    return steane_plus

In [127]:
plus_circ = steane_plus()

plus_circ.barrier()

# Apply logical Hadamard
for qubit in steane_reg:
    plus_circ.h(qubit)
plus_circ.save_statevector()

print(plus_circ.draw())

# Check statevector
result = qiskit.execute(plus_circ, backend).result()
sv = result.get_statevector()
sv.probabilities_dict(decimals=5) # check that statevector matches |0> written out above

               ┌───┐     ┌───┐          ┌───┐      ░ ┌───┐ ░ 
steane_0: ─────┤ X ├─────┤ X ├──────────┤ X ├──────░─┤ H ├─░─
               └─┬─┘┌───┐└─┬─┘     ┌───┐└─┬─┘┌───┐ ░ ├───┤ ░ 
steane_1: ───────┼──┤ X ├──┼───────┤ X ├──┼──┤ X ├─░─┤ H ├─░─
          ┌───┐  │  └─┬─┘  │       └─┬─┘  │  └─┬─┘ ░ ├───┤ ░ 
steane_2: ┤ H ├──■────┼────┼─────────┼────┼────■───░─┤ H ├─░─
          └───┘┌───┐  │    │  ┌───┐  │    │  ┌───┐ ░ ├───┤ ░ 
steane_3: ─────┤ X ├──┼────┼──┤ X ├──┼────┼──┤ X ├─░─┤ H ├─░─
          ┌───┐└─┬─┘  │    │  └─┬─┘  │    │  └─┬─┘ ░ ├───┤ ░ 
steane_4: ┤ H ├──┼────┼────■────┼────┼────┼────■───░─┤ H ├─░─
          ├───┤  │    │         │    │    │        ░ ├───┤ ░ 
steane_5: ┤ H ├──■────┼─────────┼────■────┼────────░─┤ H ├─░─
          ├───┤       │         │         │        ░ ├───┤ ░ 
steane_6: ┤ H ├───────■─────────■─────────■────────░─┤ H ├─░─
          └───┘                                    ░ └───┘ ░ 


{'0000000': 0.125,
 '0011110': 0.125,
 '0101101': 0.125,
 '0110011': 0.125,
 '1001011': 0.125,
 '1010101': 0.125,
 '1100110': 0.125,
 '1111000': 0.125}

## 3. Transversal CNOTs

In [128]:
def do_cnots():
    plus_circ = steane_plus()
    cat_circ = cat_state(7,7)

    code_A_reg = qiskit.QuantumRegister(7, name='code_A')
    code_B_reg = qiskit.QuantumRegister(7, name='code_B')
    cat_A_reg = qiskit.QuantumRegister(7, name='cat_A')
    cat_B_reg = qiskit.QuantumRegister(7, name='cat_B')

    full_circ = qiskit.QuantumCircuit(code_A_reg, cat_A_reg, cat_B_reg, code_B_reg)

    full_circ.compose(plus_circ, code_A_reg, inplace=True)
    full_circ.compose(cat_circ, [*cat_A_reg, *cat_B_reg], inplace=True)
    full_circ.compose(plus_circ, code_B_reg, inplace=True)

    for ctrl, targ in zip(code_A_reg, cat_A_reg):
        full_circ.cx(ctrl, targ)

    for ctrl, targ in zip(code_B_reg, cat_B_reg):
        full_circ.cx(ctrl, targ)
    
    return full_circ

In [129]:
circ = do_cnots()
circ.draw(fold=-1)

## 4. Logical Z-basis measurement
### Test projective measurement on logical $|0\rangle$ state
For the Steane code, the logical Z operator is transversal, so to measure our logical (encoded) qubit in the logical Z basis, we use a projective measurement.
See the circuit shown in Equation 10 of [Dave Bacon's lecture notes] for an explanation of how the projective measurements work.

In [120]:
# Steane code logical |0> state
steane_reg = qiskit.QuantumRegister(7, name='steane')
steane_zero = qiskit.QuantumCircuit(steane_reg)
steane_zero.h([0,1,3])
for pair in [(0,2), (3,5), (1,6), (0,4), (3,6), (1,5), (0,6), (1,2), (3,4)]:
    steane_zero.cx(pair[0], pair[1])
steane_zero.barrier()

anc_reg = qiskit.QuantumRegister(1, name='ancilla')
steane_zero.add_register(anc_reg)
steane_zero.h(anc_reg[0])
for qubit in steane_reg:
    steane_zero.cz(anc_reg[0], qubit)
steane_zero.h(anc_reg[0])
steane_zero.barrier()

anc_bit_reg = qiskit.ClassicalRegister(1, name='z_meas')
steane_bit_reg = qiskit.ClassicalRegister(7, name='steane_meas')
steane_zero.add_register(anc_bit_reg, steane_bit_reg)
steane_zero.measure(steane_reg, steane_bit_reg)
steane_zero.measure(anc_reg, anc_bit_reg)
steane_zero.draw(fold=-1)

In [122]:
result = qiskit.execute(steane_zero, backend, shots=10000).result()
counts = result.get_counts()

print(counts)

# parse the counts to be human readable
for anc_result in ['0', '1']:
    print(f'When the ancilla was measured to be in state |{anc_result}>')
    print('The other qubits were measured to be in (BIG ENDIAN):')
    post_selected_states = {key.split()[0]: val for key, val in counts.items() if key[-1] == anc_result}
    post_shot_total = sum(post_selected_states.values())
    for state, shots in post_selected_states.items():
        print(f'\t{"".join(list(state)[::-1][:-1])}: {shots / post_shot_total:.3f}')

{'1111000 0': 1309, '1100110 0': 1168, '0011110 0': 1236, '1001011 0': 1259, '0000000 0': 1262, '1010101 0': 1274, '0110011 0': 1256, '0101101 0': 1236}
When the ancilla was measured to be in state |0>
The other qubits were measured to be in (BIG ENDIAN):
	000111: 0.131
	011001: 0.117
	011110: 0.124
	110100: 0.126
	000000: 0.126
	101010: 0.127
	110011: 0.126
	101101: 0.124
When the ancilla was measured to be in state |1>
The other qubits were measured to be in (BIG ENDIAN):


### Logical Z measurement on CAT state halves

In [190]:
def ct_module():
    circ = do_cnots()

    code_A_reg = [qreg for qreg in circ.qregs if qreg.name == 'code_A'][0]
    code_B_reg = [qreg for qreg in circ.qregs if qreg.name == 'code_B'][0]
    code_regs = [code_A_reg, code_B_reg]
    
    cat_A_reg = [qreg for qreg in circ.qregs if qreg.name == 'cat_A'][0]
    cat_B_reg = [qreg for qreg in circ.qregs if qreg.name == 'cat_B'][0]
    cat_regs = [cat_A_reg, cat_B_reg]
    
    # Add ancilla qubits for projective measurement
    anc_reg = qiskit.QuantumRegister(2, name='ancilla')
    bit_reg = qiskit.ClassicalRegister(2, name='cat_bits')
    circ.add_register(anc_reg, bit_reg)
    circ.barrier()
    
    # Logical Z measurements on each CAT state half
    for i, anc_qb in enumerate(anc_reg):
        circ.h(anc_qb)
        for qubit in cat_regs[i]:
            circ.cz(anc_qb, qubit)
        circ.h(anc_qb)
        circ.measure(anc_qb, bit_reg[i])
    
    logical_X = qiskit.QuantumCircuit(7, name='Steane X')
    logical_X.x(range(7))
    #print(logical_X.draw())
    
    # Cannot compare two CBits together (to see whether they're different), so just check separately whether the bit_reg is 01 = 0x1 or 10 = 0x2
    circ.append(logical_X, code_B_reg).c_if(bit_reg, 1)
    circ.append(logical_X, code_B_reg).c_if(bit_reg, 2)

    return circ

In [191]:
circ = ct_module()
circ.draw(fold=-1)

## Full CT module simulation
The output of the CT module should be the (logical) state $|\phi_{AB}\rangle = \frac{1}{\sqrt{2}}(|0_A0_B\rangle + |1_A1_B\rangle)$

In [192]:
mps_simulator = qiskit.Aer.get_backend('aer_simulator_matrix_product_state')

In [193]:
circ = ct_module()
circ.barrier()

# Do logical Z basis measurements on each of the logical qubits
code_A_reg = [qreg for qreg in circ.qregs if qreg.name == 'code_A'][0]
code_B_reg = [qreg for qreg in circ.qregs if qreg.name == 'code_B'][0]
code_regs = [code_A_reg, code_B_reg]
# Reuse the ancilla qubits
anc_reg = [qreg for qreg in circ.qregs if qreg.name == 'ancilla'][0]
circ.reset(anc_reg)

creg = qiskit.ClassicalRegister(2, name='creg')
circ.add_register(creg)

for i, anc_qb in enumerate(anc_reg):
    circ.h(anc_qb)
    for qubit in code_regs[i]:
        circ.cz(anc_qb, qubit)
    circ.h(anc_qb)
    circ.measure(anc_qb, creg[i])
circ.draw(fold=-1)

In [197]:
# Simulate
# This cell can take a bit of time, about 10 minutes on my MacBook, but reasonable enough to ensure this circuit is working properly
print('Simulation started...')
start = time.time()
result = qiskit.execute(circ, mps_simulator, shots=10000).result()
end = time.time()
counts = result.get_counts()
print('Simulation ended!')
print(f'Elapsed time: {end-start:.3f} seconds')

Simulation started...
Simulation ended!
Elapsed time: 636.211 seconds


In [199]:
print('Raw counts (creg cat_bits):')
print(counts)

print('When we logically measure the Code A qubit and Code B qubit, we find:')
outputs = {}
for key, val in counts.items():
    logical_output = key.split()[0]
    try:
        outputs[logical_output] += val
    except KeyError:
        outputs[logical_output] = val
total_shots = sum(outputs.values())
for output, shots in outputs.items():
    print(f'\tThe logical state: {output} with probability {shots/total_shots*100:.3f}%')
print('This is what we expect from the output of the CT module!!!')

Raw counts (creg cat_bits):
{'11 01': 1229, '00 10': 1225, '00 00': 1284, '00 01': 1223, '00 11': 1222, '11 11': 1259, '11 10': 1276, '11 00': 1282}
When we logically measure the Code A qubit and Code B qubit, we find:
	The logical state: 11 with probability 50.460%
	The logical state: 00 with probability 49.540%
This is what we expect from the output of the CT module!!!


## Extra Steane code stuff

In [5]:
# Steane code logical |0> state
steane_reg = qiskit.QuantumRegister(7, name='steane')
steane_zero = qiskit.QuantumCircuit(steane_reg)
steane_zero.h([0,1,3])
for pair in [(0,2), (3,5), (1,6), (0,4), (3,6), (1,5), (0,6), (1,2), (3,4)]:
    steane_zero.cx(pair[0], pair[1])
steane_zero.save_statevector()
#steane_zero.measure_all()
steane_zero.draw()

Logical Steane $|0\rangle_L=|0000000\rangle + |1010101\rangle + |0110011\rangle + |1100110\rangle + |0001111\rangle + |1011010\rangle + |0111100\rangle + |1101001\rangle$

(The above was taken from Equation 41 of this [paper](https://www2.physics.ox.ac.uk/sites/default/files/ErrorCorrectionSteane06.pdf)) Also note that the ordering of qubits is reversed because Qiskit is little endian...

In [11]:
result = qiskit.execute(steane_zero, backend).result()
sv = result.get_statevector()
counts = result.get_counts()
counts

{'0000000': 0.125,
 '0011110': 0.125,
 '0101101': 0.125,
 '0110011': 0.125,
 '1001011': 0.125,
 '1010101': 0.125,
 '1100110': 0.125,
 '1111000': 0.125}

In [13]:
# Steane code logical |1> state
steane_reg = qiskit.QuantumRegister(7, name='steane')
steane_one = qiskit.QuantumCircuit(steane_reg)
steane_one.h([0,1,3])
for pair in [(0,2), (3,5), (1,6), (0,4), (3,6), (1,5), (0,6), (1,2), (3,4)]:
    steane_one.cx(pair[0], pair[1])
steane_one.barrier()
steane_one.x([0,1,2,3,4,5,6])
steane_one.save_statevector()
steane_one.draw()

Logical Steane $|1\rangle_L= X_{1111111}|0\rangle_L$

In [14]:
result = qiskit.execute(steane_one, backend).result()
sv = result.get_statevector()
counts = result.get_counts()
counts

{'0000111': 0.125,
 '0011001': 0.125,
 '0101010': 0.125,
 '0110100': 0.125,
 '1001100': 0.125,
 '1010010': 0.125,
 '1100001': 0.125,
 '1111111': 0.125}

In [44]:
result = qiskit.execute(steane_plus, backend, shots=10000).result()
counts = result.get_counts()

# parse the counts to be human readable
for anc_result in ['0', '1']:
    print(f'When the ancilla was measured to be in state |{anc_result}>')
    print('The other qubits were measured to be in (BIG ENDIAN):')
    post_selected_states = {key: val for key, val in counts.items() if key[0] == anc_result}
    post_shot_total = sum(post_selected_states.values())
    for state, shots in post_selected_states.items():
        print(f'\t{"".join(list(state)[::-1][:-1])}: {shots / post_shot_total:.3f}')

When the ancilla was measured to be in state |0>
The other qubits were measured to be in (BIG ENDIAN):
	0110011: 0.122
	0001111: 0.127
	0111100: 0.122
	0000000: 0.124
	1101001: 0.135
	1100110: 0.122
	1010101: 0.119
	1011010: 0.129
When the ancilla was measured to be in state |1>
The other qubits were measured to be in (BIG ENDIAN):
	0100101: 0.124
	1000011: 0.123
	0011001: 0.122
	1001100: 0.117
	0101010: 0.132
	0010110: 0.132
	1111111: 0.121
	1110000: 0.130
