# Modelling NIQS Hardware pyTorch

In [1]:
import numpy as np
import qiskit as qk
import matplotlib.pyplot as plt
import multiprocessing as mp
import random
import torch.optim as optim

from qiskit.quantum_info import DensityMatrix
from qiskit.quantum_info import Operator
from scipy.linalg import sqrtm
from tqdm.notebook import tqdm


from src_torch import *

torch.set_printoptions(precision=8)

## Test

In [2]:
n = 3
d = 2**n

state_input_list = [prepare_input(numberToBase(i, 6, n)) for i in range(6**n)]

np.random.seed(42)
torch.manual_seed(42)

X_target, _, _ = generate_ginibre(d**2, 2)

choi_target = generate_choi(X_target)

state_target_list = [apply_map(state_input, choi_target) for state_input in state_input_list]


X, A, B = generate_ginibre(d**2, 2, requires_grad=True)
optimizer = optim.Adam([A, B], lr=0.01)

fid_list = []

In [None]:
for i in tqdm(range(10000)):
    optimizer.zero_grad()
    X = A + 1j*B
    choi_model = generate_choi(X)
    index = np.random.randint(0, len(state_input_list)-1)
    state_input = state_input_list[index]
    state_target = state_target_list[index]
    
    state_model = apply_map(state_input, choi_model)
    fid = np.abs(state_fidelity(state_model, state_target).detach().numpy())
    loss = torch.norm(state_model - state_target)
    loss.backward()
    optimizer.step()
    fid_list.append(fid)
    fidelity = state_fidelity
    print(f"step: {i}, fid: {fid:.4f}, norm: {loss.detach().numpy():.4f}")

In [None]:
plt.plot(fid_list)

In [4]:
print(choi_model)
print(choi_target)

tensor([[ 0.03807269-7.80625564e-18j, -0.01845127-4.90851245e-02j,
         -0.05548101-9.20165183e-03j,  ...,
         -0.02541757+1.30500987e-02j, -0.03580205+2.84096706e-02j,
         -0.08998313+2.88232404e-02j],
        [-0.01845127+4.90851245e-02j,  0.08739023-8.45677695e-18j,
          0.06797071-7.26427153e-02j,  ...,
          0.05556582-5.04095075e-02j,  0.00535633-9.13056476e-02j,
         -0.01680889-1.35073369e-01j],
        [-0.05548101+9.20165183e-03j,  0.06797071+7.26427153e-02j,
          0.14141996-1.38777878e-17j,  ...,
          0.15378825-2.48849924e-02j,  0.10429895-1.01460473e-01j,
          0.08122187-8.21124852e-02j],
        ...,
        [-0.02541757-1.30500987e-02j,  0.05556582+5.04095075e-02j,
          0.15378825+2.48849924e-02j,  ...,
          0.26784321-6.50521303e-18j,  0.15462760-1.12616032e-01j,
         -0.01837281-2.59314602e-02j],
        [-0.03580205-2.84096706e-02j,  0.00535633+9.13056476e-02j,
          0.10429895+1.01460473e-01j,  ...,
        

In [None]:
A = state_input_list[15]
print(A)

L, V = torch.linalg.eigh(A)
L = torch.sqrt(L.type(torch.complex64))

B = torch.zeros_like(A)
for l, v in zip(L, V.T):
    B += l*torch.conj(v.reshape(-1,1))@v.reshape(1,-1)

print(A, B@B)

In [None]:
n = 2
d = 2**n

np.random.seed(42)
X = generate_ginibre(d**2, 2)
print(X)

In [None]:
XX = X@X.T.conj()
print(XX)

In [None]:
[[ 0.78061957+0.j          0.07601179+0.68805429j -1.23559901-0.40182879j
   1.18303852+0.45402626j]
 [ 0.07601179-0.68805429j  3.17078055+0.j          0.27068591+2.97482211j
   3.25289552+0.88433786j]
 [-1.23559901+0.40182879j  0.27068591-2.97482211j  3.82883495+0.j
   0.10898574-1.62175541j]
 [ 1.18303852-0.45402626j  3.25289552-0.88433786j  0.10898574+1.62175541j
   6.37437797+0.j        ]]

In [None]:
Y = partial_trace(XX)
print(Y)

In [None]:
Y = square_root_inverse(Y)
print(Y)

In [None]:
0.46982564+1.52046633e-19j
-0.00539533+2.72307567e-02j
-0.00539533-2.72307567e-02j
0.3258792 -1.16520961e-17j

In [None]:
I = torch.eye(d).type(torch.complex128)
Ykron = torch.kron(I, Y).T
print(Ykron)

In [None]:
choi = Ykron@XX@Ykron
print(torch.trace(choi))

In [30]:
n = 6
d = 2**n

rank = d**2

np.random.seed(42)
X, A, B = generate_ginibre(rank*d, d, requires_grad=True)
optimizer = optim.Adam([A, B], lr=0.01)

In [31]:
U, D, _ = torch.svd(X)
print(U.shape)
K = [U[i*d:(i+1)*d, :d] for i in range(rank)]

I = K[0].T.conj()@K[0] + K[1].T.conj()@K[1] + K[2].T.conj()@K[2] + K[3].T.conj()@K[3]
print(I)

torch.Size([262144, 64])
tensor([[ 9.84445812e-04+0.00000000e+00j, -1.88895202e-05-3.00589453e-05j,
         -1.67161390e-05+2.91756523e-05j,  ...,
          3.34065134e-05-4.28406897e-06j,  3.21246322e-06-1.14678946e-05j,
         -2.26206844e-07+7.45552601e-05j],
        [-1.88895202e-05+3.00589453e-05j,  8.82433347e-04+0.00000000e+00j,
         -4.17848200e-05+4.67988951e-05j,  ...,
          7.10125854e-05+2.88255037e-06j,  1.74818108e-05+2.48267567e-05j,
          4.94567582e-05+4.92595726e-05j],
        [-1.67161390e-05-2.91756523e-05j, -4.17848200e-05-4.67988951e-05j,
          1.04879530e-03+0.00000000e+00j,  ...,
         -8.67924494e-05-1.90293428e-05j, -8.46423041e-05-1.96070972e-05j,
         -3.30675121e-06+3.93717263e-05j],
        ...,
        [ 3.34065134e-05+4.28406897e-06j,  7.10125854e-05-2.88255037e-06j,
         -8.67924494e-05+1.90293428e-05j,  ...,
          1.03522385e-03+0.00000000e+00j,  9.22356606e-06-2.16760632e-05j,
         -1.38209246e-05+3.01280949e-05j]

In [32]:
eigen = D.detach().numpy()

min_val = 1e12

for i in eigen:
    for j in eigen:
        if i != j:
            new_min_val = np.abs((i**2 - j**2))
            if new_min_val < min_val:
                min_val = new_min_val

print(1/min_val)

0.008143678667540663


In [30]:
n = 2
d = 2**n
rank = 2

state_input_list = [prepare_input(numberToBase(i, 6, n)) for i in range(6**n)]

np.random.seed(42)
torch.manual_seed(42)

X_target, _, _ = generate_ginibre(rank*d, d)
kraus_list_target = generate_kraus(X_target, d, rank)

state_target_list = [apply_kraus(state_input, kraus_list_target) for state_input in state_input_list]


X, A, B = generate_ginibre(rank*d, d, requires_grad=True)
optimizer = optim.Adam([A, B], lr=0.01)

In [31]:
fid_list = []

for i in tqdm(range(10000)):
    optimizer.zero_grad()
    X = A + 1j*B
    kraus_list_model = generate_kraus(X, d, rank)
    index = np.random.randint(0, len(state_input_list)-1)
    state_input = state_input_list[index]
    state_target = state_target_list[index]
    
    state_model = apply_kraus(state_input, kraus_list_model)
    #loss = torch.norm(state_model - state_target)
    loss = -state_fidelity(state_model, state_target)
    loss.backward()
    optimizer.step()
    print(f"step: {i}, fid: {loss.detach().numpy():.4f}")
    

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

step: 0, fid: -0.7118-0.0000j
step: 1, fid: -0.8966-0.0000j
step: 2, fid: -0.7543+0.0000j
step: 3, fid: -0.7912-0.0000j
step: 4, fid: -0.8429-0.0000j
step: 5, fid: -0.9135-0.0000j
step: 6, fid: -0.8002-0.0000j
step: 7, fid: -0.7765-0.0000j
step: 8, fid: -0.7416-0.0000j
step: 9, fid: -0.6573-0.0000j
step: 10, fid: -0.5951+0.0000j
step: 11, fid: -0.6365+0.0000j
step: 12, fid: -0.5897+0.0000j
step: 13, fid: -0.5320+0.0000j
step: 14, fid: -0.5883-0.0000j
step: 15, fid: -0.7033+0.0000j
step: 16, fid: -0.4222+0.0000j
step: 17, fid: -0.6006-0.0000j
step: 18, fid: -0.4222+0.0000j
step: 19, fid: -0.5872-0.0000j
step: 20, fid: -0.6918-0.0000j
step: 21, fid: -0.5317+0.0000j
step: 22, fid: -0.8533-0.0000j
step: 23, fid: -0.6358-0.0000j
step: 24, fid: -0.6278+0.0000j
step: 25, fid: -0.3946-0.0000j
step: 26, fid: -0.4253+0.0000j
step: 27, fid: -0.8139-0.0000j
step: 28, fid: -0.6462+0.0000j
step: 29, fid: -0.6077+0.0000j
step: 30, fid: -0.8411-0.0000j
step: 31, fid: -0.5419-0.0000j
step: 32, fid: -0.

step: 329, fid: -0.6946+0.0000j
step: 330, fid: -0.8559+0.0000j
step: 331, fid: -0.8523+0.0000j
step: 332, fid: -0.9445+0.0000j
step: 333, fid: -0.7123+0.0000j
step: 334, fid: -0.4311-0.0000j
step: 335, fid: -0.7756-0.0000j
step: 336, fid: -0.7329+0.0000j
step: 337, fid: -0.8354-0.0000j
step: 338, fid: -0.9014+0.0000j
step: 339, fid: -0.7235-0.0000j
step: 340, fid: -0.9055-0.0000j
step: 341, fid: -0.7391-0.0000j
step: 342, fid: -0.7273+0.0000j
step: 343, fid: -0.9023+0.0000j
step: 344, fid: -0.9092+0.0000j
step: 345, fid: -0.9054+0.0000j
step: 346, fid: -0.8684-0.0000j
step: 347, fid: -0.9140+0.0000j
step: 348, fid: -0.9107+0.0000j
step: 349, fid: -0.9128-0.0000j
step: 350, fid: -0.7793+0.0000j
step: 351, fid: -0.7642+0.0000j
step: 352, fid: -0.8603+0.0000j
step: 353, fid: -0.7619+0.0000j
step: 354, fid: -0.8194-0.0000j
step: 355, fid: -0.8058+0.0000j
step: 356, fid: -0.9249+0.0000j
step: 357, fid: -0.6908-0.0000j
step: 358, fid: -0.8082+0.0000j
step: 359, fid: -0.7670+0.0000j
step: 36

step: 664, fid: -0.6693+0.0000j
step: 665, fid: -0.9022-0.0000j
step: 666, fid: -0.8033+0.0000j
step: 667, fid: -0.6830+0.0000j
step: 668, fid: -0.9071-0.0000j
step: 669, fid: -0.9350+0.0000j
step: 670, fid: -0.7658-0.0000j
step: 671, fid: -0.9137-0.0000j
step: 672, fid: -0.9099-0.0000j
step: 673, fid: -0.9455-0.0000j
step: 674, fid: -0.6631-0.0000j
step: 675, fid: -0.8328+0.0000j
step: 676, fid: -0.6810-0.0000j
step: 677, fid: -0.8406-0.0000j
step: 678, fid: -0.7825+0.0000j
step: 679, fid: -0.8248-0.0000j
step: 680, fid: -0.8033-0.0000j
step: 681, fid: -0.7616-0.0000j
step: 682, fid: -0.7015-0.0000j
step: 683, fid: -0.9131+0.0000j
step: 684, fid: -0.7233-0.0000j
step: 685, fid: -0.9448+0.0000j
step: 686, fid: -0.8561-0.0000j
step: 687, fid: -0.9060+0.0000j
step: 688, fid: -0.8669-0.0000j
step: 689, fid: -0.6139+0.0000j
step: 690, fid: -0.8073+0.0000j
step: 691, fid: -0.8066-0.0000j
step: 692, fid: -0.7723-0.0000j
step: 693, fid: -0.5813-0.0000j
step: 694, fid: -0.8904-0.0000j
step: 69

In [6]:
kraus_list_target[0].T.conj()@kraus_list_target[0] + kraus_list_target[1].T.conj()@kraus_list_target[1]

tensor([[ 1.00000000e+00+0.00000000e+00j, -2.08166817e-17-1.51788304e-17j,
         -1.21430643e-17-2.77555756e-17j, -7.63278329e-17+3.46944695e-18j,
          0.00000000e+00+1.38777878e-17j,  0.00000000e+00+1.38777878e-17j,
          0.00000000e+00+6.93889390e-17j,  0.00000000e+00+0.00000000e+00j],
        [-2.08166817e-17+2.19008839e-17j,  1.00000000e+00+0.00000000e+00j,
          6.93889390e-17+1.66533454e-16j,  4.16333634e-17-5.55111512e-17j,
         -1.11022302e-16+5.55111512e-17j,  8.32667268e-17+3.12250226e-17j,
         -8.32667268e-17-8.32667268e-17j,  1.80411242e-16+1.11022302e-16j],
        [-1.21430643e-17+0.00000000e+00j,  6.93889390e-17-1.80411242e-16j,
          1.00000000e+00+0.00000000e+00j,  6.93889390e-17+4.16333634e-17j,
          1.11022302e-16-6.93889390e-17j, -6.24500451e-17+2.77555756e-17j,
          4.68375339e-17+0.00000000e+00j, -1.38777878e-17+6.93889390e-17j],
        [-7.63278329e-17+0.00000000e+00j,  4.16333634e-17+1.38777878e-17j,
          6.93889390e-

In [23]:
n = 2
d = 2**n
rank = 2

state_input_list = [prepare_input(numberToBase(i, 6, n)) for i in range(6**n)]

np.random.seed(42)
torch.manual_seed(42)

X_target, _, _ = generate_ginibre(rank*d, rank*d)
kraus_list_target, U, Q, R = generate_kraus(X_target, d, rank)

In [24]:
print(torch.diag(R))

tensor([-2.21147645+0.j, -3.93841711+0.j,  3.78146082+0.j, -3.12663678+0.j,
         2.12048085+0.j,  2.57248815+0.j, -2.04652399+0.j,  0.48867031+0.j],
       dtype=torch.complex128)


In [25]:
print(Q@Q.conj().T)

tensor([[ 1.00000000e+00+0.00000000e+00j,  2.08166817e-17+2.77555756e-17j,
         -2.94902991e-17-4.16333634e-17j,  3.46944695e-18-1.28369537e-16j,
          1.24900090e-16-1.14491749e-16j,  1.97758476e-16+1.38777878e-17j,
         -1.87350135e-16-2.77555756e-17j, -3.08997619e-17-1.21430643e-16j],
        [ 2.08166817e-17-5.55111512e-17j,  1.00000000e+00+0.00000000e+00j,
          9.02056208e-17-8.32667268e-17j,  8.32667268e-17-2.77555756e-17j,
         -2.77555756e-17+5.55111512e-17j, -1.38777878e-17+4.16333634e-17j,
         -4.16333634e-17+0.00000000e+00j, -5.72458747e-17+0.00000000e+00j],
        [-2.94902991e-17+5.55111512e-17j,  9.02056208e-17+4.51028104e-17j,
          1.00000000e+00+0.00000000e+00j,  1.35308431e-16-5.03069808e-17j,
          6.24500451e-17-2.77555756e-17j,  2.77555756e-17-7.63278329e-17j,
          4.85722573e-17-4.85722573e-17j,  2.62376926e-17+4.85722573e-17j],
        [ 3.46944695e-18+1.11022302e-16j,  8.32667268e-17+1.38777878e-17j,
          1.35308431e-

In [28]:
print((U@U.conj().T).shape)

torch.Size([8, 8])


In [29]:
print(U.conj().T@U)

tensor([[ 1.00000000e+00+0.00000000e+00j,  5.55111512e-17-5.55111512e-17j,
         -2.08166817e-17+0.00000000e+00j,  0.00000000e+00+2.77555756e-17j,
          4.51028104e-17+5.55111512e-17j, -4.16333634e-17+1.73472348e-17j,
          1.11022302e-16+2.77555756e-17j, -1.50053581e-16+6.93889390e-18j],
        [ 5.55111512e-17+6.93889390e-17j,  1.00000000e+00+0.00000000e+00j,
         -1.56125113e-17-3.98986399e-17j, -1.38777878e-17+9.02056208e-17j,
         -8.67361738e-18-4.16333634e-17j, -9.71445147e-17+8.93382590e-17j,
         -1.38777878e-17-2.60208521e-17j, -2.97071395e-17-2.77555756e-17j],
        [-2.08166817e-17+0.00000000e+00j, -1.56125113e-17+3.98986399e-17j,
          1.00000000e+00+0.00000000e+00j,  4.85722573e-17-2.08166817e-17j,
          5.89805982e-17+3.46944695e-17j,  3.46944695e-17+5.72458747e-17j,
          2.08166817e-17-1.56125113e-17j, -8.27246258e-17-1.38777878e-17j],
        [ 0.00000000e+00-2.77555756e-17j, -1.38777878e-17-7.28583860e-17j,
          4.85722573e-

In [11]:
print(U@U.conj().T)

tensor([[ 1.00000000e+00+0.00000000e+00j,  2.08166817e-17+2.77555756e-17j,
         -2.94902991e-17-4.16333634e-17j,  3.46944695e-18-1.28369537e-16j,
          1.24900090e-16-1.14491749e-16j,  1.97758476e-16+1.38777878e-17j,
         -1.87350135e-16-2.77555756e-17j, -3.08997619e-17-1.21430643e-16j],
        [ 2.08166817e-17-5.55111512e-17j,  1.00000000e+00+0.00000000e+00j,
          9.02056208e-17-8.32667268e-17j,  8.32667268e-17-2.77555756e-17j,
         -2.77555756e-17+5.55111512e-17j, -1.38777878e-17+4.16333634e-17j,
         -4.16333634e-17+0.00000000e+00j, -5.72458747e-17+0.00000000e+00j],
        [-2.94902991e-17+5.55111512e-17j,  9.02056208e-17+4.51028104e-17j,
          1.00000000e+00+0.00000000e+00j,  1.35308431e-16-5.03069808e-17j,
          6.24500451e-17-2.77555756e-17j,  2.77555756e-17-7.63278329e-17j,
          4.85722573e-17-4.85722573e-17j,  2.62376926e-17+4.85722573e-17j],
        [ 3.46944695e-18+1.11022302e-16j,  8.32667268e-17+1.38777878e-17j,
          1.35308431e-