# Quantum State Tomography via MLE in Qiskit

In [28]:
# Importing standard Qiskit modules
from qiskit import QuantumCircuit, QuantumRegister, IBMQ, execute, transpile
from qiskit.providers.aer import QasmSimulator
from qiskit.tools.monitor import job_monitor
from qiskit.circuit import Parameter, Instruction
from qiskit.quantum_info import Pauli

# Import state tomography modules
from qiskit.ignis.verification.tomography import state_tomography_circuits, StateTomographyFitter
from qiskit.ignis.verification.tomography.fitters.lstsq_fit import lstsq_fit
from qiskit.ignis.verification.tomography.fitters.cvx_fit import cvx_fit

from qiskit.quantum_info import state_fidelity
from qiskit.opflow import Zero, One, I, X, Y, Z

# suppress warnings
import warnings
warnings.filterwarnings('ignore')

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import copy
import qutip as qt
import itertools

import seaborn as sns

## Target State Parities
We consider the target state $|110\rangle$ and generate the parity values associated with each Pauli string. 

In [102]:
g = qt.basis(2,0)
e = qt.basis(2,1)

# fidelity: the reconstructed state has the (flipped) ordering |q5q3q1> 
target_state_qt = qt.tensor(e,e,g)
target_state_qt = qt.ket2dm(target_state_qt)
target_state = target_state_qt.full()

# parity: "XYZ" corresponds to X measurement on q1, Y measurement on q3, and Z measurement on q5
target_state_parity_qt = qt.ket2dm(qt.tensor(g,e,e))
target_state_parity = target_state_parity_qt.full()

In [103]:
target_state_parity

array([[0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]])

In [114]:
infid_level = 0.05

probs = np.random.uniform(0, 1, 8)

# Make sure the element 3 corresponding to target state parity has the right prob
probs[3] = ((1-infid_level) / infid_level) * (np.sum(probs) - probs[3]) 

# Normalize so that the total probability sums to one
probs = probs / np.sum(probs) 

target_state_parity_noisy = probs * np.eye(8)

In [115]:
pauli = {"X":qt.sigmax(),"Y":qt.sigmay(),"Z":qt.sigmaz(),"I":qt.identity(2)}
target_parity = {}
for k1, p1 in pauli.items():
    for k2, p2 in pauli.items():
        for k3, p3 in pauli.items():
            pauli_string = k1+k2+k3
            if pauli_string == "III":
                continue
            op = qt.tensor(p1,p2,p3)
            meas = (target_state_parity_qt*op).tr()
            
            # Noisy simulation
            meas2 = np.trace(op.full() @ target_state_parity_noisy)
            
            target_parity[pauli_string] = meas2

In [116]:
target_parity

{'XXX': 0j,
 'XXY': 0j,
 'XXZ': 0j,
 'XXI': 0j,
 'XYX': 0j,
 'XYY': 0j,
 'XYZ': 0j,
 'XYI': 0j,
 'XZX': 0j,
 'XZY': 0j,
 'XZZ': 0j,
 'XZI': 0j,
 'XIX': 0j,
 'XIY': 0j,
 'XIZ': 0j,
 'XII': 0j,
 'YXX': 0j,
 'YXY': 0j,
 'YXZ': 0j,
 'YXI': 0j,
 'YYX': 0j,
 'YYY': 0j,
 'YYZ': 0j,
 'YYI': 0j,
 'YZX': 0j,
 'YZY': 0j,
 'YZZ': 0j,
 'YZI': 0j,
 'YIX': 0j,
 'YIY': 0j,
 'YIZ': 0j,
 'YII': 0j,
 'ZXX': 0j,
 'ZXY': 0j,
 'ZXZ': 0j,
 'ZXI': 0j,
 'ZYX': 0j,
 'ZYY': 0j,
 'ZYZ': 0j,
 'ZYI': 0j,
 'ZZX': 0j,
 'ZZY': 0j,
 'ZZZ': (0.9248143085350639+0j),
 'ZZI': (-0.9433733384552355+0j),
 'ZIX': 0j,
 'ZIY': 0j,
 'ZIZ': (-0.9589953781743025+0j),
 'ZII': (0.9388826264514168+0j),
 'IXX': 0j,
 'IXY': 0j,
 'IXZ': 0j,
 'IXI': 0j,
 'IYX': 0j,
 'IYY': 0j,
 'IYZ': 0j,
 'IYI': 0j,
 'IZX': 0j,
 'IZY': 0j,
 'IZZ': (0.9372711713602284+0j),
 'IZI': (-0.9462511630369251+0j),
 'IIX': 0j,
 'IIY': 0j,
 'IIZ': (-0.9504120139868282+0j)}

## Tomography

In [117]:
basis_matrix = []
data = []

for pauli, val in target_parity.items():
    ### Parity "XYZ" > actual order "ZYX" > vectorize to get |ZYX⟩⟩ > conjugate ket to its dual
    row = Pauli(pauli[::-1]).to_matrix().flatten(order='F').conjugate()
    basis_matrix.append(row)
    data.append(val)

basis_matrix = np.array(basis_matrix)

In [118]:
rho_fit = cvx_fit(data=data, basis_matrix=basis_matrix, weights=None, trace=1)

In [119]:
rho_fit

array([[2.42015644e-04+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j,
        0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j,
        0.00000000e+00+0.j, 0.00000000e+00+0.j],
       [0.00000000e+00+0.j, 9.90997384e-03+0.j, 0.00000000e+00+0.j,
        0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j,
        0.00000000e+00+0.j, 0.00000000e+00+0.j],
       [0.00000000e+00+0.j, 0.00000000e+00+0.j, 7.12678383e-03+0.j,
        0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j,
        0.00000000e+00+0.j, 0.00000000e+00+0.j],
       [0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j,
        7.51521428e-03+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j,
        0.00000000e+00+0.j, 0.00000000e+00+0.j],
       [0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j,
        0.00000000e+00+0.j, 1.20725065e-02+0.j, 0.00000000e+00+0.j,
        0.00000000e+00+0.j, 0.00000000e+00+0.j],
       [0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j,
       

In [120]:
state_fidelity(rho_fit, target_state)

0.9499999998747717

In [121]:
def mu_optimize(mu, n):
    # calculate eigenvalues of µ matrix
    eigen = np.linalg.eig(mu)
    vals  = eigen[0]
    vecs  = eigen[1].transpose()

    # order eigenvalues from largest to smallest
    eig_vals = sorted(vals, reverse=True)
    idx      = []
    for val in eig_vals:
        idx.append(np.where(vals == val)[0][0])
    eig_vecs = []
    for i in idx:
        eig_vecs.append(vecs[i])
        
    # calculate eigenvalues of the density matrix
    accumulator = 0
    lamb_vals   = [None] * len(eig_vals)
    for i in range(len(eig_vals) - 1, -1, -1):
        if eig_vals[i] + (accumulator / (i + 1)) >= 0:
            for j in range(i + 1):
                lamb_vals[j] = eig_vals[j] +  (accumulator / (i + 1))
            break
        else:
            lamb_vals[i] = 0
            accumulator  += eig_vals[i]

    # calculate density matrix
    predicted_state = np.zeros((2 ** n, 2 ** n), 'complex')
    for idx, lamb_val in enumerate(lamb_vals):
        predicted_state += lamb_vals[idx] * np.outer(eig_vecs[idx], eig_vecs[idx].conj())
    
    return predicted_state

def measurements_strings(n, arr=['X', 'Y', 'Z']):
    strs  = []
    combs = list(itertools.combinations_with_replacement(arr, n))
    for comb in combs:
        for item in set(list(itertools.permutations(comb))):
            strs.append("".join(item))
    return strs

def tensor_operator(arr):
    arr = list(arr)[::-1]
    
    I = np.array([[1, 0], [0, 1]])
    X = np.array([[0, 1], [1, 0]])
    Y  = np.array([[0, -1j], [1j, 0]])
    Z = np.array([[1, 0], [0, -1]])
    
    first = arr.pop(0)
    if first == 'I':
        out = I
    elif first == 'X':
        out = X
    elif first == 'Y':
        out = Y
    else:
        out = Z
        
    for op in arr:
        if op == 'I':
            out = np.kron(out, I)
        elif op == 'X':
            out = np.kron(out, X)
        elif op == 'Y':
            out = np.kron(out, Y)
        else:
            out = np.kron(out, Z)

    return out.astype('complex')

In [122]:
# get all expectation values
ops      = measurements_strings(3, arr=['I', 'X', 'Y', 'Z'])[1:] # we need to consider identity operator here
exp_vals = []
for op in ops:
    exp_vals.append(target_parity[op])

# calculate µ matrix 
mu = tensor_operator(['I' for _ in range(3)])
for idx, op in enumerate(ops):
    mu += exp_vals[idx] * tensor_operator(op)
mu /= (2 ** 3)

# optimize the µ matrix to get the predicted density matrix
predicted_state = mu_optimize(mu, 3)

# calculate fidelity
fidelity = target_state.conj().dot(predicted_state).dot(target_state)

In [123]:
state_fidelity(predicted_state, target_state)

0.9500000000000001