In [None]:
import sys
import torch
import tntorch as tn
import numpy as np
sys.path.append('..')
import lib

n_qubits = 5
rho = lib.randomMixedState(2 ** n_qubits)
dim = rho.shape[0]

projectors_cnt = 10000
measurements_cnt = 1000
tensor_rank = 2**n_qubits
batch_size = 100
lr = 1e-3

train_size = projectors_cnt * measurements_cnt
train_X, train_y = lib.generate_dataset(rho, projectors_cnt, measurements_cnt)
test_X, test_y = lib.generate_dataset(rho, projectors_cnt, measurements_cnt)
train_y, test_y = [x.astype('float64') for x in [train_y, test_y]]

sigma = lib.simulator.randomMixedState(rho.shape[0])
sigma_real, sigma_imag = [tn.Tensor(x, ranks_tt=tensor_rank, requires_grad=True) for x in [np.real(sigma), np.imag(sigma)]]

epoches = 10

def trace(tensor):
    if len(tensor.shape) == 2:
        return sum([tensor[i,i] for i in range(tensor.shape[0])])
    if len(tensor.shape) == 3:
        return sum([tensor[:, i,i] for i in range(tensor.shape[1])])

def cholesky(sigma_real, sigma_imag):
    sigma_real, sigma_imag = sigma_real.dot(sigma_real, k=1)+sigma_imag.dot(sigma_imag, k=1), sigma_real.dot(sigma_imag, k=1)-sigma_imag.dot(sigma_real, k=1)
    trace = sum([sigma_real[i,i] for i in range(rho.shape[0])])
    sigma_real, sigma_imag = [x/trace for x in [sigma_real, sigma_imag]]
    return sigma_real, sigma_imag
    
initial_trace = sum((np.real(trace(test_X.dot(sigma)))-test_y)**2)

def loss(sigma_real, sigma_imag):
    sigma_real, sigma_imag = cholesky(sigma_real, sigma_imag)
    res = 0
    idx = np.random.choice(np.arange(train_X.shape[0]), batch_size)
    for E_m,y_m in zip(train_X[idx], train_y[idx].astype('float64')):
        E_real, E_imag = [tn.Tensor(x) for x in [np.real(E_m), np.imag(E_m)]]
        res += ((E_real.dot(sigma_real)+E_imag.dot(sigma_imag)-y_m)**2)
#     return res/(initial_trace*train_X.shape[0])
    return res
    

def eval_loss(sigma_real, sigma_imag): # any score function can be used here
    sigma_real_, sigma_imag_ = cholesky(sigma_real, sigma_imag)
    sigma = sigma_real_.torch().detach().cpu().numpy() + 1j*sigma_imag_.torch().detach().cpu().numpy()
    return -lib.fidelity(sigma, rho)

print('Trace before: %f'%initial_trace)

lib.tn_optimize([sigma_real, sigma_imag], loss, eval_loss, tol=0, patience=2000,print_freq=10,lr=lr)