In [1]:
import sys
import torch
import tntorch as tn
import numpy as np
sys.path.append('..')
import lib

n_qubits = 5
rho = lib.randomMixedState(2 ** n_qubits)
# rho = lib.randomPureState(2 ** n_qubits)
dim = rho.shape[0]

projectors_cnt = 1000
measurements_cnt = 1000
tensor_rank = 2**n_qubits
batch_size = 100
lr = 1e-3

train_size = projectors_cnt * measurements_cnt
train_X, train_y = lib.generate_dataset(rho, projectors_cnt, measurements_cnt)
test_X, test_y = lib.generate_dataset(rho, projectors_cnt, measurements_cnt)
train_y, test_y = [x.astype('float64') for x in [train_y, test_y]]

sigma = lib.simulator.randomMixedState(rho.shape[0])
sigma_real, sigma_imag = [tn.Tensor(x, ranks_tt=tensor_rank, requires_grad=True) for x in [np.real(sigma), np.imag(sigma)]]

epoches = 10

def trace(tensor):
    if len(tensor.shape) == 2:
        return sum([tensor[i,i] for i in range(tensor.shape[0])])
    if len(tensor.shape) == 3:
        return sum([tensor[:, i,i] for i in range(tensor.shape[1])])

def cholesky(sigma_real, sigma_imag):
    sigma_real, sigma_imag = sigma_real.dot(sigma_real, k=1)+sigma_imag.dot(sigma_imag, k=1), sigma_real.dot(sigma_imag, k=1)-sigma_imag.dot(sigma_real, k=1)
    trace = sum([sigma_real[i,i] for i in range(rho.shape[0])])
    sigma_real, sigma_imag = [x/trace for x in [sigma_real, sigma_imag]]
    return sigma_real, sigma_imag
    
initial_trace = sum((np.real(trace(test_X.dot(sigma)))-test_y)**2)

def loss(sigma_real, sigma_imag):
    sigma_real, sigma_imag = cholesky(sigma_real, sigma_imag)
    res = 0
    idx = np.random.choice(np.arange(train_X.shape[0]), batch_size)
    for E_m,y_m in zip(train_X[idx], train_y[idx].astype('float64')):
        E_real, E_imag = [tn.Tensor(x) for x in [np.real(E_m), np.imag(E_m)]]
        res += ((E_real.dot(sigma_real)+E_imag.dot(sigma_imag)-y_m)**2)
#     return res/(initial_trace*train_X.shape[0])
    return res
    

def eval_loss(sigma_real, sigma_imag): # any score function can be used here
    sigma_real_, sigma_imag_ = cholesky(sigma_real, sigma_imag)
    sigma = sigma_real_.torch().detach().cpu().numpy() + 1j*sigma_imag_.torch().detach().cpu().numpy()
    return -lib.fidelity(sigma, rho)

print('Trace before: %f'%initial_trace)

lib.tn_optimize([sigma_real, sigma_imag], loss, eval_loss, tol=0, patience=2000,print_freq=10,lr=lr)

Trace before: 0.085483
Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers
iter: 0 |  loss: 0.014046844094991684 | eval_loss: -0.430651532039372 | total time: 0.14774203300476074


	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  /opt/conda/conda-bld/pytorch_1595629395347/work/torch/csrc/utils/python_arg_parser.cpp:766.)
  exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)


iter: 10 |  loss: 0.012899213470518589 | eval_loss: -0.4334349501431574 | total time: 1.8333194255828857
iter: 20 |  loss: 0.015273394994437695 | eval_loss: -0.4392879553796723 | total time: 3.526559352874756
iter: 30 |  loss: 0.014648406766355038 | eval_loss: -0.44631364030503584 | total time: 5.223982810974121
iter: 40 |  loss: 0.014566272497177124 | eval_loss: -0.4606333068082315 | total time: 6.921551704406738
iter: 50 |  loss: 0.012453748844563961 | eval_loss: -0.47031423670244205 | total time: 8.66620945930481
iter: 60 |  loss: 0.010590963065624237 | eval_loss: -0.4799415196836345 | total time: 10.379002332687378
iter: 70 |  loss: 0.009406046941876411 | eval_loss: -0.4976516142505171 | total time: 12.094162225723267
iter: 80 |  loss: 0.008594349958002567 | eval_loss: -0.5085975225231042 | total time: 13.805412292480469
iter: 90 |  loss: 0.007401490584015846 | eval_loss: -0.5187679919766648 | total time: 15.518301963806152
iter: 100 |  loss: 0.00741082476451993 | eval_loss: -0.534

iter: 790 |  loss: 0.0047329990193247795 | eval_loss: -0.5812895666146722 | total time: 136.43110990524292
iter: 800 |  loss: 0.0036884888540953398 | eval_loss: -0.5826054316003243 | total time: 138.15145301818848
iter: 810 |  loss: 0.0033088403288275003 | eval_loss: -0.583918753480867 | total time: 139.8732466697693
iter: 820 |  loss: 0.004080140497535467 | eval_loss: -0.5869681911663832 | total time: 141.60026860237122
iter: 830 |  loss: 0.0038615968078374863 | eval_loss: -0.5882275520909288 | total time: 143.32365322113037
iter: 840 |  loss: 0.0028580192010849714 | eval_loss: -0.5895909607843822 | total time: 145.045170545578
iter: 850 |  loss: 0.003692772937938571 | eval_loss: -0.5919652214762431 | total time: 146.76451539993286
iter: 860 |  loss: 0.0038655688986182213 | eval_loss: -0.5923445659252985 | total time: 148.48615169525146
iter: 870 |  loss: 0.0034922335762530565 | eval_loss: -0.5933538784910527 | total time: 150.21140503883362
iter: 880 |  loss: 0.0036783358082175255 | 

iter: 1560 |  loss: 0.0016806183848530054 | eval_loss: -0.6525502720536444 | total time: 269.0271461009979
iter: 1570 |  loss: 0.0022661180701106787 | eval_loss: -0.6541977773125967 | total time: 270.74621629714966
iter: 1580 |  loss: 0.0029354013968259096 | eval_loss: -0.6554997386831072 | total time: 272.47027134895325
iter: 1590 |  loss: 0.0023317858576774597 | eval_loss: -0.6564703260113747 | total time: 274.1956329345703
iter: 1600 |  loss: 0.0021261456422507763 | eval_loss: -0.6569218468374863 | total time: 275.91583609580994
iter: 1610 |  loss: 0.00231386860832572 | eval_loss: -0.657668173897084 | total time: 277.63439774513245
iter: 1620 |  loss: 0.00210010027512908 | eval_loss: -0.6579131662512361 | total time: 279.35463285446167
iter: 1630 |  loss: 0.002644770545884967 | eval_loss: -0.6588260388387378 | total time: 281.07910346984863
iter: 1640 |  loss: 0.002053616801276803 | eval_loss: -0.6593681105284775 | total time: 282.7997989654541
iter: 1650 |  loss: 0.0029947673901915

iter: 2330 |  loss: 0.0018880777060985565 | eval_loss: -0.6748906689794709 | total time: 401.7221920490265
iter: 2340 |  loss: 0.0022962994407862425 | eval_loss: -0.6742880608502293 | total time: 403.4359350204468
iter: 2350 |  loss: 0.0020686716306954622 | eval_loss: -0.6733909754197054 | total time: 405.1471948623657
iter: 2360 |  loss: 0.002032740041613579 | eval_loss: -0.672873944972715 | total time: 406.8593225479126
iter: 2370 |  loss: 0.0019829433877021074 | eval_loss: -0.6728691272419919 | total time: 408.57242155075073
iter: 2380 |  loss: 0.0020899411756545305 | eval_loss: -0.6724012714372731 | total time: 410.2849233150482
iter: 2390 |  loss: 0.0022958226036280394 | eval_loss: -0.6717560836590973 | total time: 412.0000238418579
iter: 2400 |  loss: 0.002227973425760865 | eval_loss: -0.6708866067043986 | total time: 413.71356201171875
iter: 2410 |  loss: 0.0026839878410100937 | eval_loss: -0.6701878678936289 | total time: 415.4250829219818
iter: 2420 |  loss: 0.0026509214658290

iter: 3100 |  loss: 0.0026881685480475426 | eval_loss: -0.6356920699170402 | total time: 534.6024484634399
iter: 3110 |  loss: 0.0028933477587997913 | eval_loss: -0.63494881106241 | total time: 536.3207051753998
iter: 3120 |  loss: 0.0023707104846835136 | eval_loss: -0.6344484682928513 | total time: 538.040664434433
iter: 3130 |  loss: 0.0026107251178473234 | eval_loss: -0.6339988174277728 | total time: 539.7624499797821
iter: 3140 |  loss: 0.0020035412162542343 | eval_loss: -0.63377149733072 | total time: 541.4850966930389
iter: 3150 |  loss: 0.0029705222696065903 | eval_loss: -0.6335198482337298 | total time: 543.2089223861694
iter: 3160 |  loss: 0.002157457871362567 | eval_loss: -0.6335055297739257 | total time: 544.9294562339783
iter: 3170 |  loss: 0.0028420337475836277 | eval_loss: -0.6331867519280131 | total time: 546.6498072147369
iter: 3180 |  loss: 0.0023309742100536823 | eval_loss: -0.6329793170399153 | total time: 548.3695130348206
iter: 3190 |  loss: 0.0021435623057186604 |

iter: 3880 |  loss: 0.003227266948670149 | eval_loss: -0.6253746255739464 | total time: 669.026852607727
iter: 3890 |  loss: 0.002913880627602339 | eval_loss: -0.6250395621235904 | total time: 670.7491643428802
iter: 3900 |  loss: 0.002028072252869606 | eval_loss: -0.6250655876085403 | total time: 672.4715733528137
iter: 3910 |  loss: 0.0031193478498607874 | eval_loss: -0.6250529047397204 | total time: 674.1931204795837
iter: 3920 |  loss: 0.002325603971257806 | eval_loss: -0.6250681876735835 | total time: 675.917402267456
iter: 3930 |  loss: 0.00293358345516026 | eval_loss: -0.6249762978173949 | total time: 677.6438064575195
iter: 3940 |  loss: 0.0034209792502224445 | eval_loss: -0.6248953597540392 | total time: 679.3661711215973
iter: 3950 |  loss: 0.0027281739749014378 | eval_loss: -0.6249394715260119 | total time: 681.0909879207611
iter: 3960 |  loss: 0.002814891282469034 | eval_loss: -0.6250830762635331 | total time: 682.8141086101532
iter: 3970 |  loss: 0.002519759349524975 | eva

(4219, 0.6791421218641134, 2218)