## Import library

In [118]:
import os
import subprocess
import random
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
from tqdm.notebook import tqdm
import itertools

import torch
from qucumber.nn_states import DensityMatrix
from qucumber.nn_states import ComplexWaveFunction
from qucumber.callbacks import MetricEvaluator
import qucumber.utils.unitaries as unitaries
import qucumber.utils.training_statistics as ts
import qucumber.utils.cplx as cplx
import qucumber.utils.data as data
from qucumber.observables import ObservableBase, to_pm1
from qucumber.observables.pauli import flip_spin
import qucumber

from qulacs.gate import Pauli

with open('./params_setting.yaml', 'r') as yml:
    params = yaml.safe_load(yml)
    
# quantum circuit parameter
n_qubit = params["circuit_info"]["n_qubit"]
n_data = params["circuit_info"]["n_data"]
each_n_shot = int(n_data / 3**n_qubit)
state_name = params["circuit_info"]["state_name"]
error_model = params["circuit_info"]["error_model"]
error_rate = params["circuit_info"]["error_rate"]
# RBM architecture parameter
n_visible_unit = params["architecture_info"]["n_visible_unit"]
n_hidden_unit = params["architecture_info"]["n_hidden_unit"] 
n_aux_unit = params["architecture_info"]["n_aux_unit"]
# train parameter
lr = params["train_info"]["lr"]
pbs = params["train_info"]["positive_batch_size"]
nbs = params["train_info"]["negative_batch_size"]
n_gibbs_step = params["train_info"]["n_gibbs_step"]
period = 1
epoch = params["train_info"]["n_epoch"]
lr_drop_epoch = params["train_info"]["lr_drop_epoch"]
lr_drop_factor = params["train_info"]["lr_drop_factor"]
seed = params["train_info"]["seed"]
# sampling parameter
n_sampling = params["sampling_info"]["n_sample"]
n_copy = params["sampling_info"]["n_copy"]
# data path info
train_data_path = f"./data/{noise_model}/error_prob_{100*error_rate}%/num_of_data_{n_data}/"
ideal_state_path = f"./target_state/"

# settings
## warnings
warnings.simplefilter('ignore')

## seaborn layout
sns.set()
sns.set_style("white")

## seed
def seed_settings(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    qucumber.set_random_seed(seed, cpu=True, gpu=False)

seed_settings(seed=seed)

## caluculate ideal state

In [119]:
# calculate ideal state
is_ideal_state_file = os.path.exists(ideal_state_path)
if is_ideal_state_file:
    print("ideal state data is exsisted !")
else:
    print("caluculate ideal state data ...")
    subprocess.run("python caluculate_ideal_state.py", shell=True)
    print("ideal state data is ready !")

ideal state data is exsisted !


## generate dataset

In [123]:
# generate train data
is_train_data_file = os.path.exists(train_data_path)
if is_train_data_file:
    print("train data is exsisted !")
else:
    print("generate directries & train data ...")
    os.makedirs(train_data_path, exist_ok = True)
    subprocess.run("python generate_dataset.py", shell=True)
    print("train data is ready !")

generate directries & train data ...


0it [00:00, ?it/s]
  0%|          | 0/1111 [00:00<?, ?it/s][A
  1%|          | 6/1111 [00:00<00:19, 57.63it/s][A

measurement pattern 0 : ('X', 'X')



  1%|          | 12/1111 [00:00<00:18, 57.94it/s][A
  2%|▏         | 20/1111 [00:00<00:17, 64.13it/s][A
  3%|▎         | 28/1111 [00:00<00:16, 67.44it/s][A
  3%|▎         | 36/1111 [00:00<00:15, 70.13it/s][A
  4%|▍         | 44/1111 [00:00<00:15, 71.11it/s][A
  5%|▍         | 52/1111 [00:00<00:14, 72.26it/s][A
  5%|▌         | 60/1111 [00:00<00:14, 71.59it/s][A
  6%|▌         | 68/1111 [00:00<00:14, 72.01it/s][A
  7%|▋         | 76/1111 [00:01<00:14, 71.32it/s][A
  8%|▊         | 84/1111 [00:01<00:14, 72.54it/s][A
  8%|▊         | 92/1111 [00:01<00:15, 67.31it/s][A
  9%|▉         | 100/1111 [00:01<00:14, 69.62it/s][A
 10%|▉         | 108/1111 [00:01<00:14, 71.23it/s][A
 10%|█         | 116/1111 [00:01<00:13, 72.41it/s][A
 11%|█         | 124/1111 [00:01<00:13, 71.26it/s][A
 12%|█▏        | 132/1111 [00:01<00:15, 62.55it/s][A
 13%|█▎        | 139/1111 [00:02<00:15, 63.41it/s][A
 13%|█▎        | 146/1111 [00:02<00:15, 61.25it/s][A
 14%|█▍        | 154/1111 [00:02<00:14

measurement pattern 1 : ('X', 'Y')



  1%|▏         | 16/1111 [00:00<00:14, 74.22it/s][A
  2%|▏         | 24/1111 [00:00<00:14, 74.71it/s][A
  3%|▎         | 32/1111 [00:00<00:14, 75.07it/s][A
  4%|▎         | 40/1111 [00:00<00:14, 75.48it/s][A
  4%|▍         | 48/1111 [00:00<00:14, 75.68it/s][A
  5%|▌         | 56/1111 [00:00<00:13, 75.95it/s][A
  6%|▌         | 64/1111 [00:00<00:13, 76.07it/s][A
  6%|▋         | 72/1111 [00:00<00:13, 76.59it/s][A
  7%|▋         | 80/1111 [00:01<00:13, 76.53it/s][A
  8%|▊         | 88/1111 [00:01<00:13, 75.90it/s][A
  9%|▊         | 96/1111 [00:01<00:13, 76.04it/s][A
  9%|▉         | 104/1111 [00:01<00:13, 75.86it/s][A
 10%|█         | 112/1111 [00:01<00:13, 76.04it/s][A
 11%|█         | 120/1111 [00:01<00:13, 76.00it/s][A
 12%|█▏        | 128/1111 [00:01<00:12, 75.87it/s][A
 12%|█▏        | 136/1111 [00:01<00:13, 74.96it/s][A
 13%|█▎        | 144/1111 [00:01<00:13, 70.67it/s][A
 14%|█▎        | 152/1111 [00:02<00:13, 69.60it/s][A
 14%|█▍        | 160/1111 [00:02<00:13

measurement pattern 2 : ('X', 'Z')



  1%|▏         | 16/1111 [00:00<00:14, 76.42it/s][A
  2%|▏         | 24/1111 [00:00<00:15, 69.00it/s][A
  3%|▎         | 32/1111 [00:00<00:15, 71.93it/s][A
  4%|▎         | 40/1111 [00:00<00:14, 73.63it/s][A
  4%|▍         | 48/1111 [00:00<00:14, 75.04it/s][A
  5%|▌         | 56/1111 [00:00<00:14, 74.96it/s][A
  6%|▌         | 64/1111 [00:00<00:13, 76.11it/s][A
  6%|▋         | 72/1111 [00:00<00:13, 76.46it/s][A
  7%|▋         | 80/1111 [00:01<00:13, 77.24it/s][A
  8%|▊         | 88/1111 [00:01<00:13, 76.86it/s][A
  9%|▊         | 96/1111 [00:01<00:13, 77.06it/s][A
  9%|▉         | 104/1111 [00:01<00:13, 76.31it/s][A
 10%|█         | 112/1111 [00:01<00:13, 76.60it/s][A
 11%|█         | 120/1111 [00:01<00:12, 76.73it/s][A
 12%|█▏        | 128/1111 [00:01<00:12, 76.50it/s][A
 12%|█▏        | 136/1111 [00:01<00:12, 76.48it/s][A
 13%|█▎        | 144/1111 [00:01<00:12, 76.46it/s][A
 14%|█▎        | 152/1111 [00:02<00:12, 76.72it/s][A
 14%|█▍        | 160/1111 [00:02<00:12

measurement pattern 3 : ('Y', 'X')



  1%|▏         | 16/1111 [00:00<00:14, 77.27it/s][A
  2%|▏         | 24/1111 [00:00<00:14, 75.73it/s][A
  3%|▎         | 32/1111 [00:00<00:14, 75.40it/s][A
  4%|▎         | 40/1111 [00:00<00:14, 76.29it/s][A
  4%|▍         | 48/1111 [00:00<00:15, 69.06it/s][A
  5%|▌         | 56/1111 [00:00<00:14, 71.12it/s][A
  6%|▌         | 64/1111 [00:00<00:14, 73.17it/s][A
  6%|▋         | 72/1111 [00:00<00:14, 74.19it/s][A
  7%|▋         | 80/1111 [00:01<00:13, 74.90it/s][A
  8%|▊         | 88/1111 [00:01<00:13, 74.99it/s][A
  9%|▊         | 96/1111 [00:01<00:13, 75.37it/s][A
  9%|▉         | 104/1111 [00:01<00:13, 75.72it/s][A
 10%|█         | 112/1111 [00:01<00:13, 75.97it/s][A
 11%|█         | 120/1111 [00:01<00:13, 76.21it/s][A
 12%|█▏        | 128/1111 [00:01<00:12, 76.46it/s][A
 12%|█▏        | 136/1111 [00:01<00:12, 76.82it/s][A
 13%|█▎        | 144/1111 [00:01<00:12, 76.71it/s][A
 14%|█▎        | 152/1111 [00:02<00:12, 76.25it/s][A
 14%|█▍        | 160/1111 [00:02<00:12

measurement pattern 4 : ('Y', 'Y')



  1%|          | 8/1111 [00:00<00:15, 72.42it/s][A
  1%|▏         | 16/1111 [00:00<00:15, 71.03it/s][A
  2%|▏         | 24/1111 [00:00<00:15, 72.35it/s][A
  3%|▎         | 32/1111 [00:00<00:14, 73.43it/s][A
  4%|▎         | 40/1111 [00:00<00:14, 73.98it/s][A
  4%|▍         | 48/1111 [00:00<00:14, 73.94it/s][A
  5%|▌         | 56/1111 [00:00<00:15, 67.76it/s][A
  6%|▌         | 63/1111 [00:00<00:15, 66.81it/s][A
  6%|▋         | 71/1111 [00:01<00:15, 68.95it/s][A
  7%|▋         | 79/1111 [00:01<00:14, 70.33it/s][A
  8%|▊         | 87/1111 [00:01<00:14, 70.83it/s][A
  9%|▊         | 95/1111 [00:01<00:14, 72.22it/s][A
  9%|▉         | 103/1111 [00:01<00:13, 72.80it/s][A
 10%|▉         | 111/1111 [00:01<00:13, 73.25it/s][A
 11%|█         | 119/1111 [00:01<00:13, 74.05it/s][A
 11%|█▏        | 127/1111 [00:01<00:13, 74.56it/s][A
 12%|█▏        | 135/1111 [00:01<00:13, 74.68it/s][A
 13%|█▎        | 143/1111 [00:01<00:12, 75.02it/s][A
 14%|█▎        | 151/1111 [00:02<00:12, 

measurement pattern 5 : ('Y', 'Z')



  1%|▏         | 16/1111 [00:00<00:13, 79.33it/s][A
  2%|▏         | 24/1111 [00:00<00:14, 76.56it/s][A
  3%|▎         | 32/1111 [00:00<00:13, 77.63it/s][A
  4%|▎         | 40/1111 [00:00<00:13, 77.76it/s][A
  4%|▍         | 48/1111 [00:00<00:14, 74.25it/s][A
  5%|▌         | 56/1111 [00:00<00:14, 73.71it/s][A
  6%|▌         | 64/1111 [00:00<00:13, 75.01it/s][A
  6%|▋         | 72/1111 [00:00<00:14, 72.40it/s][A
  7%|▋         | 80/1111 [00:01<00:14, 69.25it/s][A
  8%|▊         | 88/1111 [00:01<00:14, 71.84it/s][A
  9%|▊         | 97/1111 [00:01<00:13, 73.82it/s][A
  9%|▉         | 105/1111 [00:01<00:13, 75.05it/s][A
 10%|█         | 114/1111 [00:01<00:13, 76.04it/s][A
 11%|█         | 122/1111 [00:01<00:12, 76.63it/s][A
 12%|█▏        | 130/1111 [00:01<00:12, 77.00it/s][A
 12%|█▏        | 138/1111 [00:01<00:12, 76.77it/s][A
 13%|█▎        | 146/1111 [00:01<00:12, 76.90it/s][A
 14%|█▍        | 154/1111 [00:02<00:12, 77.15it/s][A
 15%|█▍        | 162/1111 [00:02<00:12

measurement pattern 6 : ('Z', 'X')



  1%|▏         | 16/1111 [00:00<00:14, 76.44it/s][A
  2%|▏         | 24/1111 [00:00<00:14, 77.36it/s][A
  3%|▎         | 32/1111 [00:00<00:13, 77.57it/s][A
  4%|▎         | 40/1111 [00:00<00:33, 31.95it/s][A
  4%|▍         | 46/1111 [00:01<00:42, 24.89it/s][A
  5%|▍         | 51/1111 [00:01<00:40, 26.14it/s][A
  5%|▍         | 55/1111 [00:01<00:51, 20.35it/s][A
  5%|▌         | 58/1111 [00:02<01:01, 17.24it/s][A
  5%|▌         | 61/1111 [00:02<00:57, 18.32it/s][A
  6%|▌         | 64/1111 [00:02<01:00, 17.37it/s][A
  6%|▋         | 71/1111 [00:02<00:45, 22.95it/s][A
  7%|▋         | 74/1111 [00:02<00:50, 20.51it/s][A
  7%|▋         | 79/1111 [00:02<00:40, 25.23it/s][A
  8%|▊         | 86/1111 [00:03<00:30, 33.26it/s][A
  8%|▊         | 92/1111 [00:03<00:26, 38.07it/s][A
  9%|▉         | 99/1111 [00:03<00:22, 44.15it/s][A
 10%|▉         | 107/1111 [00:03<00:19, 51.76it/s][A
 10%|█         | 115/1111 [00:03<00:17, 58.55it/s][A
 11%|█         | 123/1111 [00:03<00:15, 64.

measurement pattern 7 : ('Z', 'Y')



  1%|          | 8/1111 [00:00<00:14, 76.97it/s][A
  1%|▏         | 16/1111 [00:00<00:14, 77.93it/s][A
  2%|▏         | 24/1111 [00:00<00:14, 77.29it/s][A
  3%|▎         | 32/1111 [00:00<00:14, 76.85it/s][A
  4%|▎         | 40/1111 [00:00<00:14, 75.62it/s][A
  4%|▍         | 48/1111 [00:00<00:13, 76.70it/s][A
  5%|▌         | 56/1111 [00:00<00:13, 77.53it/s][A
  6%|▌         | 64/1111 [00:00<00:13, 77.31it/s][A
  7%|▋         | 73/1111 [00:00<00:13, 77.99it/s][A
  7%|▋         | 81/1111 [00:01<00:13, 78.04it/s][A
  8%|▊         | 89/1111 [00:01<00:13, 78.44it/s][A
  9%|▊         | 97/1111 [00:01<00:12, 78.29it/s][A
  9%|▉         | 105/1111 [00:01<00:13, 73.95it/s][A
 10%|█         | 113/1111 [00:01<00:14, 69.93it/s][A
 11%|█         | 121/1111 [00:01<00:13, 71.95it/s][A
 12%|█▏        | 129/1111 [00:01<00:13, 74.01it/s][A
 12%|█▏        | 137/1111 [00:01<00:13, 74.78it/s][A
 13%|█▎        | 145/1111 [00:01<00:12, 75.55it/s][A
 14%|█▍        | 153/1111 [00:02<00:12, 

measurement pattern 8 : ('Z', 'Z')



  1%|          | 8/1111 [00:00<00:13, 79.74it/s][A
  2%|▏         | 17/1111 [00:00<00:13, 80.32it/s][A
  2%|▏         | 26/1111 [00:00<00:13, 80.04it/s][A
  3%|▎         | 35/1111 [00:00<00:13, 80.35it/s][A
  4%|▍         | 44/1111 [00:00<00:13, 80.61it/s][A
  5%|▍         | 53/1111 [00:00<00:13, 81.01it/s][A
  6%|▌         | 62/1111 [00:00<00:12, 80.79it/s][A
  6%|▋         | 71/1111 [00:00<00:12, 80.81it/s][A
  7%|▋         | 80/1111 [00:00<00:12, 80.85it/s][A
  8%|▊         | 89/1111 [00:01<00:12, 81.18it/s][A
  9%|▉         | 98/1111 [00:01<00:12, 81.00it/s][A
 10%|▉         | 107/1111 [00:01<00:12, 81.54it/s][A
 10%|█         | 116/1111 [00:01<00:12, 80.74it/s][A
 11%|█▏        | 125/1111 [00:01<00:12, 80.53it/s][A
 12%|█▏        | 134/1111 [00:01<00:12, 80.41it/s][A
 13%|█▎        | 143/1111 [00:01<00:12, 80.61it/s][A
 14%|█▎        | 152/1111 [00:01<00:11, 80.56it/s][A
 14%|█▍        | 161/1111 [00:01<00:11, 80.89it/s][A
 15%|█▌        | 170/1111 [00:02<00:11,

train data is ready !


## load dataset

In [124]:
meas_pattern_path = train_data_path + "/measurement_pattern.txt"
meas_label_path = train_data_path + "/measurement_label.txt"
meas_result_path = train_data_path + "/measurement_result.txt"
ideal_rho_re_path = ideal_state_path + "/rho_real.txt"
ideal_rho_im_path = ideal_state_path + "/rho_imag.txt"
meas_result, ideal_rho, meas_label, meas_pattern = data.load_data_DM(meas_result_path,
                                                                     ideal_rho_re_path,
                                                                     ideal_rho_im_path,
                                                                     meas_label_path,
                                                                     meas_pattern_path)

## build RBM architecture

In [125]:
nn_state_dm = DensityMatrix(
    num_visible = n_visible_unit, 
    num_hidden = n_hidden_unit, 
    num_aux = n_aux_unit, 
    unitary_dict = unitaries.create_dict(),
    gpu = False
)

## estimate observable expectation value

In [126]:
class GeneralPauliDistill(ObservableBase):
    def __init__(self, pauli_dict: dict, m: int) -> None:
        self.name = "distilled_pauli"
        self.symbol = "distilled_general_pauli"
        self.pauli_dict = pauli_dict
        self.num_copy = m
        
    def apply(self, nn_state, samples):
        """
        This function calcualte <x1 x2 ... xm | rho^{\otimes m} O | xm x1 x2 ... xm-1> / <x1 x2 ... xm | rho^{\otimes m} | x1 x2 ... xm>
        where O acts only on the first register.
        """
        
        # [num_sample, num_visible_node]
        # samples = [s1, s2, s3 ... sN]
        #  where num_sample = N, and si is num_visible_node-bits
        samples = samples.to(device=nn_state.device)
        
        num_sample, num_visible_node = samples.shape
        
        # [num_sample, num_visible_node * num_copy]
        # samples_array = [[s1 sN sN-1], [s2 s1 sN], [s3 s2 s1],.. [sN sN-1 sN-2]]
        #  each row is num_copy*num_visible_node bits the above example is for num_copy=3
        samples_array = []
        for copy_index in range(self.num_copy):
            rolled_samples = torch.roll(samples, shifts=copy_index, dims=0)
            samples_array.append(rolled_samples)
        samples_array = torch.hstack(samples_array)
        assert(samples_array.shape[0] == num_sample)
        assert(samples_array.shape[1] == num_visible_node * self.num_copy)
        
        # roll second dim of [num_sample, num_visible_node * num_copy] by num_visible_node
        # swapped_samples_array = [[sN-1 s1 sN], [sN s2 s1], [s1 s3 s2],.. [sN-2 sN sN-1]]
        swapped_samples_array = torch.roll(samples_array, shifts = num_visible_node, dims=1)

        # pick copy of first block
        #  first_block_sample = [sN-1, sN, s1, s2, ... sN-2]
        first_block_sample = swapped_samples_array[:, :num_visible_node].clone()

        # calculate coefficient for first block [num_samples, 0:num_visible_node]
        total_prod = cplx.make_complex(torch.ones_like(samples[:,0]), torch.zeros_like(samples[:,0]))
        for index, pauli in self.pauli_dict.items():
            assert(index < num_visible_node)
            coeff = to_pm1(first_block_sample[:, index])
            if pauli == "Z":
                coeff = cplx.make_complex(coeff, torch.zeros_like(coeff))
                total_prod = cplx.elementwise_mult(coeff, total_prod)
            elif pauli == "Y":
                coeff = cplx.make_complex(torch.zeros_like(coeff), coeff)
                total_prod = cplx.elementwise_mult(coeff, total_prod)
        
        # flip samples for for first block [num_samples, 0:num_visible_node]
        # first_block_sample -> [OsN-1, OsN, Os1, Os2, ... OsN-2]
        #  where Osi is bit array after Pauli bit-flips 
        for index, pauli in self.pauli_dict.items():
            assert(index < num_visible_node)
            if pauli in ["X", "Y"]:
                first_block_sample = flip_spin(index, first_block_sample)


        # store flipped first block
        swapped_samples_array[:, :num_visible_node] = first_block_sample

        # calculate product of coefficients
        # samples_array = [[s1 sN sN-1], [s2 s1 sN], [s3 s2 s1],.. [sN sN-1 sN-2]]
        # swapped_samples_array = [[OsN-1 s1 sN], [OsN s2 s1], [Os1 s3 s2],.. [OsN-2 sN sN-1]]
        """
        total_prod = [
            <s1 sN sN-1 | rho^{\otimes 3} | OsN-1 s1 sN> / <s1 sN sN-1 | rho^{\otimes 3} | s1 sN sN-1> , 
            <s2 s1 sN   | rho^{\otimes 3} | OsN s2 s1>   / <s2 s1 sN   | rho^{\otimes 3} | s2 s1 sN> , 
            <s3 s2 s1   | rho^{\otimes 3} | Os1 s3 s2>   / <s3 s2 s1   | rho^{\otimes 3} | s3 s2 s1> , 

        e.g. 
        <s3 s2 s1   | rho^{\otimes 3} | Os1 s3 s2>   / <s3 s2 s1   | rho^{\otimes 3} | s3 s2 s1>
         = <s3 | rho | Os1> <s2 | rho | s3> < s1| rho | s2> / (<s3 | rho | s3> <s2 | rho | s2> < s1| rho | s1>)
         =  (<s3 | rho | Os1> / <s3 | rho | s3>)
          * (<s2 | rho | s3> / <s2 | rho | s2> )
          * (< s1| rho | s2> / < s1| rho | s1>)
         
        importance_sampling_numerator(s3, Os1)  provides <s3 | rho | Os1>
        importance_sampling_denominator(s3)     provides <s3 | rho | s3>
        """
        for copy_index in range(self.num_copy):
            st = copy_index * samples.shape[1]
            en = (copy_index+1) * samples.shape[1]
            # numerator is []
            numerator = nn_state.importance_sampling_numerator(swapped_samples_array[:, st:en], samples_array[:, st:en])
            denominator = nn_state.importance_sampling_denominator(samples_array[:, st:en])
            values = cplx.elementwise_division(numerator, denominator)
            total_prod = cplx.elementwise_mult(total_prod, values)

        value = cplx.real(total_prod)
        return value

def calculate_distilled_expectation_value(pauli_dict: dict, num_samples: int, num_copies: int):
    obs_num = GeneralPauliDistill(pauli_dict, num_copies)
    obs_div = GeneralPauliDistill({}, num_copies)
    num_stat = obs_num.statistics(nn_state_dm, num_samples=num_samples)
    div_stat = obs_div.statistics(nn_state_dm, num_samples=num_samples)

    from uncertainties import ufloat
    num = ufloat(num_stat["mean"], num_stat["std_error"])
    div = ufloat(div_stat["mean"], div_stat["std_error"])
    val = num/div
    result_dict = {"mean": val.n , "std_error": val.s, "num_samples": num_samples, "num_copies": num_copies}
    return result_dict

def get_density_matrix(nn_state):
    space = nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)
    tensor = nn_state.rho(space, space)/Z
    matrix = cplx.numpy(tensor)
    return matrix

def get_max_eigvec(matrix):
    e_val, e_vec = np.linalg.eigh(matrix)
    me_val = e_val[-1]
    me_vec = e_vec[:,-1]
    return me_vec

def get_eigvec(nn_state, obs, space, **kwargs):
    dm = get_density_matrix(nn_state)
    ev = get_max_eigvec(dm)
    ev = np.atleast_2d(ev)
    val = ev@obs@ev.T.conj()
    val = val[0,0].real
    return val

def observable_XX():
    target_list = [0, 1]
    pauli_index = [1, 1] # 1:X , 2:Y, 3:Z
    gate = Pauli(target_list, pauli_index) # = X_1 X_2
    return gate.get_matrix()

def observable_XZ():
    target_list = [0, 1]
    pauli_index = [1, 3] # 1:X , 2:Y, 3:Z
    gate = Pauli(target_list, pauli_index) # = X_1 Z_2
    return gate.get_matrix()

def observable_XX_ev(nn_state, **kwargs):
    obs_stat = calculate_distilled_expectation_value({0: "X", 1: "X"}, n_sampling, n_copy)
    return obs_stat["mean"]

def observable_XZ_ev(nn_state, **kwargs):
    obs_stat = calculate_distilled_expectation_value({0: "X", 1: "Z"}, n_sampling, n_copy)
    return obs_stat["mean"]

## callback setting 

In [127]:
def create_callback_dm(nn_state):
    metric_dict = {
        "Fidelity": ts.fidelity,
        "KL_Divergence": ts.KL,
        "Observable_XX_ev": observable_XX_ev,
        "Observable_XZ_ev": observable_XZ_ev,
    }

    space = nn_state.generate_hilbert_space()
    callbacks = [
        MetricEvaluator(
            period,
            metric_dict,
            target = ideal_rho,
            bases = meas_pattern,
            verbose = True,
            space = space,
        )
    ]
    return callbacks

callbacks = create_callback_dm(nn_state_dm)

## train

In [128]:
nn_state_dm.fit(
    data = meas_result,
    input_bases = meas_label,
    epochs = epoch,
    pos_batch_size = pbs,
    neg_batch_size = nbs,
    lr = lr,
    k = n_gibbs_step,
    bases = meas_pattern,
    callbacks = callbacks,
    time = True,
    optimizer = torch.optim.Adadelta,
    scheduler = torch.optim.lr_scheduler.StepLR,
    scheduler_args = {"step_size": lr_drop_epoch, "gamma": lr_drop_factor},
)

Epoch: 1	Fidelity = 0.318441	KL_Divergence = 0.582472	Observable_XX_ev = 0.659360	Observable_XZ_ev = 0.060344
Epoch: 2	Fidelity = 0.310015	KL_Divergence = 0.576633	Observable_XX_ev = 0.729100	Observable_XZ_ev = -0.028863
Epoch: 3	Fidelity = 0.341978	KL_Divergence = 0.585131	Observable_XX_ev = 0.793302	Observable_XZ_ev = 0.103953
Epoch: 4	Fidelity = 0.355473	KL_Divergence = 0.595958	Observable_XX_ev = 0.755153	Observable_XZ_ev = -0.126818
Epoch: 5	Fidelity = 0.319420	KL_Divergence = 0.556398	Observable_XX_ev = 0.708853	Observable_XZ_ev = 0.123612
Epoch: 6	Fidelity = 0.350800	KL_Divergence = 0.570337	Observable_XX_ev = 0.765336	Observable_XZ_ev = 0.005617
Epoch: 7	Fidelity = 0.324102	KL_Divergence = 0.546107	Observable_XX_ev = 0.684925	Observable_XZ_ev = -0.136316
Epoch: 8	Fidelity = 0.324120	KL_Divergence = 0.540745	Observable_XX_ev = 0.638128	Observable_XZ_ev = -0.143803
Epoch: 9	Fidelity = 0.357888	KL_Divergence = 0.556146	Observable_XX_ev = 0.769205	Observable_XZ_ev = -0.005232
Epoch

## save model & train log

In [86]:
# save model

# save train log
train_log_df = pd.DataFrame()
train_log_df["epoch"] = np.arange(CFG.period, CFG.epochs + 1, CFG.period)
train_log_df["Fidelity"] = callbacks[0]["Fidelity"]
train_log_df["KL_Divergence"] = callbacks[0]["KL_Divergence"]
train_log_df["Observable_XX_ev"] = callbacks[0]["Observable_XX_ev"]
train_log_df["Observable_XZ_ev"] = callbacks[0]["Observable_XZ_ev"]

NameError: name 'CFG' is not defined