# Modelling NIQS Hardware

In [115]:
import numpy as np
import qiskit as qk
from qiskit.quantum_info import DensityMatrix
from scipy.linalg import sqrtm

np.set_printoptions(precision=2)

In [184]:
def partial_trace(X, discard_first = True):
    d = X.shape[0]
    d_red = int(np.sqrt(d))
    Y = np.zeros((d_red, d_red), dtype = "complex128")
    I = np.eye(d_red)
    
    for i in range(d_red):
        basis_vec = np.zeros((d_red, 1),  dtype = "complex128")
        basis_vec[i, 0] = 1
        
        if discard_first:
            basis_vec = np.kron(basis_vec, I)
        else:
            basis_vec = np.kron(I, basis_vec)
        
        Y = Y + basis_vec.T@X@basis_vec
    
    return Y

def state_fidelity(A, B):
    sqrtA = sqrtm(A)
    fidelity = np.trace(sqrtm(sqrtA@B@sqrtA))
    return fidelity

def apply_map(state, choi):
    d = state.shape[0]
    
    #reshuffle
    choi = choi.reshape(d,d,d,d).swapaxes(1,2).reshape(d**2, d**2)
    
    #flatten
    state = state.reshape(-1, 1)
    
    state = (choi@state).reshape(d, d)
    return state
    

def prepare_input(config):
    """1 = |0>, 2 = |1>, 3 = |+>, 4 = |->, 5 = |+i>, 6 = |-i>"""
    n = len(config)
    circuit = qk.QuantumCircuit(n)
    for i, gate in enumerate(config):
        if gate == 2:
            circuit.x(i)
        if gate == 3:
            circuit.h(i)
        if gate == 4:
            circuit.x(i)
            circuit.h(i)
        if gate == 5:
            circuit.h(i)
            circuit.s(i)
        if gate == 6:
            circuit.x(i)
            circuit.h(i)
            circuit.s(i)
        
            
    rho = DensityMatrix(circuit)
    return rho.data

In [186]:
n = 1         # number of qubits
d = 2**n      # dim of Hilbert space
I = np.eye(d)

np.random.seed(42)

#Ginibre matrix
X = np.random.normal(0, 1, (d**2, d**2)) + 1j*np.random.normal(0, 1, (d**2, d**2))

#partial trace
Y = partial_trace(X@(X.conj().T), discard_first = True)
sqrtYinv = np.linalg.inv(sqrtm(Y))

#choi
choi = np.kron(I, sqrtYinv)@X@(X.conj().T)@np.kron(I, sqrtYinv)

In [185]:
#1 = |0>, 2 = |1>, 3 = |+>, 4 = |->, 5 = |i+>, 6 = |i->
state = prepare_input([6])

state = apply_map(state, choi)
print(state)
print(np.trace(state))
print(np.trace(state@state))
print(state_fidelity(state, state))

[[ 0.68-5.72e-17j -0.06+1.25e-01j]
 [-0.06-1.25e-01j  0.32-4.68e-17j]]
(0.9999999999999987-1.0408340855860843e-16j)
(0.5999423172468369-1.0668549377257364e-16j)
(0.9999999999999982-1.1325462882223177e-16j)
