In [1]:
import numpy as np
from qiskit.quantum_info import random_statevector

# 2 Qubits

In [2]:
n_qubits = 2

### Generate "unkown" state you want to learn

In [3]:
psi = random_statevector(2**n_qubits).data.reshape(-1, 1)
unkown_dm = psi @ psi.conj().T

### Define measurement bases 

In [4]:
measurement_strings = ["0", "1", "+", "-", "i", "-i"]

zero = np.array([1, 0]).reshape(-1, 1)
one = np.array([0, 1]).reshape(-1, 1)
plus = 1/np.sqrt(2) * np.array([1, 1]).reshape(-1, 1)
minus = 1/np.sqrt(2) * np.array([1, -1]).reshape(-1, 1)
i = 1/np.sqrt(2) * np.array([1, 1j]).reshape(-1, 1)
minus_i = 1/np.sqrt(2) * np.array([1, -1j]).reshape(-1, 1)

basis_states = [zero, one, plus, minus, i, minus_i]

### Create measurement projectors

In [5]:
# 2 qubits
measurement_operators = [state_vector @ state_vector.conj().T for state_vector in basis_states]
two_qubit_measurement_operators = []
for qubit_1_op in measurement_operators:
    for qubit_2_op in measurement_operators:
        two_qubit_measurement_operators.append(np.kron(qubit_1_op, qubit_2_op))

### Compute probabilities to use as the data to fit to
Experimentally, you would compute these from measurement counts

In [6]:
def get_probabilities(dm):
    return [np.trace(dm @ measurement_op).real for measurement_op in two_qubit_measurement_operators]

probabilities = get_probabilities(unkown_dm)
prob_dict = dict(zip(measurement_strings, probabilities))

### Create tunable state
Parameterization from: https://research.physics.illinois.edu/QI/Photonics/tomography-files/tomo_chapter_2004.pdf

In [7]:
def get_parameterized_dm(t):
    """ t is a list of the tunable parameters """
    T = np.array([[t[0], 0, 0, 0],
                  [t[4] + 1j*t[5], t[1], 0, 0],
                  [t[10] + 1j*t[11], t[6] + 1j*t[7], t[2], 0],
                  [t[14] + 1j*t[15], t[12] + 1j*t[13], t[8] + 1j*t[9], t[3]]
                 
                 ])
    parameterized_rho = (T.conj().T @ T) / np.trace(T.conj().T @ T)
    return parameterized_rho

In [8]:
dm_parameters = np.random.uniform(low=-1, high=1, size=16)

In [9]:
parameterized_dm = get_parameterized_dm(dm_parameters)

### Test that DM is valid

In [10]:
# Check that trace is close to one
assert np.allclose(1, np.trace(parameterized_dm))

# Check Hermitian
assert np.allclose(parameterized_dm, parameterized_dm.conj().T)

# Check positive-semi definite <--> eigenvalues >= 0
assert np.all(np.linalg.eigvalsh(parameterized_dm) >= 0)

### Optimize with SciPy

In [11]:
def get_probabilites_from_parameters(t):
    parameterized_dm = get_parameterized_dm(t)
    probabilities = get_probabilities(parameterized_dm)
    return probabilities

In [12]:
from scipy.special import kl_div

def kl_divergence(true_probs, t):
    probs = get_probabilites_from_parameters(t)
    return sum(kl_div(true_probs, probs))

kl_divergence_from_parameters = lambda params : kl_divergence(probabilities, params)

In [13]:
from scipy.optimize import minimize

optimization_result = minimize(kl_divergence_from_parameters, dm_parameters, method="L-BFGS-B")

In [14]:
found_params = optimization_result.x

found_dm = get_parameterized_dm(found_params)

In [15]:
dm_fidelity = lambda d1, d2 : np.trace(d1.T.conj() @ d2).real

dm_fidelity(unkown_dm, found_dm)

np.float64(0.9999466697396566)

In [16]:
print(unkown_dm)

[[ 0.0913432 -2.51810988e-19j -0.09697933+1.20899084e-01j
   0.04376399+1.04823794e-02j -0.22421975-8.17213081e-02j]
 [-0.09697933-1.20899084e-01j  0.26298158+3.94767279e-18j
  -0.03259019-6.90538662e-02j  0.12989089+3.83534210e-01j]
 [ 0.04376399-1.04823794e-02j -0.03259019+6.90538662e-02j
   0.02217097-3.73054537e-19j -0.11680548-1.34229402e-02j]
 [-0.22421975+8.17213081e-02j  0.12989089-3.83534210e-01j
  -0.11680548+1.34229402e-02j  0.62350425+8.55389772e-18j]]


In [17]:
print(found_dm)

[[ 0.09133651+0.j         -0.09697913+0.12089427j  0.04375413+0.01048741j
  -0.22419312-0.0817061j ]
 [-0.09697913-0.12089427j  0.26300422+0.j         -0.03257572-0.06906503j
   0.12987943+0.38351325j]
 [ 0.04375413-0.01048741j -0.03257572+0.06906503j  0.02218797+0.j
  -0.11678858-0.01343374j]
 [-0.22419312+0.0817061j   0.12987943-0.38351325j -0.11678858+0.01343374j
   0.6234713 +0.j        ]]
