# Modelling NIQS Hardware pyTorch

In [1]:
import numpy as np
import qiskit as qk
import matplotlib.pyplot as plt
import multiprocessing as mp
import random
import torch.optim as optim

from qiskit.quantum_info import DensityMatrix
from qiskit.quantum_info import Operator
from scipy.linalg import sqrtm
from tqdm.notebook import tqdm


from src_torch import *

torch.set_printoptions(precision=8)

## Test

In [2]:
n = 3
d = 2**n

state_input_list = [prepare_input(numberToBase(i, 6, n)) for i in range(6**n)]

np.random.seed(42)
torch.manual_seed(42)

X_target = generate_ginibre(d**2, 2)

choi_target = generate_choi(X_target)

state_target_list = [apply_map(state_input, choi_target) for state_input in state_input_list]


X = generate_ginibre(d**2, 2).requires_grad_()
optimizer = optim.SGD([X], lr=0.1)

fid_list = []

for i in tqdm(range(1000)):
    optimizer.zero_grad()
    
    choi_model = generate_choi(X)
    index = np.random.randint(0, len(state_input_list)-1)
    state_input = state_input_list[index]
    state_target = state_target_list[index]
    
    state_model = apply_map(state_input, choi_model)
    loss = -state_fidelity(state_model, state_target)
    loss.backward()
    optimizer.step()
    fid = np.abs(loss.detach().numpy())
    fid_list.append(fid)
    print(f"step: {i}, fid: {fid:.4f}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

step: 0, fid: 0.4546
step: 1, fid: 0.3962
step: 2, fid: 0.4770
step: 3, fid: 0.2569
step: 4, fid: 0.3745
step: 5, fid: 0.5383
step: 6, fid: 0.2986
step: 7, fid: 0.4444
step: 8, fid: 0.4933
step: 9, fid: 0.5692
step: 10, fid: 0.4783
step: 11, fid: 0.4477
step: 12, fid: 0.4660
step: 13, fid: 0.4864
step: 14, fid: 0.5430
step: 15, fid: 0.3160
step: 16, fid: 0.4068
step: 17, fid: 0.3159
step: 18, fid: 0.4807
step: 19, fid: 0.4180
step: 20, fid: 0.3751
step: 21, fid: 0.4314
step: 22, fid: 0.3448
step: 23, fid: 0.1957
step: 24, fid: 0.4496
step: 25, fid: 0.3748
step: 26, fid: 0.6510
step: 27, fid: 0.4829
step: 28, fid: 0.3193
step: 29, fid: 0.4974
step: 30, fid: 0.3475
step: 31, fid: 0.4015
step: 32, fid: 0.3916
step: 33, fid: 0.3088
step: 34, fid: 0.3453
step: 35, fid: 0.4786
step: 36, fid: 0.2215
step: 37, fid: 0.2597
step: 38, fid: 0.4641
step: 39, fid: 0.6360
step: 40, fid: 0.5088
step: 41, fid: 0.2868
step: 42, fid: 0.3749
step: 43, fid: 0.4533
step: 44, fid: 0.3004
step: 45, fid: 0.464

step: 381, fid: 0.4191
step: 382, fid: 0.3286
step: 383, fid: 0.6489
step: 384, fid: 0.3408
step: 385, fid: 0.6515
step: 386, fid: 0.4752
step: 387, fid: 0.3376
step: 388, fid: 0.5693
step: 389, fid: 0.4319
step: 390, fid: 0.4612
step: 391, fid: 0.4353
step: 392, fid: 0.4356
step: 393, fid: 0.4370
step: 394, fid: 0.3957
step: 395, fid: 0.5098
step: 396, fid: 0.5315
step: 397, fid: 0.3851
step: 398, fid: 0.3435
step: 399, fid: 0.4620
step: 400, fid: 0.5878
step: 401, fid: 0.2662
step: 402, fid: 0.4135
step: 403, fid: 0.3408
step: 404, fid: 0.6333
step: 405, fid: 0.3934
step: 406, fid: 0.6562
step: 407, fid: 0.3691
step: 408, fid: 0.3681
step: 409, fid: 0.3736
step: 410, fid: 0.2851
step: 411, fid: 0.3981
step: 412, fid: 0.5421
step: 413, fid: 0.4392
step: 414, fid: 0.3026
step: 415, fid: 0.4711
step: 416, fid: 0.6250
step: 417, fid: 0.3546
step: 418, fid: 0.4905
step: 419, fid: 0.7105
step: 420, fid: 0.4224
step: 421, fid: 0.4191
step: 422, fid: 0.5185
step: 423, fid: 0.3410
step: 424, 

step: 781, fid: 0.5068
step: 782, fid: 0.6193
step: 783, fid: 0.5093
step: 784, fid: 0.7023
step: 785, fid: 0.1985
step: 786, fid: 0.3038
step: 787, fid: 0.5280
step: 788, fid: 0.4679
step: 789, fid: 0.6228
step: 790, fid: 0.3350
step: 791, fid: 0.6017
step: 792, fid: 0.3648
step: 793, fid: 0.3649
step: 794, fid: 0.3681
step: 795, fid: 0.4215
step: 796, fid: 0.5821
step: 797, fid: 0.6084
step: 798, fid: 0.4708
step: 799, fid: 0.6008
step: 800, fid: 0.5577
step: 801, fid: 0.2932
step: 802, fid: 0.5890
step: 803, fid: 0.5453
step: 804, fid: 0.4792
step: 805, fid: 0.3899
step: 806, fid: 0.4271
step: 807, fid: 0.5204
step: 808, fid: 0.6924
step: 809, fid: 0.4299
step: 810, fid: 0.6047
step: 811, fid: 0.3687
step: 812, fid: 0.4416
step: 813, fid: 0.5627
step: 814, fid: 0.5908
step: 815, fid: 0.3062
step: 816, fid: 0.5092
step: 817, fid: 0.4758
step: 818, fid: 0.2961
step: 819, fid: 0.4630
step: 820, fid: 0.5503
step: 821, fid: 0.5259
step: 822, fid: 0.5593
step: 823, fid: 0.5476
step: 824, 

In [None]:
plt.plot(fid_list)

In [None]:
print(choi_model)
print(choi_target)

In [None]:
A = state_input_list[15]
print(A)

L, V = torch.linalg.eigh(A)
L = torch.sqrt(L.type(torch.complex64))

B = torch.zeros_like(A)
for l, v in zip(L, V.T):
    B += l*torch.conj(v.reshape(-1,1))@v.reshape(1,-1)

print(A, B@B)

In [None]:
n = 2
d = 2**n

np.random.seed(42)
X = generate_ginibre(d**2, 2)
print(X)

In [None]:
XX = X@X.T.conj()
print(XX)

In [None]:
[[ 0.78061957+0.j          0.07601179+0.68805429j -1.23559901-0.40182879j
   1.18303852+0.45402626j]
 [ 0.07601179-0.68805429j  3.17078055+0.j          0.27068591+2.97482211j
   3.25289552+0.88433786j]
 [-1.23559901+0.40182879j  0.27068591-2.97482211j  3.82883495+0.j
   0.10898574-1.62175541j]
 [ 1.18303852-0.45402626j  3.25289552-0.88433786j  0.10898574+1.62175541j
   6.37437797+0.j        ]]

In [None]:
Y = partial_trace(XX)
print(Y)

In [None]:
Y = square_root_inverse(Y)
print(Y)

In [None]:
0.46982564+1.52046633e-19j
-0.00539533+2.72307567e-02j
-0.00539533-2.72307567e-02j
0.3258792 -1.16520961e-17j

In [None]:
I = torch.eye(d).type(torch.complex128)
Ykron = torch.kron(I, Y).T
print(Ykron)

In [None]:
choi = Ykron@XX@Ykron
print(torch.trace(choi))