In [None]:
import sys
import torch
import tntorch as tn
import numpy as np
sys.path.append('..')
import lib

n_qubits = 5
rho = lib.randomMixedState(2 ** n_qubits)
dim = rho.shape[0]

projectors_cnt = 10000
measurements_cnt = 1000
tensor_rank = 2**n_qubits
batch_size = 100
lr = 1e-2

train_size = projectors_cnt * measurements_cnt
train_X, train_y = lib.generate_dataset(rho, projectors_cnt, measurements_cnt)
test_X, test_y = lib.generate_dataset(rho, projectors_cnt, measurements_cnt)
train_y, test_y = [x.astype('float64') for x in [train_y, test_y]]

sigma = lib.simulator.randomMixedState(rho.shape[0])
sigma_real, sigma_imag = [tn.Tensor(x, ranks_tt=tensor_rank, requires_grad=True) for x in [np.real(sigma), np.imag(sigma)]]

epoches = 10

def trace(tensor):
    if len(tensor.shape) == 2:
        return sum([tensor[i,i] for i in range(tensor.shape[0])])
    if len(tensor.shape) == 3:
        return sum([tensor[:, i,i] for i in range(tensor.shape[1])])

def cholesky(sigma_real, sigma_imag):
    return sigma_real.dot(sigma_real, k=1)+sigma_imag.dot(sigma_imag, k=1), sigma_real.dot(sigma_imag, k=1)-sigma_imag.dot(sigma_real, k=1)
    
initial_trace = sum((np.real(trace(test_X.dot(sigma)))-test_y)**2)

def loss(sigma_real, sigma_imag):
    sigma_real, sigma_imag = cholesky(sigma_real, sigma_imag)
    res = 0
    idx = np.random.choice(np.arange(train_X.shape[0]), batch_size)
    for E_m,y_m in zip(train_X[idx], train_y[idx].astype('float64')):
        E_real, E_imag = [tn.Tensor(x) for x in [np.real(E_m), np.imag(E_m)]]
        res += ((E_real.dot(sigma_real)+E_imag.dot(sigma_imag)-y_m)**2)
#     return res/(initial_trace*train_X.shape[0])
    return res
    

def eval_loss(sigma_real, sigma_imag): # any score function can be used here
    sigma_real_, sigma_imag_ = cholesky(sigma_real, sigma_imag)
    sigma = sigma_real_.torch().detach().cpu().numpy() + 1j*sigma_imag_.torch().detach().cpu().numpy()
    return -lib.fidelity(sigma, rho)

print('Trace before: %f'%initial_trace)

lib.tn_optimize([sigma_real, sigma_imag], loss, eval_loss, tol=0, patience=2000,print_freq=10,lr=lr)

Trace before: 0.938808
Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers
iter: 0 |  loss: 0.08874133974313736 | eval_loss: -0.02754437245247337 | total time: 0.14484190940856934
iter: 10 |  loss: 0.08633831143379211 | eval_loss: -0.04033354781538823 | total time: 1.9317750930786133
iter: 20 |  loss: 0.07312344759702682 | eval_loss: -0.07845271617422812 | total time: 3.6664958000183105
iter: 30 |  loss: 0.05695733055472374 | eval_loss: -0.1525715364328853 | total time: 5.414903402328491
iter: 40 |  loss: 0.010470649227499962 | eval_loss: -0.43321363487366465 | total time: 7.195456027984619
iter: 50 |  loss: 0.017762944102287292 | eval_loss: -0.6205210311944537 | total time: 8.978639364242554
iter: 60 |  loss: 0.015383871272206306 | eval_loss: -0.6312848503211329 | total time: 10.76115083694458
iter: 70 |  loss: 0.011574653908610344 | eval_loss: -0.525165009543168 | total time: 12.543503761291504
iter: 80 |  loss: 0.00664135068655014 | e

iter: 770 |  loss: 0.004340838640928268 | eval_loss: -0.8733201846127894 | total time: 138.45091199874878
iter: 780 |  loss: 0.0030167477671056986 | eval_loss: -0.8505103937981096 | total time: 140.24404764175415
iter: 790 |  loss: 0.0034729137551039457 | eval_loss: -0.8751263661192529 | total time: 142.0338270664215
iter: 800 |  loss: 0.004943095147609711 | eval_loss: -0.8553857821428792 | total time: 143.82728171348572
iter: 810 |  loss: 0.0035133841447532177 | eval_loss: -0.8588593208348365 | total time: 145.61876821517944
iter: 820 |  loss: 0.00345264351926744 | eval_loss: -0.8681178012098715 | total time: 147.4099633693695
iter: 830 |  loss: 0.004296812694519758 | eval_loss: -0.8860691623676348 | total time: 149.20021390914917
iter: 840 |  loss: 0.004018399398773909 | eval_loss: -0.8514700572304287 | total time: 150.99178814888
iter: 850 |  loss: 0.004870549309998751 | eval_loss: -0.8568206361032376 | total time: 152.7818684577942
iter: 860 |  loss: 0.00449265306815505 | eval_loss

iter: 1550 |  loss: 0.003081946400925517 | eval_loss: -0.8966865675372728 | total time: 279.38720178604126
iter: 1560 |  loss: 0.0037293981295078993 | eval_loss: -0.8894694949814895 | total time: 281.17805218696594
iter: 1570 |  loss: 0.003804683918133378 | eval_loss: -0.9228226966695829 | total time: 282.96497225761414
iter: 1580 |  loss: 0.005864571314305067 | eval_loss: -0.9398315663663958 | total time: 284.7494287490845
iter: 1590 |  loss: 0.003315431997179985 | eval_loss: -0.8664111185036384 | total time: 286.53613471984863
iter: 1600 |  loss: 0.004378436133265495 | eval_loss: -0.8837799206388222 | total time: 288.3217566013336
iter: 1610 |  loss: 0.005467335227876902 | eval_loss: -0.8691079872579413 | total time: 290.10939836502075
iter: 1620 |  loss: 0.003729358548298478 | eval_loss: -0.8687375548803116 | total time: 291.8944857120514
iter: 1630 |  loss: 0.0030398941598832607 | eval_loss: -0.9110254800441528 | total time: 293.67876839637756
iter: 1640 |  loss: 0.0041457135230302