# Modelling NIQS Hardware pyTorch

In [1]:
import numpy as np
import qiskit as qk
import matplotlib.pyplot as plt
import multiprocessing as mp
import random
import torch.optim as optim

from qiskit.quantum_info import DensityMatrix
from qiskit.quantum_info import Operator
from scipy.linalg import sqrtm
from tqdm.notebook import tqdm


from src_torch import *

torch.set_printoptions(precision=8)

## Test

In [2]:
n = 3
d = 2**n

state_input_list = [prepare_input(numberToBase(i, 6, n)) for i in range(6**n)]

np.random.seed(42)
torch.manual_seed(42)

X_target = generate_ginibre(d**2, 2)

choi_target = generate_choi(X_target)

state_target_list = [apply_map(state_input, choi_target) for state_input in state_input_list]


X = generate_ginibre(d**2, 2).requires_grad_()
optimizer = optim.SGD([X], lr=0.1)

fid_list = []

for i in tqdm(range(1000)):
    optimizer.zero_grad()
    
    choi_model = generate_choi(X)
    index = np.random.randint(0, len(state_input_list)-1)
    state_input = state_input_list[index]
    state_target = state_target_list[index]
    
    state_model = apply_map(state_input, choi_model)
    loss = -state_fidelity(state_model, state_target)
    loss.backward()
    optimizer.step()
    fid = np.abs(loss.detach().numpy())
    fid_list.append(fid)
    print(f"step: {i}, fid: {fid:.4f}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

step: 0, fid: 8121.8144
step: 1, fid: 3611822262.7544



RuntimeError: linalg_eig_backward: The eigenvectors in the complex case are specified up to multiplication by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined.

In [None]:
plt.plot(fid_list)

In [None]:
print(choi_model)
print(choi_target)

In [None]:
A = state_input_list[15]
print(A)

L, V = torch.linalg.eigh(A)
L = torch.sqrt(L.type(torch.complex64))

B = torch.zeros_like(A)
for l, v in zip(L, V.T):
    B += l*torch.conj(v.reshape(-1,1))@v.reshape(1,-1)

print(A, B@B)

In [None]:
n = 2
d = 2**n

np.random.seed(42)
X = generate_ginibre(d**2, 2)
print(X)

In [None]:
XX = X@X.T.conj()
print(XX)

In [None]:
[[ 0.78061957+0.j          0.07601179+0.68805429j -1.23559901-0.40182879j
   1.18303852+0.45402626j]
 [ 0.07601179-0.68805429j  3.17078055+0.j          0.27068591+2.97482211j
   3.25289552+0.88433786j]
 [-1.23559901+0.40182879j  0.27068591-2.97482211j  3.82883495+0.j
   0.10898574-1.62175541j]
 [ 1.18303852-0.45402626j  3.25289552-0.88433786j  0.10898574+1.62175541j
   6.37437797+0.j        ]]

In [None]:
Y = partial_trace(XX)
print(Y)

In [None]:
Y = square_root_inverse(Y)
print(Y)

In [None]:
0.46982564+1.52046633e-19j
-0.00539533+2.72307567e-02j
-0.00539533-2.72307567e-02j
0.3258792 -1.16520961e-17j

In [None]:
I = torch.eye(d).type(torch.complex128)
Ykron = torch.kron(I, Y).T
print(Ykron)

In [None]:
choi = Ykron@XX@Ykron
print(torch.trace(choi))