## Import library

In [48]:
import os
import subprocess
import random
import warnings
import numpy as np
from scipy.linalg import sqrtm
from scipy.stats import unitary_group
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
from tqdm.notebook import tqdm
import itertools

import torch
from qucumber.nn_states import DensityMatrix
from qucumber.nn_states import ComplexWaveFunction
from qucumber.callbacks import MetricEvaluator
import qucumber.utils.unitaries as unitaries
import qucumber.utils.training_statistics as ts
import qucumber.utils.cplx as cplx
import qucumber.utils.data as data
from qucumber.observables import ObservableBase, to_pm1
from qucumber.observables.pauli import flip_spin
import qucumber

from qulacs.gate import Pauli

## settings

In [49]:
with open('./params_setting.yaml', 'r') as yml:
    params = yaml.safe_load(yml)
    
# quantum circuit parameter
n_qubit = params["circuit_info"]["n_qubit"]
each_n_shot = params["circuit_info"]["each_n_shot"]
state_name = params["circuit_info"]["state_name"]
error_model = params["circuit_info"]["error_model"]
error_rate = params["circuit_info"]["error_rate"]

# RBM architecture parameter
n_visible_unit = params["architecture_info"]["n_visible_unit"]
n_hidden_unit = params["architecture_info"]["n_hidden_unit"] 
n_aux_unit = params["architecture_info"]["n_aux_unit"]

# train parameter
lr = params["train_info"]["lr"]
pbs = params["train_info"]["positive_batch_size"]
nbs = params["train_info"]["negative_batch_size"]
n_gibbs_step = params["train_info"]["n_gibbs_step"]
period = 1
epoch = params["train_info"]["n_epoch"]
lr_drop_epoch = params["train_info"]["lr_drop_epoch"]
lr_drop_factor = params["train_info"]["lr_drop_factor"]
seed = params["train_info"]["seed"]

# sampling parameter
n_sampling = params["sampling_info"]["n_sample"]
n_copy = params["sampling_info"]["n_copy"]

# data path info
environment = "local"
if environment == "local":
    train_data_path = f"./data/{error_model}/error_rate_{100*error_rate}%/num_of_data_{each_n_shot}/"
    target_state_path = f"./target_state/{error_model}/error_rate_{100*error_rate}%/"
if environment == "colab":
    from google.colab import drive
    drive.mount("/content/drive/")
    drive_path = "/content/drive/MyDrive/NQS4QEM/Bell"
    train_data_path = drive_path + f"./data/{error_model}/error_rate_{100*error_rate}%/num_of_data_{each_n_shot}/"
    target_state_path = drive_path + f"./target_state/{error_model}/error_rate_{100*error_rate}%/"

# settings
## warnings
warnings.simplefilter('ignore')

## seaborn layout
sns.set()
sns.set_style("white")

## seed
def seed_settings(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    qucumber.set_random_seed(seed, cpu=True, gpu=False)

seed_settings(seed=seed)

## utility

In [50]:
def get_density_matrix(nn_state):
    space = nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)
    tensor = nn_state.rho(space, space)/Z
    matrix = cplx.numpy(tensor)
    return matrix

def get_max_eigvec(matrix):
    e_val, e_vec = np.linalg.eigh(matrix)
    me_val = e_val[-1]
    me_vec = e_vec[:,-1]
    return me_vec

def get_eigvec(nn_state, obs, space, **kwargs):
    dm = get_density_matrix(nn_state)
    ev = get_max_eigvec(dm)
    ev = np.atleast_2d(ev)
    val = ev@obs@ev.T.conj()
    val = val[0,0].real
    return val

def observable_XX():
    target_list = [0, 1]
    pauli_index = [1, 1] # 1:X , 2:Y, 3:Z
    gate = Pauli(target_list, pauli_index) # = X_1 X_2
    return gate.get_matrix()

def observable_XZ():
    target_list = [0, 1]
    pauli_index = [1, 3] # 1:X , 2:Y, 3:Z
    gate = Pauli(target_list, pauli_index) # = X_1 Z_2
    return gate.get_matrix()

def observable_ZZ_ev(nn_state, **kwargs):
    obs_stat = calculate_distilled_expectation_value({0: "Z", 1: "Z"}, n_sampling, n_copy)
    return obs_stat["mean"]

def observable_1dtfi_ev(nn_state, **kwargs):
    obs_stat_term_1 = calculate_distilled_expectation_value({0: "Z", 1: "Z"}, n_sampling, n_copy)
    obs_stat_term_2 = calculate_distilled_expectation_value({0: "X", 1: "I"}, n_sampling, n_copy)
    obs_stat_term_3 = calculate_distilled_expectation_value({0: "I", 1: "X"}, n_sampling, n_copy)
    
    return obs_stat_term_1["mean"] + obs_stat_term_2["mean"] + obs_stat_term_3["mean"]

## quantum circuit

In [51]:
def Rz(n_qubit, target_qubit_idx, theta):
    I = np.eye(2)
    local_Rz = np.array([[np.exp(-1j*theta/2),0], [0,np.cos(1j*theta/2)]])
    if target_qubit_idx==0:
        mat = local_Rz
    else:
        mat = I
    for i in range(n_qubit-1):
        if i+1==target_qubit_idx:
            mat = np.kron(mat, local_Rz)
        else:
            mat = np.kron(mat, I)
            
    return mat

def global_depolarizing(state, n_qubit, error_rate):
    return (1-error_rate)*state + error_rate*np.trace(state)*np.eye(2**n_qubit)/(2**n_qubit)
    
def unitary(state, n_qubit, theta, target_qubit_idx):
    return Rz(n_qubit, target_qubit_idx, theta) @ state @ Rz(n_qubit, target_qubit_idx, theta).T.conjugate()

def init_state(n_qubit, state_name):
    ket_0 = np.array([[1],[0]]) 
    init_state = ket_0
    
    for i in range(2**(n_qubit-1)-1):
        init_state = np.append(init_state, np.array([[0],[0]]), axis=0) # |00...0>
    
    if state_name == "density_matrix":
        init_state_vec = init_state
        init_state = init_state_vec @ init_state_vec.T.conjugate() # |00...0><00...0|
    
    return init_state

def Random_unitary(n_qubit, state_name, error_model, error_rate):
    from scipy.stats import unitary_group
    U = unitary_group.rvs(dim = 2**n_qubit, random_state = seed)
    
    if state_name == "state_vector":
        if error_model == "ideal":
            state = init_state(n_qubit, state_name)
            state = U @ state
            
    if state_name == "density_matrix":
        if error_model == "ideal":
            state = init_state(n_qubit, state_name)
            state = U @ state @ U.T.conjugate()
        
        if error_model == "depolarizing":
            state = init_state(n_qubit, state_name)
            state = U @ state @ U.T.conjugate()
            state = global_depolarizing(state, n_qubit, error_rate)
        
        if error_model == "unitary":
            state = init_state(n_qubit, state_name)
            state = U @ state @ U.T.conjugate()
            for i in range(n_qubit):
                state = unitary(state, n_qubit, np.sqrt(error_rate), i)
            
        if error_model == "depolarizing&unitary":
            state = init_state(n_qubit, state_name)
            state = U @ state @ U.T.conjugate()
            state = global_depolarizing(state, n_qubit, error_rate)
            for i in range(n_qubit):
                state = unitary(state, n_qubit, np.sqrt(error_rate), i)
    
    return state

## generate dataset

In [52]:
# generate train data
is_train_data_file = os.path.exists(train_data_path + "*.txt")
if is_train_data_file:
    print("train data is exsisted !")
else:
    print("generate directries & train data ...")
    os.makedirs(train_data_path, exist_ok = True)
    os.makedirs(target_state_path, exist_ok = True)
    subprocess.run("python gen_dataset.py", shell=True)
    print("train data is ready !")

generate directries & train data ...


0it [00:00, ?it/s]
  0%|          | 0/1000 [00:00<?, ?it/s][A
  2%|▎         | 25/1000 [00:00<00:03, 245.74it/s][A

measurement pattern 1 : ('X', 'X')



  5%|▌         | 54/1000 [00:00<00:03, 265.78it/s][A
  8%|▊         | 81/1000 [00:00<00:03, 257.91it/s][A
 11%|█         | 110/1000 [00:00<00:03, 266.62it/s][A
 14%|█▎        | 137/1000 [00:00<00:03, 267.39it/s][A
 16%|█▋        | 165/1000 [00:00<00:03, 271.07it/s][A
 19%|█▉        | 193/1000 [00:00<00:03, 267.38it/s][A
 22%|██▏       | 222/1000 [00:00<00:02, 272.14it/s][A
 25%|██▌       | 250/1000 [00:00<00:02, 273.49it/s][A
 28%|██▊       | 278/1000 [00:01<00:02, 271.09it/s][A
 31%|███       | 308/1000 [00:01<00:02, 277.05it/s][A
 34%|███▎      | 336/1000 [00:01<00:02, 274.88it/s][A
 36%|███▋      | 364/1000 [00:01<00:02, 275.92it/s][A
 39%|███▉      | 392/1000 [00:01<00:02, 273.53it/s][A
 42%|████▏     | 421/1000 [00:01<00:02, 277.39it/s][A
 45%|████▍     | 449/1000 [00:01<00:02, 271.67it/s][A
 48%|████▊     | 477/1000 [00:01<00:01, 266.82it/s][A
 50%|█████     | 505/1000 [00:01<00:01, 268.54it/s][A
 53%|█████▎    | 534/1000 [00:01<00:01, 272.90it/s][A
 56%|█████▌

measurement pattern 2 : ('X', 'Y')



  6%|▌         | 56/1000 [00:00<00:03, 271.73it/s][A
  8%|▊         | 84/1000 [00:00<00:03, 262.32it/s][A
 11%|█         | 111/1000 [00:00<00:04, 213.83it/s][A
 14%|█▎        | 137/1000 [00:00<00:03, 227.29it/s][A
 16%|█▋        | 163/1000 [00:00<00:03, 236.18it/s][A
 19%|█▉        | 190/1000 [00:00<00:03, 246.21it/s][A
 22%|██▏       | 218/1000 [00:00<00:03, 253.97it/s][A
 24%|██▍       | 245/1000 [00:00<00:02, 256.85it/s][A
 27%|██▋       | 272/1000 [00:01<00:02, 258.25it/s][A
 30%|███       | 300/1000 [00:01<00:02, 263.93it/s][A
 33%|███▎      | 328/1000 [00:01<00:02, 267.50it/s][A
 36%|███▌      | 356/1000 [00:01<00:02, 269.96it/s][A
 38%|███▊      | 384/1000 [00:01<00:02, 270.07it/s][A
 41%|████      | 412/1000 [00:01<00:02, 271.95it/s][A
 44%|████▍     | 440/1000 [00:01<00:02, 271.21it/s][A
 47%|████▋     | 468/1000 [00:01<00:01, 273.02it/s][A
 50%|████▉     | 496/1000 [00:01<00:01, 273.18it/s][A
 52%|█████▏    | 524/1000 [00:02<00:01, 269.65it/s][A
 55%|█████▌

measurement pattern 3 : ('X', 'Z')



  2%|▏         | 24/1000 [00:00<00:04, 239.43it/s][A
  5%|▌         | 53/1000 [00:00<00:03, 264.96it/s][A
  8%|▊         | 81/1000 [00:00<00:03, 269.29it/s][A
 11%|█         | 109/1000 [00:00<00:03, 272.06it/s][A
 14%|█▍        | 138/1000 [00:00<00:03, 277.09it/s][A
 17%|█▋        | 166/1000 [00:00<00:03, 275.59it/s][A
 20%|█▉        | 195/1000 [00:00<00:02, 278.30it/s][A
 22%|██▏       | 223/1000 [00:00<00:02, 277.32it/s][A
 25%|██▌       | 252/1000 [00:00<00:02, 280.32it/s][A
 28%|██▊       | 281/1000 [00:01<00:02, 277.46it/s][A
 31%|███       | 309/1000 [00:01<00:02, 274.67it/s][A
 34%|███▍      | 338/1000 [00:01<00:02, 278.52it/s][A
 37%|███▋      | 366/1000 [00:01<00:02, 278.56it/s][A
 40%|███▉      | 395/1000 [00:01<00:02, 279.98it/s][A
 42%|████▏     | 424/1000 [00:01<00:02, 274.68it/s][A
 45%|████▌     | 452/1000 [00:01<00:02, 209.16it/s][A
 48%|████▊     | 476/1000 [00:01<00:02, 212.79it/s][A
 50%|█████     | 502/1000 [00:01<00:02, 223.73it/s][A
 53%|█████▎ 

measurement pattern 4 : ('Y', 'X')


  6%|▌         | 58/1000 [00:00<00:03, 282.17it/s][A
  9%|▊         | 87/1000 [00:00<00:03, 268.60it/s][A
 11%|█▏        | 114/1000 [00:00<00:03, 264.97it/s][A
 14%|█▍        | 142/1000 [00:00<00:03, 267.60it/s][A
 17%|█▋        | 170/1000 [00:00<00:03, 271.37it/s][A
 20%|█▉        | 198/1000 [00:00<00:02, 270.60it/s][A
 23%|██▎       | 226/1000 [00:00<00:02, 268.77it/s][A
 26%|██▌       | 255/1000 [00:00<00:02, 271.49it/s][A
 28%|██▊       | 283/1000 [00:01<00:02, 267.90it/s][A
 31%|███       | 310/1000 [00:01<00:02, 267.19it/s][A
 34%|███▎      | 337/1000 [00:01<00:02, 266.42it/s][A
 36%|███▋      | 364/1000 [00:01<00:02, 266.85it/s][A
 39%|███▉      | 392/1000 [00:01<00:02, 267.89it/s][A
 42%|████▏     | 420/1000 [00:01<00:02, 270.13it/s][A
 45%|████▍     | 448/1000 [00:01<00:02, 271.70it/s][A
 48%|████▊     | 476/1000 [00:01<00:01, 272.17it/s][A
 50%|█████     | 504/1000 [00:01<00:01, 253.64it/s][A
 53%|█████▎    | 530/1000 [00:01<00:01, 253.30it/s][A
 56%|█████▌ 

measurement pattern 5 : ('Y', 'Y')



  3%|▎         | 28/1000 [00:00<00:03, 275.51it/s][A
  6%|▌         | 56/1000 [00:00<00:03, 275.67it/s][A
  8%|▊         | 84/1000 [00:00<00:03, 275.54it/s][A
 11%|█         | 112/1000 [00:00<00:03, 273.92it/s][A
 14%|█▍        | 140/1000 [00:00<00:03, 274.17it/s][A
 17%|█▋        | 168/1000 [00:00<00:03, 271.37it/s][A
 20%|█▉        | 196/1000 [00:00<00:02, 273.49it/s][A
 22%|██▏       | 224/1000 [00:00<00:02, 273.09it/s][A
 25%|██▌       | 252/1000 [00:00<00:02, 273.95it/s][A
 28%|██▊       | 280/1000 [00:01<00:02, 272.17it/s][A
 31%|███       | 309/1000 [00:01<00:02, 275.74it/s][A
 34%|███▎      | 337/1000 [00:01<00:02, 274.17it/s][A
 36%|███▋      | 365/1000 [00:01<00:02, 274.34it/s][A
 39%|███▉      | 393/1000 [00:01<00:02, 273.94it/s][A
 42%|████▏     | 421/1000 [00:01<00:02, 271.49it/s][A
 45%|████▍     | 449/1000 [00:01<00:02, 268.76it/s][A
 48%|████▊     | 476/1000 [00:01<00:01, 267.51it/s][A
 50%|█████     | 503/1000 [00:01<00:01, 254.15it/s][A
 53%|█████▎ 

measurement pattern 6 : ('Y', 'Z')



  6%|▌         | 58/1000 [00:00<00:03, 272.98it/s][A
  9%|▊         | 86/1000 [00:00<00:03, 274.16it/s][A
 12%|█▏        | 115/1000 [00:00<00:03, 278.35it/s][A
 14%|█▍        | 143/1000 [00:00<00:03, 256.79it/s][A
 17%|█▋        | 169/1000 [00:00<00:03, 252.61it/s][A
 20%|█▉        | 197/1000 [00:00<00:03, 259.85it/s][A
 23%|██▎       | 226/1000 [00:00<00:02, 266.47it/s][A
 25%|██▌       | 253/1000 [00:00<00:02, 256.40it/s][A
 28%|██▊       | 279/1000 [00:01<00:02, 249.80it/s][A
 31%|███       | 306/1000 [00:01<00:02, 253.93it/s][A
 34%|███▎      | 335/1000 [00:01<00:02, 261.99it/s][A
 36%|███▌      | 362/1000 [00:01<00:02, 262.74it/s][A
 39%|███▉      | 390/1000 [00:01<00:02, 267.38it/s][A
 42%|████▏     | 417/1000 [00:01<00:02, 267.12it/s][A
 44%|████▍     | 445/1000 [00:01<00:02, 268.36it/s][A
 47%|████▋     | 473/1000 [00:01<00:01, 269.90it/s][A
 50%|█████     | 501/1000 [00:01<00:01, 269.47it/s][A
 53%|█████▎    | 530/1000 [00:01<00:01, 273.70it/s][A
 56%|█████▌

measurement pattern 7 : ('Z', 'X')



  5%|▌         | 54/1000 [00:00<00:03, 253.15it/s][A
  8%|▊         | 80/1000 [00:00<00:03, 253.50it/s][A
 11%|█         | 106/1000 [00:00<00:03, 253.59it/s][A
 13%|█▎        | 132/1000 [00:00<00:03, 254.04it/s][A
 16%|█▌        | 158/1000 [00:00<00:03, 255.88it/s][A
 18%|█▊        | 184/1000 [00:00<00:03, 255.57it/s][A
 21%|██        | 212/1000 [00:00<00:03, 260.27it/s][A
 24%|██▍       | 240/1000 [00:00<00:02, 265.80it/s][A
 27%|██▋       | 268/1000 [00:01<00:02, 269.54it/s][A
 30%|██▉       | 297/1000 [00:01<00:02, 275.35it/s][A
 32%|███▎      | 325/1000 [00:01<00:02, 265.20it/s][A
 35%|███▌      | 354/1000 [00:01<00:02, 271.55it/s][A
 38%|███▊      | 383/1000 [00:01<00:02, 275.35it/s][A
 41%|████      | 411/1000 [00:01<00:02, 274.81it/s][A
 44%|████▍     | 440/1000 [00:01<00:02, 276.60it/s][A
 47%|████▋     | 468/1000 [00:01<00:01, 275.35it/s][A
 50%|████▉     | 496/1000 [00:01<00:01, 275.55it/s][A
 52%|█████▏    | 524/1000 [00:01<00:01, 269.53it/s][A
 55%|█████▌

measurement pattern 8 : ('Z', 'Y')



  3%|▎         | 28/1000 [00:00<00:03, 270.97it/s][A
  6%|▌         | 56/1000 [00:00<00:03, 270.38it/s][A
  8%|▊         | 84/1000 [00:00<00:03, 262.88it/s][A
 11%|█         | 111/1000 [00:00<00:03, 245.51it/s][A
 14%|█▍        | 139/1000 [00:00<00:03, 256.58it/s][A
 17%|█▋        | 167/1000 [00:00<00:03, 261.67it/s][A
 20%|█▉        | 196/1000 [00:00<00:02, 268.20it/s][A
 22%|██▏       | 224/1000 [00:00<00:02, 270.07it/s][A
 25%|██▌       | 252/1000 [00:00<00:02, 264.17it/s][A
 28%|██▊       | 280/1000 [00:01<00:02, 266.71it/s][A
 31%|███       | 309/1000 [00:01<00:02, 272.80it/s][A
 34%|███▎      | 337/1000 [00:01<00:02, 271.89it/s][A
 36%|███▋      | 365/1000 [00:01<00:02, 273.43it/s][A
 39%|███▉      | 393/1000 [00:01<00:02, 272.29it/s][A
 42%|████▏     | 422/1000 [00:01<00:02, 275.62it/s][A
 45%|████▌     | 450/1000 [00:01<00:02, 273.68it/s][A
 48%|████▊     | 478/1000 [00:01<00:01, 275.21it/s][A
 51%|█████     | 506/1000 [00:01<00:01, 274.28it/s][A
 53%|█████▎ 

measurement pattern 9 : ('Z', 'Z')



  6%|▌         | 58/1000 [00:00<00:03, 278.43it/s][A
  9%|▊         | 86/1000 [00:00<00:03, 266.63it/s][A
 11%|█▏        | 113/1000 [00:00<00:03, 244.91it/s][A
 14%|█▍        | 139/1000 [00:00<00:03, 249.22it/s][A
 17%|█▋        | 168/1000 [00:00<00:03, 259.38it/s][A
 20%|█▉        | 196/1000 [00:00<00:03, 265.59it/s][A
 22%|██▎       | 225/1000 [00:00<00:02, 273.06it/s][A
 25%|██▌       | 253/1000 [00:00<00:02, 271.48it/s][A
 28%|██▊       | 282/1000 [00:01<00:02, 275.90it/s][A
 31%|███       | 311/1000 [00:01<00:02, 278.84it/s][A
 34%|███▍      | 339/1000 [00:01<00:02, 276.46it/s][A
 37%|███▋      | 367/1000 [00:01<00:02, 275.80it/s][A
 40%|███▉      | 395/1000 [00:01<00:02, 275.45it/s][A
 42%|████▏     | 423/1000 [00:01<00:02, 276.07it/s][A
 45%|████▌     | 451/1000 [00:01<00:02, 273.48it/s][A
 48%|████▊     | 480/1000 [00:01<00:01, 277.84it/s][A
 51%|█████     | 508/1000 [00:01<00:01, 277.74it/s][A
 54%|█████▎    | 536/1000 [00:01<00:01, 273.32it/s][A
 56%|█████▋

train data is ready !


## load dataset

In [53]:
meas_pattern_path = train_data_path + "/measurement_pattern.txt"
meas_label_path = train_data_path + "/measurement_label.txt"
meas_result_path = train_data_path + "/measurement_result.txt"
target_rho_re_path = target_state_path + "/rho_real.txt"
target_rho_im_path = target_state_path + "/rho_imag.txt"
meas_result, target_rho, meas_label, meas_pattern = data.load_data_DM(meas_result_path,
                                                                      target_rho_re_path,
                                                                      target_rho_im_path,
                                                                      meas_label_path,
                                                                      meas_pattern_path)

## build RBM architecture

In [54]:
nn_state_dm = DensityMatrix(
    num_visible = n_visible_unit, 
    num_hidden = n_hidden_unit, 
    num_aux = n_aux_unit, 
    unitary_dict = unitaries.create_dict(),
    gpu = False
)

## NQS for VD

In [55]:
class GeneralPauliDistill(ObservableBase):
    def __init__(self, pauli_dict: dict, m: int) -> None:
        self.name = "distilled_pauli"
        self.symbol = "distilled_general_pauli"
        self.pauli_dict = pauli_dict
        self.num_copy = m
        
    def apply(self, nn_state, samples):
        """
        This function calcualte <x1 x2 ... xm | rho^{\otimes m} O | xm x1 x2 ... xm-1> / <x1 x2 ... xm | rho^{\otimes m} | x1 x2 ... xm>
        where O acts only on the first register.
        """
        
        # [num_sample, num_visible_node]
        # samples = [s1, s2, s3 ... sN]
        #  where num_sample = N, and si is num_visible_node-bits
        samples = samples.to(device=nn_state.device)
        
        num_sample, num_visible_node = samples.shape
        
        # [num_sample, num_visible_node * num_copy]
        # samples_array = [[s1 sN sN-1], [s2 s1 sN], [s3 s2 s1],.. [sN sN-1 sN-2]]
        #  each row is num_copy*num_visible_node bits the above example is for num_copy=3
        samples_array = []
        for copy_index in range(self.num_copy):
            rolled_samples = torch.roll(samples, shifts=copy_index, dims=0)
            samples_array.append(rolled_samples)
        samples_array = torch.hstack(samples_array)
        assert(samples_array.shape[0] == num_sample)
        assert(samples_array.shape[1] == num_visible_node * self.num_copy)
        
        # roll second dim of [num_sample, num_visible_node * num_copy] by num_visible_node
        # swapped_samples_array = [[sN-1 s1 sN], [sN s2 s1], [s1 s3 s2],.. [sN-2 sN sN-1]]
        swapped_samples_array = torch.roll(samples_array, shifts = num_visible_node, dims=1)

        # pick copy of first block
        #  first_block_sample = [sN-1, sN, s1, s2, ... sN-2]
        first_block_sample = swapped_samples_array[:, :num_visible_node].clone()

        # calculate coefficient for first block [num_samples, 0:num_visible_node]
        total_prod = cplx.make_complex(torch.ones_like(samples[:,0]), torch.zeros_like(samples[:,0]))
        for index, pauli in self.pauli_dict.items():
            assert(index < num_visible_node)
            coeff = to_pm1(first_block_sample[:, index])
            if pauli == "Z":
                coeff = cplx.make_complex(coeff, torch.zeros_like(coeff))
                total_prod = cplx.elementwise_mult(coeff, total_prod)
            elif pauli == "Y":
                coeff = cplx.make_complex(torch.zeros_like(coeff), coeff)
                total_prod = cplx.elementwise_mult(coeff, total_prod)
        
        # flip samples for for first block [num_samples, 0:num_visible_node]
        # first_block_sample -> [OsN-1, OsN, Os1, Os2, ... OsN-2]
        #  where Osi is bit array after Pauli bit-flips 
        for index, pauli in self.pauli_dict.items():
            assert(index < num_visible_node)
            if pauli in ["X", "Y"]:
                first_block_sample = flip_spin(index, first_block_sample)


        # store flipped first block
        swapped_samples_array[:, :num_visible_node] = first_block_sample

        # calculate product of coefficients
        # samples_array = [[s1 sN sN-1], [s2 s1 sN], [s3 s2 s1],.. [sN sN-1 sN-2]]
        # swapped_samples_array = [[OsN-1 s1 sN], [OsN s2 s1], [Os1 s3 s2],.. [OsN-2 sN sN-1]]
        """
        total_prod = [
            <s1 sN sN-1 | rho^{\otimes 3} | OsN-1 s1 sN> / <s1 sN sN-1 | rho^{\otimes 3} | s1 sN sN-1> , 
            <s2 s1 sN   | rho^{\otimes 3} | OsN s2 s1>   / <s2 s1 sN   | rho^{\otimes 3} | s2 s1 sN> , 
            <s3 s2 s1   | rho^{\otimes 3} | Os1 s3 s2>   / <s3 s2 s1   | rho^{\otimes 3} | s3 s2 s1> , 

        e.g. 
        <s3 s2 s1   | rho^{\otimes 3} | Os1 s3 s2>   / <s3 s2 s1   | rho^{\otimes 3} | s3 s2 s1>
         = <s3 | rho | Os1> <s2 | rho | s3> < s1| rho | s2> / (<s3 | rho | s3> <s2 | rho | s2> < s1| rho | s1>)
         =  (<s3 | rho | Os1> / <s3 | rho | s3>)
          * (<s2 | rho | s3> / <s2 | rho | s2> )
          * (< s1| rho | s2> / < s1| rho | s1>)
         
        importance_sampling_numerator(s3, Os1)  provides <s3 | rho | Os1>
        importance_sampling_denominator(s3)     provides <s3 | rho | s3>
        """
        for copy_index in range(self.num_copy):
            st = copy_index * samples.shape[1]
            en = (copy_index+1) * samples.shape[1]
            # numerator is []
            numerator = nn_state.importance_sampling_numerator(swapped_samples_array[:, st:en], samples_array[:, st:en])
            denominator = nn_state.importance_sampling_denominator(samples_array[:, st:en])
            values = cplx.elementwise_division(numerator, denominator)
            total_prod = cplx.elementwise_mult(total_prod, values)

        value = cplx.real(total_prod)
        return value

def calculate_distilled_expectation_value(pauli_dict: dict, num_samples: int, num_copies: int):
    obs_num = GeneralPauliDistill(pauli_dict, num_copies)
    obs_div = GeneralPauliDistill({}, num_copies)
    num_stat = obs_num.statistics(nn_state_dm, num_samples=num_samples)
    div_stat = obs_div.statistics(nn_state_dm, num_samples=num_samples)

    from uncertainties import ufloat
    num = ufloat(num_stat["mean"], num_stat["std_error"])
    div = ufloat(div_stat["mean"], div_stat["std_error"])
    val = num/div
    result_dict = {"mean": val.n , "std_error": val.s, "num_samples": num_samples, "num_copies": num_copies}
    return result_dict

## callback setting 

In [57]:
def ideal_fidelity(nn_state, **kwargs):
    ideal_state = Random_unitary(n_qubit, state_name, "ideal", error_rate)
    train_state = get_density_matrix(nn_state)
    F = np.trace(sqrtm(sqrtm(ideal_state)@train_state@sqrtm(ideal_state)))
    
    return (F.real)**2

def noisy_fidelity(nn_state, **kwargs):
    noisy_state = Random_unitary(n_qubit, state_name, error_model, error_rate)
    train_state = get_density_matrix(nn_state)
    F = np.trace(sqrtm(sqrtm(noisy_state)@train_state@sqrtm(noisy_state)))
    
    return (F.real)**2
    
def max_eigen_fidelity(nn_state, **kwargs):
    ideal_state = Random_unitary(n_qubit, state_name, "ideal", error_rate)
    train_state = get_density_matrix(nn_state)
    max_eigen_state = get_max_eigvec(train_state)
    F = max_eigen_state.T.conjugate()@ideal_state@max_eigen_state
    
    return F.real

def create_callback_dm(nn_state):
    metric_dict = {
        "Fidelity": ts.fidelity,
        "Ideal_fidelity": ideal_fidelity,
        "Noisy_fidelity": noisy_fidelity,
        "Max_eigen_fidelity": max_eigen_fidelity,
        "KL_Divergence": ts.KL,
        #"ObservableZZ_ev": observable_ZZ_ev,
    }
    space = nn_state.generate_hilbert_space()
    callbacks = [
        MetricEvaluator(
            period,
            metric_dict,
            target = target_rho,
            bases = meas_pattern,
            verbose = True,
            space = space,
        )
    ]
    
    return callbacks

callbacks = create_callback_dm(nn_state_dm)

## train

In [58]:
nn_state_dm.fit(
    data = meas_result,
    input_bases = meas_label,
    epochs = epoch,
    pos_batch_size = pbs,
    neg_batch_size = nbs,
    lr = lr,
    k = n_gibbs_step,
    bases = meas_pattern,
    callbacks = callbacks,
    time = True,
    optimizer = torch.optim.Adadelta,
    scheduler = torch.optim.lr_scheduler.StepLR,
    scheduler_args = {"step_size": lr_drop_epoch, "gamma": lr_drop_factor},
)

Epoch: 1	Fidelity = 0.809501	Ideal_fidelity = 0.872215	Noisy_fidelity = 0.899723	Max_eigen_fidelity = 0.899401	KL_Divergence = 0.434793
Epoch: 2	Fidelity = 0.826246	Ideal_fidelity = 0.880780	Noisy_fidelity = 0.908981	Max_eigen_fidelity = 0.930734	KL_Divergence = 0.353533
Epoch: 3	Fidelity = 0.843348	Ideal_fidelity = 0.891173	Noisy_fidelity = 0.918340	Max_eigen_fidelity = 0.942341	KL_Divergence = 0.340485
Epoch: 4	Fidelity = 0.852896	Ideal_fidelity = 0.896139	Noisy_fidelity = 0.923524	Max_eigen_fidelity = 0.958826	KL_Divergence = 0.346220
Epoch: 5	Fidelity = 0.870808	Ideal_fidelity = 0.907264	Noisy_fidelity = 0.933171	Max_eigen_fidelity = 0.968091	KL_Divergence = 0.323714
Epoch: 6	Fidelity = 0.886323	Ideal_fidelity = 0.917532	Noisy_fidelity = 0.941447	Max_eigen_fidelity = 0.968738	KL_Divergence = 0.304103
Epoch: 7	Fidelity = 0.904868	Ideal_fidelity = 0.928429	Noisy_fidelity = 0.951246	Max_eigen_fidelity = 0.985463	KL_Divergence = 0.316082
Epoch: 8	Fidelity = 0.918352	Ideal_fidelity = 0.

In [None]:
fidelities = callbacks[0]["Fidelity"]
KLs = callbacks[0]["KL_Divergence"]
epoch_range = np.arange(period, epoch + 1, period)

fig, axs = plt.subplots(nrows = 1, ncols = 2, figsize = (16, 5))
ax = axs[0]
ax.plot(epoch_range, fidelities, "o", color = "C0", markeredgecolor = "black")
ax.set_ylabel(r"Fidelity")
ax.set_xlabel(r"Epoch")
ax.set_ylim(0.00, 1.00)

ax = axs[1]
ax.plot(epoch_range, KLs, "o", color = "C1", markeredgecolor = "black")
ax.set_ylabel(r"KL Divergence")
ax.set_xlabel(r"Epoch")

## save model & train log

# save model
nn_state_dm.save("./exp003/model.pt")
# save train log
train_log_df = pd.DataFrame()
train_log_df["epoch"] = np.arange(period, epoch+1, period)
train_log_df["Fidelity"] = callbacks[0]["Fidelity"]
train_log_df["KL_Divergence"] = callbacks[0]["KL_Divergence"]
#train_log_df["Observable_ZZ_ev"] = callbacks[0]["Observable_ZZ_ev"]
#train_log_df["Observable_XZ_ev"] = callbacks[0]["Observable_XZ_ev"]
#train_log_df.to_csv("./exp003/train_log.csv", index=False)