In [None]:
import sys
import torch
import tntorch as tn
import numpy as np
sys.path.append('..')
import lib

n_qubits = 2
rho = lib.randomMixedState(2 ** n_qubits)
dim = rho.shape[0]

projectors_cnt = 10000
measurements_cnt = 1000
tensor_rank = 2**n_qubits
batch_size = 100
lr = 1e-3

train_size = projectors_cnt * measurements_cnt
train_X, train_y = lib.generate_dataset(rho, projectors_cnt, measurements_cnt)
test_X, test_y = lib.generate_dataset(rho, projectors_cnt, measurements_cnt)
train_y, test_y = [x.astype('float64') for x in [train_y, test_y]]

sigma = lib.simulator.randomMixedState(rho.shape[0])
sigma_real, sigma_imag = [tn.Tensor(x, ranks_tt=tensor_rank, requires_grad=True) for x in [np.real(sigma), np.imag(sigma)]]

epoches = 10

def trace(tensor):
    if len(tensor.shape) == 2:
        return sum([tensor[i,i] for i in range(tensor.shape[0])])
    if len(tensor.shape) == 3:
        return sum([tensor[:, i,i] for i in range(tensor.shape[1])])

def cholesky(sigma_real, sigma_imag):
    sigma_real, sigma_imag = sigma_real.dot(sigma_real, k=1)+sigma_imag.dot(sigma_imag, k=1), sigma_real.dot(sigma_imag, k=1)-sigma_imag.dot(sigma_real, k=1)
    trace = sum([sigma_real[i,i] for i in range(rho.shape[0])])
    sigma_real, sigma_imag = [x/trace for x in [sigma_real, sigma_imag]]
    return sigma_real, sigma_imag
    
initial_trace = sum((np.real(trace(test_X.dot(sigma)))-test_y)**2)

def loss(sigma_real, sigma_imag):
    sigma_real, sigma_imag = cholesky(sigma_real, sigma_imag)
    res = 0
    idx = np.random.choice(np.arange(train_X.shape[0]), batch_size)
    for E_m,y_m in zip(train_X[idx], train_y[idx].astype('float64')):
        E_real, E_imag = [tn.Tensor(x) for x in [np.real(E_m), np.imag(E_m)]]
        res += ((E_real.dot(sigma_real)+E_imag.dot(sigma_imag)-y_m)**2)
#     return res/(initial_trace*train_X.shape[0])
    return res
    

def eval_loss(sigma_real, sigma_imag): # any score function can be used here
    sigma_real_, sigma_imag_ = cholesky(sigma_real, sigma_imag)
    sigma = sigma_real_.torch().detach().cpu().numpy() + 1j*sigma_imag_.torch().detach().cpu().numpy()
    return -lib.fidelity(sigma, rho)

print('Trace before: %f'%initial_trace)

lib.tn_optimize([sigma_real, sigma_imag], loss, eval_loss, tol=0, patience=2000,print_freq=10,lr=lr)

Trace before: 143.777468
Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers
iter: 0 |  loss: 3.404710054397583 | eval_loss: -0.5682538140116127 | total time: 0.1655430793762207
iter: 10 |  loss: 2.415814161300659 | eval_loss: -0.5744294592289965 | total time: 1.5378386974334717
iter: 20 |  loss: 2.05648136138916 | eval_loss: -0.5754942838429137 | total time: 2.927067279815674
iter: 30 |  loss: 1.9720617532730103 | eval_loss: -0.5767641284400137 | total time: 4.329918622970581
iter: 40 |  loss: 2.907911539077759 | eval_loss: -0.5795045900502244 | total time: 5.7087273597717285
iter: 50 |  loss: 2.2552924156188965 | eval_loss: -0.5812878208407846 | total time: 7.078037261962891
iter: 60 |  loss: 2.868164539337158 | eval_loss: -0.5832308044111137 | total time: 8.448633193969727
iter: 70 |  loss: 2.441403865814209 | eval_loss: -0.5868910162663914 | total time: 9.830496311187744
iter: 80 |  loss: 2.823009729385376 | eval_loss: -0.58931510407

iter: 790 |  loss: 1.2415612936019897 | eval_loss: -0.7907115469991314 | total time: 109.48016619682312
iter: 800 |  loss: 1.0152640342712402 | eval_loss: -0.7912939516785907 | total time: 110.84853672981262
iter: 810 |  loss: 1.1604152917861938 | eval_loss: -0.7918761437690199 | total time: 112.21671056747437
iter: 820 |  loss: 1.438362717628479 | eval_loss: -0.7928265724077037 | total time: 113.60474872589111
iter: 830 |  loss: 1.2187021970748901 | eval_loss: -0.7942319298068583 | total time: 115.00127816200256
iter: 840 |  loss: 1.5330846309661865 | eval_loss: -0.7948381245564962 | total time: 116.39521169662476
iter: 850 |  loss: 1.3045424222946167 | eval_loss: -0.7962562846011156 | total time: 117.77198910713196
iter: 860 |  loss: 1.5787254571914673 | eval_loss: -0.7972480663510277 | total time: 119.14124464988708
iter: 870 |  loss: 1.254487156867981 | eval_loss: -0.798452750533263 | total time: 120.50940084457397
iter: 880 |  loss: 0.8339923024177551 | eval_loss: -0.8003955636978

iter: 1580 |  loss: 0.32819613814353943 | eval_loss: -0.9196394865294223 | total time: 221.6062126159668
iter: 1590 |  loss: 0.38671764731407166 | eval_loss: -0.9211350185495409 | total time: 223.04896354675293
iter: 1600 |  loss: 0.31673768162727356 | eval_loss: -0.9233236719040947 | total time: 224.4914629459381
iter: 1610 |  loss: 0.3345455527305603 | eval_loss: -0.9240471351934308 | total time: 225.93574833869934
iter: 1620 |  loss: 0.3686099350452423 | eval_loss: -0.9245673030969499 | total time: 227.3827645778656
iter: 1630 |  loss: 0.27797791361808777 | eval_loss: -0.9260467536269171 | total time: 228.82793259620667
iter: 1640 |  loss: 0.30448469519615173 | eval_loss: -0.9269629241628983 | total time: 230.27480030059814
iter: 1650 |  loss: 0.29327499866485596 | eval_loss: -0.928059432211033 | total time: 231.72073793411255
iter: 1660 |  loss: 0.24896782636642456 | eval_loss: -0.9302579526170696 | total time: 233.1659836769104
iter: 1670 |  loss: 0.2954065501689911 | eval_loss: -

iter: 2360 |  loss: 0.05409759655594826 | eval_loss: -0.9738260210121757 | total time: 331.23059916496277
iter: 2370 |  loss: 0.05783643200993538 | eval_loss: -0.974127681185943 | total time: 332.6145975589752
iter: 2380 |  loss: 0.04466232284903526 | eval_loss: -0.9746984842076022 | total time: 334.00460982322693
iter: 2390 |  loss: 0.05922986939549446 | eval_loss: -0.9749340655480802 | total time: 335.43933272361755
iter: 2400 |  loss: 0.045596666634082794 | eval_loss: -0.9753326237961117 | total time: 336.8190236091614
iter: 2410 |  loss: 0.05637525022029877 | eval_loss: -0.9759253759390576 | total time: 338.21006536483765
iter: 2420 |  loss: 0.04459298774600029 | eval_loss: -0.9762416411028355 | total time: 339.5972363948822
iter: 2430 |  loss: 0.04398251324892044 | eval_loss: -0.9765185109197932 | total time: 340.9701862335205
iter: 2440 |  loss: 0.048574015498161316 | eval_loss: -0.9768256641990294 | total time: 342.33889389038086
iter: 2450 |  loss: 0.04081076756119728 | eval_lo

iter: 3140 |  loss: 0.01643720082938671 | eval_loss: -0.9940165661406418 | total time: 439.3362112045288
iter: 3150 |  loss: 0.019799163565039635 | eval_loss: -0.9941362086714838 | total time: 440.7151641845703
iter: 3160 |  loss: 0.017619550228118896 | eval_loss: -0.9943523489024639 | total time: 442.1126003265381
iter: 3170 |  loss: 0.018629662692546844 | eval_loss: -0.9944569484283344 | total time: 443.5041518211365
iter: 3180 |  loss: 0.016518641263246536 | eval_loss: -0.9945536350991021 | total time: 444.8807997703552
iter: 3190 |  loss: 0.018780820071697235 | eval_loss: -0.9947167654476268 | total time: 446.3025896549225
iter: 3200 |  loss: 0.01799134537577629 | eval_loss: -0.9948283308581394 | total time: 447.72137117385864
iter: 3210 |  loss: 0.019367488101124763 | eval_loss: -0.9949021348445627 | total time: 449.1456959247589
iter: 3220 |  loss: 0.01963702030479908 | eval_loss: -0.995038229613914 | total time: 450.59009981155396
iter: 3230 |  loss: 0.02046029642224312 | eval_l

iter: 3920 |  loss: 0.022502640262246132 | eval_loss: -0.9987958617848886 | total time: 552.0493190288544
iter: 3930 |  loss: 0.015180538408458233 | eval_loss: -0.998834164212885 | total time: 553.4936912059784
iter: 3940 |  loss: 0.014646490104496479 | eval_loss: -0.9988767235441665 | total time: 554.9398934841156
iter: 3950 |  loss: 0.015363380312919617 | eval_loss: -0.9988754412890944 | total time: 556.3886637687683
iter: 3960 |  loss: 0.02269819565117359 | eval_loss: -0.9988816889989015 | total time: 557.8334929943085
iter: 3970 |  loss: 0.01884341612458229 | eval_loss: -0.998913576748749 | total time: 559.2797520160675
iter: 3980 |  loss: 0.015454558655619621 | eval_loss: -0.9989231384587326 | total time: 560.7283456325531
iter: 3990 |  loss: 0.016774943098425865 | eval_loss: -0.9989239650140039 | total time: 562.173574924469
iter: 4000 |  loss: 0.018961144611239433 | eval_loss: -0.9989357232996161 | total time: 563.6198356151581
iter: 4010 |  loss: 0.017207378521561623 | eval_los

iter: 4700 |  loss: 0.018930708989501 | eval_loss: -0.9997686906212988 | total time: 665.2502422332764
iter: 4710 |  loss: 0.019967617467045784 | eval_loss: -0.9997542837050595 | total time: 666.6994445323944
iter: 4720 |  loss: 0.020121442154049873 | eval_loss: -0.9997310992588931 | total time: 668.1533980369568
iter: 4730 |  loss: 0.014940804801881313 | eval_loss: -0.999757954393641 | total time: 669.6045167446136
iter: 4740 |  loss: 0.019571736454963684 | eval_loss: -0.9997618254401995 | total time: 671.0514943599701
iter: 4750 |  loss: 0.01773892156779766 | eval_loss: -0.9997552449923245 | total time: 672.5056927204132
iter: 4760 |  loss: 0.023564336821436882 | eval_loss: -0.9997709692850776 | total time: 673.9557802677155
iter: 4770 |  loss: 0.012261240743100643 | eval_loss: -0.9997940559990716 | total time: 675.4099299907684
iter: 4780 |  loss: 0.017275072634220123 | eval_loss: -0.9998079474756975 | total time: 676.8726308345795
iter: 4790 |  loss: 0.017312824726104736 | eval_los

iter: 5480 |  loss: 0.02303851582109928 | eval_loss: -0.9998561291056874 | total time: 778.3890686035156
iter: 5490 |  loss: 0.021586904302239418 | eval_loss: -0.9998704995084744 | total time: 779.8355829715729
iter: 5500 |  loss: 0.017218785360455513 | eval_loss: -0.9998753646405679 | total time: 781.2774674892426
iter: 5510 |  loss: 0.013053630478680134 | eval_loss: -0.9998746110966842 | total time: 782.7234995365143
iter: 5520 |  loss: 0.01604197360575199 | eval_loss: -0.9998845770527062 | total time: 784.172126531601
iter: 5530 |  loss: 0.017853500321507454 | eval_loss: -0.9998855073095001 | total time: 785.6142032146454
iter: 5540 |  loss: 0.019230064004659653 | eval_loss: -0.9998863328900935 | total time: 787.0608406066895
iter: 5550 |  loss: 0.020116442814469337 | eval_loss: -0.9998891364806661 | total time: 788.5024540424347
iter: 5560 |  loss: 0.018591152504086494 | eval_loss: -0.9998703767934352 | total time: 789.942455291748
iter: 5570 |  loss: 0.02068791538476944 | eval_los

iter: 6260 |  loss: 0.02115466445684433 | eval_loss: -0.9999095611095579 | total time: 891.0714311599731
iter: 6270 |  loss: 0.018589694052934647 | eval_loss: -0.9998986234984498 | total time: 892.5084755420685
iter: 6280 |  loss: 0.01740545779466629 | eval_loss: -0.999913027694696 | total time: 893.9478526115417
iter: 6290 |  loss: 0.013118299655616283 | eval_loss: -0.9999414828496721 | total time: 895.3935732841492
iter: 6300 |  loss: 0.022544795647263527 | eval_loss: -0.9999406442753922 | total time: 896.8373577594757
iter: 6310 |  loss: 0.013956772163510323 | eval_loss: -0.9998941498849416 | total time: 898.2793281078339
iter: 6320 |  loss: 0.015886228531599045 | eval_loss: -0.9999386678101263 | total time: 899.7190141677856
iter: 6330 |  loss: 0.021397972479462624 | eval_loss: -0.9999497557572921 | total time: 901.1626243591309
iter: 6340 |  loss: 0.01622798852622509 | eval_loss: -0.9999460060191345 | total time: 902.6006684303284
iter: 6350 |  loss: 0.013096433132886887 | eval_lo

iter: 7040 |  loss: 0.016816692426800728 | eval_loss: -0.9996161573143929 | total time: 1003.7152941226959
iter: 7050 |  loss: 0.012929052114486694 | eval_loss: -0.9996509580838797 | total time: 1005.1613664627075
iter: 7060 |  loss: 0.022397279739379883 | eval_loss: -0.9996887316491977 | total time: 1006.6021001338959
iter: 7070 |  loss: 0.018551819026470184 | eval_loss: -0.9996815990640089 | total time: 1008.0473194122314
iter: 7080 |  loss: 0.020773207768797874 | eval_loss: -0.9996449035469394 | total time: 1009.4874546527863
iter: 7090 |  loss: 0.02055175229907036 | eval_loss: -0.9996186534717001 | total time: 1010.9278755187988
iter: 7100 |  loss: 0.020637867972254753 | eval_loss: -0.9996735325771039 | total time: 1012.3894610404968
iter: 7110 |  loss: 0.016416851431131363 | eval_loss: -0.999653917851847 | total time: 1013.838892698288
iter: 7120 |  loss: 0.022129591554403305 | eval_loss: -0.9996406818455 | total time: 1015.2924745082855
iter: 7130 |  loss: 0.01898905262351036 | e

iter: 7810 |  loss: 0.020701812580227852 | eval_loss: -0.999639359963147 | total time: 1115.656299829483
iter: 7820 |  loss: 0.01941697485744953 | eval_loss: -0.999698566257402 | total time: 1117.1049728393555
iter: 7830 |  loss: 0.017733367159962654 | eval_loss: -0.9997144582168158 | total time: 1118.5567893981934
iter: 7840 |  loss: 0.015565105713903904 | eval_loss: -0.9997438548809047 | total time: 1120.021585226059
iter: 7850 |  loss: 0.020632566884160042 | eval_loss: -0.9997155495471954 | total time: 1121.4782209396362
