# Pytorch

In [1]:
import sys
sys.path.insert(0, '../../src/')

import numpy as np
import qiskit as qk
import matplotlib.pyplot as plt
import multiprocessing as mp
import random
import pickle
import torch
import torch.optim as optim

from qiskit.quantum_info import DensityMatrix
from qiskit.quantum_info import Operator
from scipy.linalg import sqrtm
from tqdm.notebook import tqdm

from cost_functions import *
from optimization import *
from quantum_maps import *
from quantum_tools import *
from src_torch import *
#np.set_printoptions(threshold=sys.maxsize)

## Implement

In [2]:
def apply_map(state, kraus_list):
    state = [K@state@K.T.conj() for K in kraus_list]
    state = torch.stack(state, dim=0).sum(dim=0)
    return state

In [10]:
random.seed(42)
np.random.seed(42)

n = 3
d = 2**n
rank = 8

_, A, B = generate_ginibre(rank*d, d, requires_grad=False)
G = A + 1j*B

U = generate_unitary(G)
kraus_target_list =  [U[i*d:(i+1)*d, :d] for i in range(rank)]

In [20]:
N = 1000
state_index, observ_index = index_generator(n, N, trace=False)

input_list = []
circuit_list = []
for i, j in zip(state_index, observ_index):

    config = numberToBase(i, 6, n)
    state = prepare_input(config)

    config = numberToBase(j, 3, n)
    observable = pauli_observable(config)
    
    input_list.append([state, observable])

target_list = [expectation_value(apply_map(input[0], kraus_target_list), input[1]) for input in input_list]

_, A, B = generate_ginibre(rank*d, d, requires_grad=True)
optimizer = optim.Adam([A, B], lr=0.01)

In [21]:
for i in tqdm(range(1000)):
    G = A + 1j*B
    U = generate_unitary(G)
    kraus_model_list =  [U[i*d:(i+1)*d, :d] for i in range(rank)]
    pred_list = [expectation_value(apply_map(input[0], kraus_model_list), input[1]) for input in input_list]
    loss = torch.mean(torch.stack([(target - predicted)**2 for target, predicted in zip(target_list, pred_list)]))
    loss.backward()
    optimizer.step()
    print(f"{np.abs(loss.detach().numpy()):.4f},  {np.mean(np.abs(G.detach().numpy())):.4f}, {np.max(np.abs(G.detach().numpy())):.4f}")
    


  0%|          | 0/1000 [00:00<?, ?it/s]

0.0289,  1.2290, 3.5621
0.0277,  1.2293, 3.5510
0.0267,  1.2296, 3.5402
0.0256,  1.2302, 3.5293
0.0247,  1.2309, 3.5182
0.0237,  1.2318, 3.5068
0.0228,  1.2329, 3.4952
0.0218,  1.2341, 3.4832
0.0210,  1.2356, 3.4710
0.0201,  1.2372, 3.4584
0.0193,  1.2390, 3.4454
0.0185,  1.2411, 3.4321
0.0178,  1.2433, 3.4185
0.0171,  1.2457, 3.4045
0.0164,  1.2483, 3.3902
0.0158,  1.2511, 3.3755
0.0152,  1.2541, 3.3605
0.0147,  1.2574, 3.3451
0.0142,  1.2609, 3.3295
0.0137,  1.2647, 3.3134
0.0133,  1.2687, 3.2971
0.0129,  1.2730, 3.2901
0.0125,  1.2775, 3.2968
0.0121,  1.2823, 3.3068
0.0118,  1.2873, 3.3294
0.0115,  1.2926, 3.3521
0.0113,  1.2981, 3.3750
0.0110,  1.3038, 3.3981
0.0108,  1.3098, 3.4214
0.0106,  1.3161, 3.4449
0.0104,  1.3226, 3.4685
0.0102,  1.3293, 3.4922
0.0101,  1.3362, 3.5161
0.0099,  1.3434, 3.5401
0.0098,  1.3507, 3.5643
0.0096,  1.3583, 3.5885
0.0095,  1.3661, 3.6128
0.0094,  1.3742, 3.6372
0.0092,  1.3824, 3.6617
0.0091,  1.3909, 3.6862
0.0090,  1.3997, 3.7108
0.0089,  1.4086,

0.0030,  6.3577, 11.6324
0.0030,  6.3755, 11.6627
0.0030,  6.3934, 11.6930
0.0030,  6.4113, 11.7232
0.0029,  6.4292, 11.7534
0.0029,  6.4470, 11.7836
0.0029,  6.4649, 11.8138
0.0029,  6.4828, 11.8439
0.0029,  6.5007, 11.8741
0.0029,  6.5185, 11.9041
0.0029,  6.5364, 11.9342
0.0029,  6.5543, 11.9642
0.0029,  6.5722, 11.9943
0.0029,  6.5901, 12.0242
0.0029,  6.6079, 12.0542
0.0029,  6.6258, 12.0841
0.0029,  6.6437, 12.1140
0.0029,  6.6616, 12.1439
0.0029,  6.6795, 12.1738
0.0029,  6.6974, 12.2036
0.0029,  6.7153, 12.2334
0.0029,  6.7332, 12.2632
0.0029,  6.7511, 12.2929
0.0029,  6.7690, 12.3227
0.0029,  6.7869, 12.3524
0.0028,  6.8048, 12.3820
0.0028,  6.8227, 12.4117
0.0028,  6.8406, 12.4413
0.0028,  6.8586, 12.4709
0.0028,  6.8765, 12.5005
0.0028,  6.8944, 12.5300
0.0028,  6.9124, 12.5595
0.0028,  6.9303, 12.5890
0.0028,  6.9482, 12.6185
0.0028,  6.9662, 12.6479
0.0028,  6.9841, 12.6773
0.0028,  7.0021, 12.7067
0.0028,  7.0200, 12.7361
0.0028,  7.0380, 12.7654
0.0028,  7.0559, 12.7947


KeyboardInterrupt: 

In [52]:
random.seed(42)
np.random.seed(42)

n = 3
d = 2**n
rank = 1

_, A, B = generate_ginibre(rank*d, d, requires_grad=True)
G = A
C = torch.rand(rank*d, d).type(torch.complex128)
R = torch.rand(rank*d, rank*d).type(torch.complex128)
R, _ = torch.linalg.qr(R)
print(torch.linalg.cond(A))


U1 = generate_unitary(G)
out1 = torch.autograd.grad(U1, G, C)[0]
G = A
G = R@G

U2 = R.T.conj()@generate_unitary(G)
out2 = torch.autograd.grad(U2, G, C)[0]


print(torch.norm(U1-U2))
print(torch.norm(out1-out2))

tensor([[ 0.4967+0.j, -0.1383+0.j,  0.6477+0.j,  1.5230+0.j, -0.2342+0.j, -0.2341+0.j,
          1.5792+0.j,  0.7674+0.j],
        [-0.4695+0.j,  0.5426+0.j, -0.4634+0.j, -0.4657+0.j,  0.2420+0.j, -1.9133+0.j,
         -1.7249+0.j, -0.5623+0.j],
        [-1.0128+0.j,  0.3142+0.j, -0.9080+0.j, -1.4123+0.j,  1.4656+0.j, -0.2258+0.j,
          0.0675+0.j, -1.4247+0.j],
        [-0.5444+0.j,  0.1109+0.j, -1.1510+0.j,  0.3757+0.j, -0.6006+0.j, -0.2917+0.j,
         -0.6017+0.j,  1.8523+0.j],
        [-0.0135+0.j, -1.0577+0.j,  0.8225+0.j, -1.2208+0.j,  0.2089+0.j, -1.9597+0.j,
         -1.3282+0.j,  0.1969+0.j],
        [ 0.7385+0.j,  0.1714+0.j, -0.1156+0.j, -0.3011+0.j, -1.4785+0.j, -0.7198+0.j,
         -0.4606+0.j,  1.0571+0.j],
        [ 0.3436+0.j, -1.7630+0.j,  0.3241+0.j, -0.3851+0.j, -0.6769+0.j,  0.6117+0.j,
          1.0310+0.j,  0.9313+0.j],
        [-0.8392+0.j, -0.3092+0.j,  0.3313+0.j,  0.9755+0.j, -0.4792+0.j, -0.1857+0.j,
         -1.1063+0.j, -1.1962+0.j]], dtype=torch.com

In [49]:
import jax
import jax.numpy as jnp
import numpy as np

x = np.random.randn(3, 3)
dx = np.random.randn(3, 3)

primals, tangents = jax.jvp(jnp.linalg.qr, (x,), (dx,))
q, r = primals
dq, dr = tangents

dt = 1e-6
dq_ = (np.linalg.qr(x + dt * dx)[0] - np.linalg.qr(x)[0]) / dt
dr_ = (np.linalg.qr(x + dt * dx)[1] - np.linalg.qr(x)[1]) / dt

assert jnp.allclose(x, q @ r, atol=1e-5, rtol=1e-5)  # passes
assert jnp.allclose(dq, dq_, atol=1e-5, rtol=1e-5)  # passes
assert jnp.allclose(dx, q @ dr_ + dq_ @ r, atol=1e-5, rtol=1e-5)  # passes
assert jnp.allclose(dr, dr_, atol=1e-5, rtol=1e-5)  # fails

ModuleNotFoundError: No module named 'jax'