# Modelling NIQS Hardware TF

In [44]:
import numpy as np
import qiskit as qk
import matplotlib.pyplot as plt
import multiprocessing as mp
import random
import torch.optim as optim
import tensorflow as tf

from qiskit.quantum_info import DensityMatrix
from qiskit.quantum_info import Operator
from scipy.linalg import sqrtm
from tqdm.notebook import tqdm


from src_tf import *

## Test

In [86]:
n = 2
d = 2**n

state_input_list = [prepare_input(numberToBase(i, 6, n)) for i in range(6**n)]

np.random.seed(43)
tf.random.set_seed(43)

optimizer = tf.optimizers.SGD(learning_rate=1e-8)

X_target = generate_ginibre(d**2, 2)
choi_target = generate_choi(X_target)
state_target_list = [apply_map(state_input, choi_target) for state_input in state_input_list]

X_model = generate_ginibre(d**2, 2)

for i in range(100):
    state_input = state_input_list[0]
    state_target = state_target_list[0]
    print("param", grads[0,0])
    
    with tf.GradientTape() as tape:
        tape.watch(X_model)
        choi_model = generate_choi(X_model)
        state_model = apply_map(state_input, choi_model)
        fid = -state_fidelity(state_model, state_target)
    
    grads = tape.gradient(fid, X_model)
    print(grads)
    #optimizer.apply_gradients(zip([grads], [X_model]))
    #print("fid", fid)
    

tf.Tensor((0.999999999999999+8.864045934318847e-17j), shape=(), dtype=complex128)
param tf.Tensor((-7148428.5963936895+722860.1478224574j), shape=(), dtype=complex128)
tf.Tensor(
[[-7148428.59639369 +722860.14782246j  3344635.2872222 -1701988.48544855j]
 [ 1510335.52961993-2143519.33008589j   332651.82467973 +726036.3995014j ]
 [ -602569.96781076-1867347.92075486j   281438.68263147 +159104.78275908j]
 [ 2045485.47083895-1004726.86275116j   -21274.04612363 +498002.45220749j]
 [-1289616.10426428+1050669.79889739j   223173.71615743 -370169.0300079j ]
 [  374215.81228071 -108445.29379077j   244909.48975163-1322777.40141923j]
 [  659445.09724638 +913069.1564041j  -1399902.86989313-1043189.86436578j]
 [-1058075.12405426 +723951.29161323j  1319569.95345577-1493541.91564771j]
 [ -305925.14703319  -27640.80113209j  1865442.89022878-3487867.62165576j]
 [ 1299376.40700463 -204013.96111753j  -918583.92832523 -162576.43486022j]
 [  273959.3071415 -1245659.95062617j -1300722.55972075 +777072.5524758

 [ 1313174.00192368-1874244.24612342j  -135195.34843322 -994113.91367919j]], shape=(16, 2), dtype=complex128)
param tf.Tensor((-7148428.5963936895+722860.1478224574j), shape=(), dtype=complex128)
tf.Tensor(
[[-7148428.59639369 +722860.14782246j  3344635.2872222 -1701988.48544855j]
 [ 1510335.52961993-2143519.33008589j   332651.82467973 +726036.3995014j ]
 [ -602569.96781076-1867347.92075486j   281438.68263147 +159104.78275908j]
 [ 2045485.47083895-1004726.86275116j   -21274.04612363 +498002.45220749j]
 [-1289616.10426428+1050669.79889739j   223173.71615743 -370169.0300079j ]
 [  374215.81228071 -108445.29379077j   244909.48975163-1322777.40141923j]
 [  659445.09724638 +913069.1564041j  -1399902.86989313-1043189.86436578j]
 [-1058075.12405426 +723951.29161323j  1319569.95345577-1493541.91564771j]
 [ -305925.14703319  -27640.80113209j  1865442.89022878-3487867.62165576j]
 [ 1299376.40700463 -204013.96111753j  -918583.92832523 -162576.43486022j]
 [  273959.3071415 -1245659.95062617j -1300

 [ 1313174.00192368-1874244.24612342j  -135195.34843322 -994113.91367919j]], shape=(16, 2), dtype=complex128)
param tf.Tensor((-7148428.5963936895+722860.1478224574j), shape=(), dtype=complex128)
tf.Tensor(
[[-7148428.59639369 +722860.14782246j  3344635.2872222 -1701988.48544855j]
 [ 1510335.52961993-2143519.33008589j   332651.82467973 +726036.3995014j ]
 [ -602569.96781076-1867347.92075486j   281438.68263147 +159104.78275908j]
 [ 2045485.47083895-1004726.86275116j   -21274.04612363 +498002.45220749j]
 [-1289616.10426428+1050669.79889739j   223173.71615743 -370169.0300079j ]
 [  374215.81228071 -108445.29379077j   244909.48975163-1322777.40141923j]
 [  659445.09724638 +913069.1564041j  -1399902.86989313-1043189.86436578j]
 [-1058075.12405426 +723951.29161323j  1319569.95345577-1493541.91564771j]
 [ -305925.14703319  -27640.80113209j  1865442.89022878-3487867.62165576j]
 [ 1299376.40700463 -204013.96111753j  -918583.92832523 -162576.43486022j]
 [  273959.3071415 -1245659.95062617j -1300

 [ 1313174.00192368-1874244.24612342j  -135195.34843322 -994113.91367919j]], shape=(16, 2), dtype=complex128)
param tf.Tensor((-7148428.5963936895+722860.1478224574j), shape=(), dtype=complex128)
tf.Tensor(
[[-7148428.59639369 +722860.14782246j  3344635.2872222 -1701988.48544855j]
 [ 1510335.52961993-2143519.33008589j   332651.82467973 +726036.3995014j ]
 [ -602569.96781076-1867347.92075486j   281438.68263147 +159104.78275908j]
 [ 2045485.47083895-1004726.86275116j   -21274.04612363 +498002.45220749j]
 [-1289616.10426428+1050669.79889739j   223173.71615743 -370169.0300079j ]
 [  374215.81228071 -108445.29379077j   244909.48975163-1322777.40141923j]
 [  659445.09724638 +913069.1564041j  -1399902.86989313-1043189.86436578j]
 [-1058075.12405426 +723951.29161323j  1319569.95345577-1493541.91564771j]
 [ -305925.14703319  -27640.80113209j  1865442.89022878-3487867.62165576j]
 [ 1299376.40700463 -204013.96111753j  -918583.92832523 -162576.43486022j]
 [  273959.3071415 -1245659.95062617j -1300

 [ 1313174.00192368-1874244.24612342j  -135195.34843322 -994113.91367919j]], shape=(16, 2), dtype=complex128)
param tf.Tensor((-7148428.5963936895+722860.1478224574j), shape=(), dtype=complex128)
tf.Tensor(
[[-7148428.59639369 +722860.14782246j  3344635.2872222 -1701988.48544855j]
 [ 1510335.52961993-2143519.33008589j   332651.82467973 +726036.3995014j ]
 [ -602569.96781076-1867347.92075486j   281438.68263147 +159104.78275908j]
 [ 2045485.47083895-1004726.86275116j   -21274.04612363 +498002.45220749j]
 [-1289616.10426428+1050669.79889739j   223173.71615743 -370169.0300079j ]
 [  374215.81228071 -108445.29379077j   244909.48975163-1322777.40141923j]
 [  659445.09724638 +913069.1564041j  -1399902.86989313-1043189.86436578j]
 [-1058075.12405426 +723951.29161323j  1319569.95345577-1493541.91564771j]
 [ -305925.14703319  -27640.80113209j  1865442.89022878-3487867.62165576j]
 [ 1299376.40700463 -204013.96111753j  -918583.92832523 -162576.43486022j]
 [  273959.3071415 -1245659.95062617j -1300

 [ 1313174.00192368-1874244.24612342j  -135195.34843322 -994113.91367919j]], shape=(16, 2), dtype=complex128)
param tf.Tensor((-7148428.5963936895+722860.1478224574j), shape=(), dtype=complex128)
tf.Tensor(
[[-7148428.59639369 +722860.14782246j  3344635.2872222 -1701988.48544855j]
 [ 1510335.52961993-2143519.33008589j   332651.82467973 +726036.3995014j ]
 [ -602569.96781076-1867347.92075486j   281438.68263147 +159104.78275908j]
 [ 2045485.47083895-1004726.86275116j   -21274.04612363 +498002.45220749j]
 [-1289616.10426428+1050669.79889739j   223173.71615743 -370169.0300079j ]
 [  374215.81228071 -108445.29379077j   244909.48975163-1322777.40141923j]
 [  659445.09724638 +913069.1564041j  -1399902.86989313-1043189.86436578j]
 [-1058075.12405426 +723951.29161323j  1319569.95345577-1493541.91564771j]
 [ -305925.14703319  -27640.80113209j  1865442.89022878-3487867.62165576j]
 [ 1299376.40700463 -204013.96111753j  -918583.92832523 -162576.43486022j]
 [  273959.3071415 -1245659.95062617j -1300

 [ 1313174.00192368-1874244.24612342j  -135195.34843322 -994113.91367919j]], shape=(16, 2), dtype=complex128)
param tf.Tensor((-7148428.5963936895+722860.1478224574j), shape=(), dtype=complex128)
tf.Tensor(
[[-7148428.59639369 +722860.14782246j  3344635.2872222 -1701988.48544855j]
 [ 1510335.52961993-2143519.33008589j   332651.82467973 +726036.3995014j ]
 [ -602569.96781076-1867347.92075486j   281438.68263147 +159104.78275908j]
 [ 2045485.47083895-1004726.86275116j   -21274.04612363 +498002.45220749j]
 [-1289616.10426428+1050669.79889739j   223173.71615743 -370169.0300079j ]
 [  374215.81228071 -108445.29379077j   244909.48975163-1322777.40141923j]
 [  659445.09724638 +913069.1564041j  -1399902.86989313-1043189.86436578j]
 [-1058075.12405426 +723951.29161323j  1319569.95345577-1493541.91564771j]
 [ -305925.14703319  -27640.80113209j  1865442.89022878-3487867.62165576j]
 [ 1299376.40700463 -204013.96111753j  -918583.92832523 -162576.43486022j]
 [  273959.3071415 -1245659.95062617j -1300

 [ 1313174.00192368-1874244.24612342j  -135195.34843322 -994113.91367919j]], shape=(16, 2), dtype=complex128)
param tf.Tensor((-7148428.5963936895+722860.1478224574j), shape=(), dtype=complex128)
tf.Tensor(
[[-7148428.59639369 +722860.14782246j  3344635.2872222 -1701988.48544855j]
 [ 1510335.52961993-2143519.33008589j   332651.82467973 +726036.3995014j ]
 [ -602569.96781076-1867347.92075486j   281438.68263147 +159104.78275908j]
 [ 2045485.47083895-1004726.86275116j   -21274.04612363 +498002.45220749j]
 [-1289616.10426428+1050669.79889739j   223173.71615743 -370169.0300079j ]
 [  374215.81228071 -108445.29379077j   244909.48975163-1322777.40141923j]
 [  659445.09724638 +913069.1564041j  -1399902.86989313-1043189.86436578j]
 [-1058075.12405426 +723951.29161323j  1319569.95345577-1493541.91564771j]
 [ -305925.14703319  -27640.80113209j  1865442.89022878-3487867.62165576j]
 [ 1299376.40700463 -204013.96111753j  -918583.92832523 -162576.43486022j]
 [  273959.3071415 -1245659.95062617j -1300