## Import library

In [20]:
import os
import subprocess
import random
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
from tqdm.notebook import tqdm
import itertools

import torch
from qucumber.nn_states import DensityMatrix
from qucumber.nn_states import ComplexWaveFunction
from qucumber.callbacks import MetricEvaluator
import qucumber.utils.unitaries as unitaries
import qucumber.utils.training_statistics as ts
import qucumber.utils.cplx as cplx
import qucumber.utils.data as data
from qucumber.observables import ObservableBase, to_pm1
from qucumber.observables.pauli import flip_spin
import qucumber

from qulacs.gate import Pauli

with open('./params_setting.yaml', 'r') as yml:
    params = yaml.safe_load(yml)
    
# quantum circuit parameter
n_qubit = params["circuit_info"]["n_qubit"]
each_n_shot = params["circuit_info"]["each_n_shot"]
state_name = params["circuit_info"]["state_name"]
error_model = params["circuit_info"]["error_model"]
error_rate = params["circuit_info"]["error_rate"]
# RBM architecture parameter
n_visible_unit = params["architecture_info"]["n_visible_unit"]
n_hidden_unit = params["architecture_info"]["n_hidden_unit"] 
n_aux_unit = params["architecture_info"]["n_aux_unit"]
# train parameter
lr = params["train_info"]["lr"]
pbs = params["train_info"]["positive_batch_size"]
nbs = params["train_info"]["negative_batch_size"]
n_gibbs_step = params["train_info"]["n_gibbs_step"]
period = 1
epoch = params["train_info"]["n_epoch"]
lr_drop_epoch = params["train_info"]["lr_drop_epoch"]
lr_drop_factor = params["train_info"]["lr_drop_factor"]
seed = params["train_info"]["seed"]
# sampling parameter
n_sampling = params["sampling_info"]["n_sample"]
n_copy = params["sampling_info"]["n_copy"]
# data path info
train_data_path = f"./data/{error_model}/error_prob_{100*error_rate}%/num_of_data_{each_n_shot}/"
ideal_state_path = f"./target_state/"

# settings
## warnings
warnings.simplefilter('ignore')

## seaborn layout
sns.set()
sns.set_style("white")

## seed
def seed_settings(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    qucumber.set_random_seed(seed, cpu=True, gpu=False)

seed_settings(seed=seed)

## caluculate ideal state

In [27]:
# calculate ideal state
is_ideal_state_file = os.path.exists(ideal_state_path)
if is_ideal_state_file:
    print("ideal state data is exsisted !")
else:
    print("caluculate ideal state data ...")
    os.makedirs(ideal_state_path, exist_ok = True)
    subprocess.run("python caluculate_ideal_state.py", shell=True)
    print("ideal state data is ready !")

caluculate ideal state data ...
ideal state data is ready !


## generate dataset

In [22]:
# generate train data
is_train_data_file = os.path.exists(train_data_path)
if is_train_data_file:
    print("train data is exsisted !")
else:
    print("generate directries & train data ...")
    os.makedirs(train_data_path, exist_ok = True)
    subprocess.run("python generate_dataset.py", shell=True)
    print("train data is ready !")

generate directries & train data ...


0it [00:00, ?it/s]
  0%|          | 0/1000 [00:00<?, ?it/s][A
  0%|          | 5/1000 [00:00<00:23, 42.49it/s][A

measurement pattern 0 : ('X', 'X', 'X')



  1%|          | 10/1000 [00:00<00:22, 44.60it/s][A
  2%|▏         | 15/1000 [00:00<00:21, 46.67it/s][A
  2%|▏         | 21/1000 [00:00<00:20, 48.50it/s][A
  3%|▎         | 26/1000 [00:00<00:19, 48.96it/s][A
  3%|▎         | 31/1000 [00:00<00:19, 49.01it/s][A
  4%|▎         | 36/1000 [00:00<00:19, 49.17it/s][A
  4%|▍         | 41/1000 [00:00<00:19, 49.38it/s][A
  5%|▍         | 46/1000 [00:00<00:19, 48.26it/s][A
  5%|▌         | 52/1000 [00:01<00:19, 49.20it/s][A
  6%|▌         | 57/1000 [00:01<00:19, 48.91it/s][A
  6%|▋         | 63/1000 [00:01<00:18, 49.53it/s][A
  7%|▋         | 68/1000 [00:01<00:19, 48.14it/s][A
  7%|▋         | 73/1000 [00:01<00:19, 47.14it/s][A
  8%|▊         | 78/1000 [00:01<00:19, 46.37it/s][A
  8%|▊         | 83/1000 [00:01<00:19, 47.13it/s][A
  9%|▉         | 88/1000 [00:01<00:19, 47.69it/s][A
  9%|▉         | 93/1000 [00:01<00:18, 48.23it/s][A
 10%|▉         | 99/1000 [00:02<00:18, 48.87it/s][A
 10%|█         | 105/1000 [00:02<00:18, 49.43

measurement pattern 1 : ('X', 'X', 'Y')



  0%|          | 5/1000 [00:00<00:20, 48.37it/s][A
  1%|          | 10/1000 [00:00<00:20, 47.45it/s][A
  2%|▏         | 15/1000 [00:00<00:21, 46.52it/s][A
  2%|▏         | 20/1000 [00:00<00:21, 46.20it/s][A
  2%|▎         | 25/1000 [00:00<00:20, 46.45it/s][A
  3%|▎         | 30/1000 [00:00<00:20, 46.50it/s][A
  4%|▎         | 35/1000 [00:00<00:20, 47.20it/s][A
  4%|▍         | 40/1000 [00:00<00:20, 47.76it/s][A
  5%|▍         | 46/1000 [00:00<00:19, 48.51it/s][A
  5%|▌         | 52/1000 [00:01<00:19, 49.21it/s][A
  6%|▌         | 57/1000 [00:01<00:19, 49.38it/s][A
  6%|▌         | 62/1000 [00:01<00:19, 49.27it/s][A
  7%|▋         | 67/1000 [00:01<00:19, 48.59it/s][A
  7%|▋         | 72/1000 [00:01<00:19, 48.57it/s][A
  8%|▊         | 77/1000 [00:01<00:18, 48.86it/s][A
  8%|▊         | 82/1000 [00:01<00:18, 49.05it/s][A
  9%|▉         | 88/1000 [00:01<00:18, 49.37it/s][A
  9%|▉         | 93/1000 [00:01<00:18, 49.38it/s][A
 10%|▉         | 98/1000 [00:02<00:18, 49.52it

measurement pattern 2 : ('X', 'X', 'Z')



  1%|          | 11/1000 [00:00<00:20, 49.37it/s][A
  2%|▏         | 17/1000 [00:00<00:19, 49.84it/s][A
  2%|▏         | 22/1000 [00:00<00:19, 49.67it/s][A
  3%|▎         | 28/1000 [00:00<00:19, 50.01it/s][A
  3%|▎         | 34/1000 [00:00<00:19, 50.06it/s][A
  4%|▍         | 40/1000 [00:00<00:19, 50.16it/s][A
  5%|▍         | 46/1000 [00:00<00:19, 49.82it/s][A
  5%|▌         | 51/1000 [00:01<00:19, 49.87it/s][A
  6%|▌         | 56/1000 [00:01<00:19, 48.32it/s][A
  6%|▌         | 61/1000 [00:01<00:19, 47.37it/s][A
  7%|▋         | 66/1000 [00:01<00:19, 47.92it/s][A
  7%|▋         | 72/1000 [00:01<00:19, 48.65it/s][A
  8%|▊         | 77/1000 [00:01<00:18, 48.66it/s][A
  8%|▊         | 82/1000 [00:01<00:19, 46.64it/s][A
  9%|▊         | 87/1000 [00:01<00:19, 46.49it/s][A
  9%|▉         | 92/1000 [00:01<00:19, 47.21it/s][A
 10%|▉         | 97/1000 [00:02<00:19, 45.77it/s][A
 10%|█         | 102/1000 [00:02<00:19, 46.11it/s][A
 11%|█         | 107/1000 [00:02<00:18, 47.0

measurement pattern 3 : ('X', 'Y', 'X')



  0%|          | 4/1000 [00:00<00:26, 38.22it/s][A
  1%|          | 8/1000 [00:00<00:27, 36.08it/s][A
  1%|▏         | 13/1000 [00:00<00:24, 39.49it/s][A
  2%|▏         | 18/1000 [00:00<00:22, 42.76it/s][A
  2%|▏         | 23/1000 [00:00<00:22, 44.20it/s][A
  3%|▎         | 28/1000 [00:00<00:21, 45.71it/s][A
  3%|▎         | 33/1000 [00:00<00:20, 46.86it/s][A
  4%|▍         | 38/1000 [00:00<00:20, 47.74it/s][A
  4%|▍         | 44/1000 [00:00<00:19, 48.59it/s][A
  5%|▌         | 50/1000 [00:01<00:19, 49.21it/s][A
  6%|▌         | 56/1000 [00:01<00:18, 49.70it/s][A
  6%|▌         | 62/1000 [00:01<00:18, 50.13it/s][A
  7%|▋         | 68/1000 [00:01<00:18, 50.18it/s][A
  7%|▋         | 74/1000 [00:01<00:18, 49.80it/s][A
  8%|▊         | 80/1000 [00:01<00:18, 50.05it/s][A
  9%|▊         | 86/1000 [00:01<00:18, 48.99it/s][A
  9%|▉         | 91/1000 [00:01<00:19, 47.29it/s][A
 10%|▉         | 96/1000 [00:02<00:20, 43.94it/s][A
 10%|█         | 101/1000 [00:02<00:24, 37.20it

measurement pattern 4 : ('X', 'Y', 'Y')



  0%|          | 5/1000 [00:00<00:20, 49.10it/s][A
  1%|          | 11/1000 [00:00<00:19, 49.80it/s][A
  2%|▏         | 16/1000 [00:00<00:20, 48.79it/s][A
  2%|▏         | 21/1000 [00:00<00:20, 47.20it/s][A
  3%|▎         | 26/1000 [00:00<00:20, 46.49it/s][A
  3%|▎         | 31/1000 [00:00<00:24, 39.68it/s][A
  4%|▎         | 36/1000 [00:00<00:23, 41.87it/s][A
  4%|▍         | 41/1000 [00:00<00:21, 43.71it/s][A
  5%|▍         | 47/1000 [00:01<00:20, 45.78it/s][A
  5%|▌         | 52/1000 [00:01<00:20, 46.70it/s][A
  6%|▌         | 57/1000 [00:01<00:19, 47.45it/s][A
  6%|▌         | 62/1000 [00:01<00:19, 47.80it/s][A
  7%|▋         | 67/1000 [00:01<00:19, 48.05it/s][A
  7%|▋         | 72/1000 [00:01<00:19, 48.36it/s][A
  8%|▊         | 77/1000 [00:01<00:18, 48.62it/s][A
  8%|▊         | 82/1000 [00:01<00:19, 47.57it/s][A
  9%|▊         | 87/1000 [00:01<00:19, 47.49it/s][A
  9%|▉         | 92/1000 [00:01<00:19, 47.76it/s][A
 10%|▉         | 97/1000 [00:02<00:18, 48.16it

measurement pattern 5 : ('X', 'Y', 'Z')



  1%|          | 9/1000 [00:00<00:24, 40.79it/s][A
  1%|▏         | 14/1000 [00:00<00:22, 43.78it/s][A
  2%|▏         | 19/1000 [00:00<00:21, 44.64it/s][A
  2%|▏         | 24/1000 [00:00<00:21, 44.89it/s][A
  3%|▎         | 29/1000 [00:00<00:22, 43.29it/s][A
  3%|▎         | 34/1000 [00:00<00:21, 44.61it/s][A
  4%|▍         | 39/1000 [00:00<00:21, 45.52it/s][A
  4%|▍         | 44/1000 [00:01<00:26, 35.57it/s][A
  5%|▍         | 49/1000 [00:01<00:25, 37.74it/s][A
  5%|▌         | 54/1000 [00:01<00:23, 40.00it/s][A
  6%|▌         | 59/1000 [00:01<00:23, 40.29it/s][A
  6%|▋         | 64/1000 [00:01<00:22, 41.07it/s][A
  7%|▋         | 69/1000 [00:01<00:21, 42.67it/s][A
  7%|▋         | 74/1000 [00:01<00:21, 43.81it/s][A
  8%|▊         | 79/1000 [00:01<00:20, 44.95it/s][A
  8%|▊         | 84/1000 [00:01<00:19, 46.34it/s][A
  9%|▉         | 89/1000 [00:02<00:19, 47.26it/s][A
  9%|▉         | 94/1000 [00:02<00:19, 45.87it/s][A
 10%|▉         | 99/1000 [00:02<00:19, 45.42it

measurement pattern 6 : ('X', 'Z', 'X')



  1%|          | 6/1000 [00:00<00:19, 51.91it/s][A
  1%|          | 12/1000 [00:00<00:19, 51.58it/s][A
  2%|▏         | 18/1000 [00:00<00:19, 51.53it/s][A
  2%|▏         | 24/1000 [00:00<00:18, 51.71it/s][A
  3%|▎         | 30/1000 [00:00<00:18, 51.50it/s][A
  4%|▎         | 36/1000 [00:00<00:18, 51.50it/s][A
  4%|▍         | 42/1000 [00:00<00:18, 51.41it/s][A
  5%|▍         | 48/1000 [00:00<00:18, 51.60it/s][A
  5%|▌         | 54/1000 [00:01<00:18, 51.65it/s][A
  6%|▌         | 60/1000 [00:01<00:18, 51.97it/s][A
  7%|▋         | 66/1000 [00:01<00:17, 51.96it/s][A
  7%|▋         | 72/1000 [00:01<00:17, 51.69it/s][A
  8%|▊         | 78/1000 [00:01<00:17, 51.47it/s][A
  8%|▊         | 84/1000 [00:01<00:17, 51.60it/s][A
  9%|▉         | 90/1000 [00:01<00:17, 51.16it/s][A
 10%|▉         | 96/1000 [00:01<00:17, 51.41it/s][A
 10%|█         | 102/1000 [00:01<00:17, 51.45it/s][A
 11%|█         | 108/1000 [00:02<00:17, 51.71it/s][A
 11%|█▏        | 114/1000 [00:02<00:17, 51.5

measurement pattern 7 : ('X', 'Z', 'Y')


  1%|          | 6/1000 [00:00<00:20, 49.20it/s][A
  1%|          | 11/1000 [00:00<00:21, 45.13it/s][A
  2%|▏         | 16/1000 [00:00<00:21, 45.77it/s][A
  2%|▏         | 22/1000 [00:00<00:20, 47.64it/s][A
  3%|▎         | 27/1000 [00:00<00:21, 46.13it/s][A
  3%|▎         | 32/1000 [00:00<00:22, 43.54it/s][A
  4%|▎         | 37/1000 [00:00<00:21, 45.39it/s][A
  4%|▍         | 43/1000 [00:00<00:20, 47.35it/s][A
  5%|▍         | 48/1000 [00:01<00:19, 47.66it/s][A
  5%|▌         | 53/1000 [00:01<00:19, 48.16it/s][A
  6%|▌         | 59/1000 [00:01<00:19, 49.23it/s][A
  6%|▋         | 65/1000 [00:01<00:18, 50.00it/s][A
  7%|▋         | 71/1000 [00:01<00:18, 50.12it/s][A
  8%|▊         | 77/1000 [00:01<00:18, 49.65it/s][A
  8%|▊         | 82/1000 [00:01<00:18, 49.48it/s][A
  9%|▉         | 88/1000 [00:01<00:18, 50.04it/s][A
  9%|▉         | 94/1000 [00:01<00:18, 50.13it/s][A
 10%|█         | 100/1000 [00:02<00:18, 49.88it/s][A
 10%|█         | 105/1000 [00:02<00:17, 49.78i

measurement pattern 8 : ('X', 'Z', 'Z')



  1%|          | 12/1000 [00:00<00:19, 50.97it/s][A
  2%|▏         | 18/1000 [00:00<00:19, 51.20it/s][A
  2%|▏         | 24/1000 [00:00<00:19, 49.12it/s][A
  3%|▎         | 29/1000 [00:00<00:20, 47.42it/s][A
  4%|▎         | 35/1000 [00:00<00:19, 48.74it/s][A
  4%|▍         | 41/1000 [00:00<00:19, 49.63it/s][A
  5%|▍         | 47/1000 [00:00<00:18, 50.22it/s][A
  5%|▌         | 53/1000 [00:01<00:18, 50.49it/s][A
  6%|▌         | 59/1000 [00:01<00:18, 50.74it/s][A
  6%|▋         | 65/1000 [00:01<00:18, 51.03it/s][A
  7%|▋         | 71/1000 [00:01<00:18, 51.00it/s][A
  8%|▊         | 77/1000 [00:01<00:18, 50.78it/s][A
  8%|▊         | 83/1000 [00:01<00:17, 51.09it/s][A
  9%|▉         | 89/1000 [00:01<00:17, 51.14it/s][A
 10%|▉         | 95/1000 [00:01<00:17, 51.21it/s][A
 10%|█         | 101/1000 [00:01<00:17, 51.28it/s][A
 11%|█         | 107/1000 [00:02<00:17, 51.37it/s][A
 11%|█▏        | 113/1000 [00:02<00:17, 51.43it/s][A
 12%|█▏        | 119/1000 [00:02<00:17, 51

measurement pattern 9 : ('Y', 'X', 'X')



  1%|          | 12/1000 [00:00<00:20, 47.11it/s][A
  2%|▏         | 17/1000 [00:00<00:20, 48.00it/s][A
  2%|▏         | 22/1000 [00:00<00:20, 48.57it/s][A
  3%|▎         | 27/1000 [00:00<00:19, 48.79it/s][A
  3%|▎         | 32/1000 [00:00<00:20, 47.75it/s][A
  4%|▎         | 37/1000 [00:00<00:20, 47.52it/s][A
  4%|▍         | 42/1000 [00:00<00:20, 47.74it/s][A
  5%|▍         | 48/1000 [00:00<00:19, 48.48it/s][A
  5%|▌         | 53/1000 [00:01<00:19, 48.35it/s][A
  6%|▌         | 58/1000 [00:01<00:19, 48.81it/s][A
  6%|▋         | 63/1000 [00:01<00:19, 48.63it/s][A
  7%|▋         | 68/1000 [00:01<00:19, 48.35it/s][A
  7%|▋         | 74/1000 [00:01<00:18, 49.16it/s][A
  8%|▊         | 80/1000 [00:01<00:18, 49.70it/s][A
  8%|▊         | 85/1000 [00:01<00:19, 47.30it/s][A
  9%|▉         | 90/1000 [00:01<00:19, 47.38it/s][A
 10%|▉         | 95/1000 [00:01<00:18, 47.67it/s][A
 10%|█         | 100/1000 [00:02<00:18, 48.04it/s][A
 10%|█         | 105/1000 [00:02<00:18, 48.1

measurement pattern 10 : ('Y', 'X', 'Y')



  1%|          | 6/1000 [00:00<00:19, 50.88it/s][A
  1%|          | 12/1000 [00:00<00:19, 49.77it/s][A
  2%|▏         | 17/1000 [00:00<00:20, 48.63it/s][A
  2%|▏         | 22/1000 [00:00<00:20, 48.57it/s][A
  3%|▎         | 27/1000 [00:00<00:19, 48.88it/s][A
  3%|▎         | 33/1000 [00:00<00:19, 49.55it/s][A
  4%|▍         | 39/1000 [00:00<00:19, 49.98it/s][A
  4%|▍         | 45/1000 [00:00<00:18, 50.38it/s][A
  5%|▌         | 51/1000 [00:01<00:18, 50.32it/s][A
  6%|▌         | 57/1000 [00:01<00:19, 48.84it/s][A
  6%|▌         | 62/1000 [00:01<00:19, 47.54it/s][A
  7%|▋         | 67/1000 [00:01<00:19, 47.77it/s][A
  7%|▋         | 72/1000 [00:01<00:19, 48.09it/s][A
  8%|▊         | 78/1000 [00:01<00:18, 48.88it/s][A
  8%|▊         | 84/1000 [00:01<00:18, 49.53it/s][A
  9%|▉         | 90/1000 [00:01<00:18, 49.96it/s][A
 10%|▉         | 96/1000 [00:01<00:17, 50.33it/s][A
 10%|█         | 102/1000 [00:02<00:17, 50.59it/s][A
 11%|█         | 108/1000 [00:02<00:17, 49.79

measurement pattern 11 : ('Y', 'X', 'Z')



  1%|          | 6/1000 [00:00<00:19, 51.45it/s][A
  1%|          | 12/1000 [00:00<00:19, 51.17it/s][A
  2%|▏         | 18/1000 [00:00<00:19, 51.06it/s][A
  2%|▏         | 24/1000 [00:00<00:19, 51.27it/s][A
  3%|▎         | 30/1000 [00:00<00:19, 50.95it/s][A
  4%|▎         | 36/1000 [00:00<00:18, 51.11it/s][A
  4%|▍         | 42/1000 [00:00<00:19, 48.92it/s][A
  5%|▍         | 47/1000 [00:00<00:19, 48.80it/s][A
  5%|▌         | 53/1000 [00:01<00:19, 49.57it/s][A
  6%|▌         | 59/1000 [00:01<00:18, 50.19it/s][A
  6%|▋         | 65/1000 [00:01<00:18, 50.54it/s][A
  7%|▋         | 71/1000 [00:01<00:18, 50.79it/s][A
  8%|▊         | 77/1000 [00:01<00:18, 50.98it/s][A
  8%|▊         | 83/1000 [00:01<00:17, 51.15it/s][A
  9%|▉         | 89/1000 [00:01<00:17, 51.25it/s][A
 10%|▉         | 95/1000 [00:01<00:17, 51.25it/s][A
 10%|█         | 101/1000 [00:01<00:17, 51.33it/s][A
 11%|█         | 107/1000 [00:02<00:17, 51.36it/s][A
 11%|█▏        | 113/1000 [00:02<00:17, 51.3

measurement pattern 12 : ('Y', 'Y', 'X')



  1%|          | 12/1000 [00:00<00:19, 50.92it/s][A
  2%|▏         | 18/1000 [00:00<00:19, 51.08it/s][A
  2%|▏         | 24/1000 [00:00<00:19, 50.96it/s][A
  3%|▎         | 30/1000 [00:00<00:18, 51.27it/s][A
  4%|▎         | 36/1000 [00:00<00:18, 51.09it/s][A
  4%|▍         | 42/1000 [00:00<00:18, 51.03it/s][A
  5%|▍         | 48/1000 [00:00<00:18, 50.72it/s][A
  5%|▌         | 54/1000 [00:01<00:18, 50.63it/s][A
  6%|▌         | 60/1000 [00:01<00:18, 50.69it/s][A
  7%|▋         | 66/1000 [00:01<00:18, 50.85it/s][A
  7%|▋         | 72/1000 [00:01<00:18, 50.88it/s][A
  8%|▊         | 78/1000 [00:01<00:18, 50.89it/s][A
  8%|▊         | 84/1000 [00:01<00:17, 51.05it/s][A
  9%|▉         | 90/1000 [00:01<00:17, 51.05it/s][A
 10%|▉         | 96/1000 [00:01<00:17, 51.02it/s][A
 10%|█         | 102/1000 [00:02<00:17, 51.04it/s][A
 11%|█         | 108/1000 [00:02<00:17, 50.65it/s][A
 11%|█▏        | 114/1000 [00:02<00:17, 50.92it/s][A
 12%|█▏        | 120/1000 [00:02<00:17, 50

measurement pattern 13 : ('Y', 'Y', 'Y')



  0%|          | 5/1000 [00:00<00:20, 49.70it/s][A
  1%|          | 11/1000 [00:00<00:19, 50.15it/s][A
  2%|▏         | 17/1000 [00:00<00:19, 50.49it/s][A
  2%|▏         | 23/1000 [00:00<00:19, 50.18it/s][A
  3%|▎         | 29/1000 [00:00<00:19, 50.41it/s][A
  4%|▎         | 35/1000 [00:00<00:19, 50.78it/s][A
  4%|▍         | 41/1000 [00:00<00:18, 50.90it/s][A
  5%|▍         | 47/1000 [00:00<00:18, 50.88it/s][A
  5%|▌         | 53/1000 [00:01<00:18, 50.89it/s][A
  6%|▌         | 59/1000 [00:01<00:18, 50.77it/s][A
  6%|▋         | 65/1000 [00:01<00:18, 50.73it/s][A
  7%|▋         | 71/1000 [00:01<00:18, 50.92it/s][A
  8%|▊         | 77/1000 [00:01<00:18, 50.63it/s][A
  8%|▊         | 83/1000 [00:01<00:18, 50.80it/s][A
  9%|▉         | 89/1000 [00:01<00:17, 50.82it/s][A
 10%|▉         | 95/1000 [00:01<00:17, 50.81it/s][A
 10%|█         | 101/1000 [00:01<00:17, 50.69it/s][A
 11%|█         | 107/1000 [00:02<00:17, 50.82it/s][A
 11%|█▏        | 113/1000 [00:02<00:17, 50.8

measurement pattern 14 : ('Y', 'Y', 'Z')



  1%|          | 11/1000 [00:00<00:19, 49.76it/s][A
  2%|▏         | 16/1000 [00:00<00:19, 49.79it/s][A
  2%|▏         | 22/1000 [00:00<00:19, 50.49it/s][A
  3%|▎         | 28/1000 [00:00<00:19, 50.72it/s][A
  3%|▎         | 34/1000 [00:00<00:19, 50.73it/s][A
  4%|▍         | 40/1000 [00:00<00:18, 50.54it/s][A
  5%|▍         | 46/1000 [00:00<00:18, 50.69it/s][A
  5%|▌         | 52/1000 [00:01<00:18, 50.87it/s][A
  6%|▌         | 58/1000 [00:01<00:18, 50.90it/s][A
  6%|▋         | 64/1000 [00:01<00:18, 51.03it/s][A
  7%|▋         | 70/1000 [00:01<00:18, 51.04it/s][A
  8%|▊         | 76/1000 [00:01<00:18, 51.07it/s][A
  8%|▊         | 82/1000 [00:01<00:17, 51.03it/s][A
  9%|▉         | 88/1000 [00:01<00:17, 51.24it/s][A
  9%|▉         | 94/1000 [00:01<00:17, 50.70it/s][A
 10%|█         | 100/1000 [00:01<00:17, 51.00it/s][A
 11%|█         | 106/1000 [00:02<00:17, 51.16it/s][A
 11%|█         | 112/1000 [00:02<00:18, 48.39it/s][A
 12%|█▏        | 118/1000 [00:02<00:17, 49

measurement pattern 15 : ('Y', 'Z', 'X')



  1%|          | 11/1000 [00:00<00:19, 50.23it/s][A
  2%|▏         | 17/1000 [00:00<00:19, 49.58it/s][A
  2%|▏         | 22/1000 [00:00<00:19, 49.00it/s][A
  3%|▎         | 28/1000 [00:00<00:19, 49.99it/s][A
  3%|▎         | 33/1000 [00:00<00:19, 49.70it/s][A
  4%|▍         | 38/1000 [00:00<00:19, 48.94it/s][A
  4%|▍         | 43/1000 [00:00<00:19, 48.27it/s][A
  5%|▍         | 49/1000 [00:00<00:19, 49.27it/s][A
  6%|▌         | 55/1000 [00:01<00:18, 49.82it/s][A
  6%|▌         | 61/1000 [00:01<00:18, 50.30it/s][A
  7%|▋         | 67/1000 [00:01<00:18, 50.39it/s][A
  7%|▋         | 73/1000 [00:01<00:18, 50.58it/s][A
  8%|▊         | 79/1000 [00:01<00:18, 50.79it/s][A
  8%|▊         | 85/1000 [00:01<00:17, 50.86it/s][A
  9%|▉         | 91/1000 [00:01<00:17, 50.99it/s][A
 10%|▉         | 97/1000 [00:01<00:17, 50.69it/s][A
 10%|█         | 103/1000 [00:02<00:17, 50.79it/s][A
 11%|█         | 109/1000 [00:02<00:17, 51.11it/s][A
 12%|█▏        | 115/1000 [00:02<00:17, 51.

measurement pattern 16 : ('Y', 'Z', 'Y')



  1%|          | 12/1000 [00:00<00:19, 50.97it/s][A
  2%|▏         | 18/1000 [00:00<00:19, 51.15it/s][A
  2%|▏         | 24/1000 [00:00<00:19, 51.08it/s][A
  3%|▎         | 30/1000 [00:00<00:18, 51.18it/s][A
  4%|▎         | 36/1000 [00:00<00:18, 51.13it/s][A
  4%|▍         | 42/1000 [00:00<00:18, 50.84it/s][A
  5%|▍         | 48/1000 [00:00<00:18, 50.79it/s][A
  5%|▌         | 54/1000 [00:01<00:18, 50.87it/s][A
  6%|▌         | 60/1000 [00:01<00:18, 50.91it/s][A
  7%|▋         | 66/1000 [00:01<00:18, 50.85it/s][A
  7%|▋         | 72/1000 [00:01<00:18, 50.91it/s][A
  8%|▊         | 78/1000 [00:01<00:18, 51.15it/s][A
  8%|▊         | 84/1000 [00:01<00:17, 51.15it/s][A
  9%|▉         | 90/1000 [00:01<00:17, 51.23it/s][A
 10%|▉         | 96/1000 [00:01<00:17, 50.74it/s][A
 10%|█         | 102/1000 [00:02<00:17, 50.96it/s][A
 11%|█         | 108/1000 [00:02<00:17, 51.01it/s][A
 11%|█▏        | 114/1000 [00:02<00:17, 51.21it/s][A
 12%|█▏        | 120/1000 [00:02<00:17, 51

measurement pattern 17 : ('Y', 'Z', 'Z')



  0%|          | 5/1000 [00:00<00:21, 45.31it/s][A
  1%|          | 10/1000 [00:00<00:22, 44.78it/s][A
  2%|▏         | 15/1000 [00:00<00:22, 44.61it/s][A
  2%|▏         | 20/1000 [00:00<00:21, 45.03it/s][A
  2%|▎         | 25/1000 [00:00<00:21, 45.51it/s][A
  3%|▎         | 30/1000 [00:00<00:20, 46.22it/s][A
  4%|▎         | 35/1000 [00:00<00:21, 45.75it/s][A
  4%|▍         | 40/1000 [00:00<00:20, 46.38it/s][A
  4%|▍         | 45/1000 [00:00<00:20, 46.40it/s][A
  5%|▌         | 50/1000 [00:01<00:20, 46.75it/s][A
  6%|▌         | 55/1000 [00:01<00:20, 46.90it/s][A
  6%|▌         | 60/1000 [00:01<00:20, 46.90it/s][A
  6%|▋         | 65/1000 [00:01<00:22, 41.00it/s][A
  7%|▋         | 70/1000 [00:01<00:22, 40.78it/s][A
  8%|▊         | 75/1000 [00:01<00:21, 42.58it/s][A
  8%|▊         | 80/1000 [00:01<00:20, 44.34it/s][A
  8%|▊         | 85/1000 [00:01<00:20, 45.47it/s][A
  9%|▉         | 90/1000 [00:01<00:19, 45.79it/s][A
 10%|▉         | 95/1000 [00:02<00:20, 44.79it

measurement pattern 18 : ('Z', 'X', 'X')



  0%|          | 5/1000 [00:00<00:21, 46.51it/s][A
  1%|          | 10/1000 [00:00<00:20, 47.66it/s][A
  2%|▏         | 15/1000 [00:00<00:23, 42.30it/s][A
  2%|▏         | 20/1000 [00:00<00:23, 41.86it/s][A
  2%|▎         | 25/1000 [00:00<00:23, 41.59it/s][A
  3%|▎         | 30/1000 [00:00<00:23, 41.24it/s][A
  4%|▎         | 35/1000 [00:00<00:22, 42.71it/s][A
  4%|▍         | 40/1000 [00:00<00:21, 44.20it/s][A
  4%|▍         | 45/1000 [00:01<00:21, 45.37it/s][A
  5%|▌         | 50/1000 [00:01<00:20, 45.96it/s][A
  6%|▌         | 55/1000 [00:01<00:20, 46.54it/s][A
  6%|▌         | 60/1000 [00:01<00:20, 46.37it/s][A
  6%|▋         | 65/1000 [00:01<00:20, 46.46it/s][A
  7%|▋         | 70/1000 [00:01<00:19, 46.88it/s][A
  8%|▊         | 75/1000 [00:01<00:19, 46.95it/s][A
  8%|▊         | 80/1000 [00:01<00:19, 47.27it/s][A
  8%|▊         | 85/1000 [00:01<00:19, 47.22it/s][A
  9%|▉         | 90/1000 [00:01<00:19, 47.55it/s][A
 10%|▉         | 95/1000 [00:02<00:19, 47.35it

measurement pattern 19 : ('Z', 'X', 'Y')



  1%|          | 10/1000 [00:00<00:21, 46.81it/s][A
  2%|▏         | 15/1000 [00:00<00:20, 47.23it/s][A
  2%|▏         | 20/1000 [00:00<00:20, 46.95it/s][A
  2%|▎         | 25/1000 [00:00<00:20, 46.83it/s][A
  3%|▎         | 30/1000 [00:00<00:20, 46.70it/s][A
  4%|▎         | 35/1000 [00:00<00:20, 46.80it/s][A
  4%|▍         | 40/1000 [00:00<00:20, 46.67it/s][A
  4%|▍         | 45/1000 [00:00<00:20, 46.98it/s][A
  5%|▌         | 50/1000 [00:01<00:20, 46.91it/s][A
  6%|▌         | 55/1000 [00:01<00:20, 47.03it/s][A
  6%|▌         | 60/1000 [00:01<00:19, 47.08it/s][A
  6%|▋         | 65/1000 [00:01<00:19, 46.91it/s][A
  7%|▋         | 70/1000 [00:01<00:19, 47.05it/s][A
  8%|▊         | 75/1000 [00:01<00:19, 47.19it/s][A
  8%|▊         | 80/1000 [00:01<00:19, 46.53it/s][A
  8%|▊         | 85/1000 [00:01<00:19, 47.04it/s][A
  9%|▉         | 90/1000 [00:01<00:19, 47.03it/s][A
 10%|▉         | 95/1000 [00:02<00:19, 47.13it/s][A
 10%|█         | 100/1000 [00:02<00:19, 46.70

measurement pattern 20 : ('Z', 'X', 'Z')



  1%|          | 12/1000 [00:00<00:19, 51.34it/s][A
  2%|▏         | 18/1000 [00:00<00:18, 51.69it/s][A
  2%|▏         | 24/1000 [00:00<00:18, 51.61it/s][A
  3%|▎         | 30/1000 [00:00<00:18, 51.61it/s][A
  4%|▎         | 36/1000 [00:00<00:18, 51.63it/s][A
  4%|▍         | 42/1000 [00:00<00:18, 51.73it/s][A
  5%|▍         | 48/1000 [00:00<00:18, 51.67it/s][A
  5%|▌         | 54/1000 [00:01<00:19, 49.32it/s][A
  6%|▌         | 59/1000 [00:01<00:20, 44.88it/s][A
  6%|▋         | 64/1000 [00:01<00:21, 44.39it/s][A
  7%|▋         | 69/1000 [00:01<00:20, 45.72it/s][A
  8%|▊         | 75/1000 [00:01<00:19, 47.41it/s][A
  8%|▊         | 81/1000 [00:01<00:18, 48.68it/s][A
  9%|▊         | 87/1000 [00:01<00:18, 49.73it/s][A
  9%|▉         | 93/1000 [00:01<00:18, 50.29it/s][A
 10%|▉         | 99/1000 [00:02<00:18, 48.25it/s][A
 10%|█         | 104/1000 [00:02<00:20, 43.92it/s][A
 11%|█         | 109/1000 [00:02<00:19, 45.43it/s][A
 12%|█▏        | 115/1000 [00:02<00:18, 47.

measurement pattern 21 : ('Z', 'Y', 'X')



  1%|          | 9/1000 [00:00<00:23, 41.89it/s][A
  1%|▏         | 14/1000 [00:00<00:21, 45.04it/s][A
  2%|▏         | 19/1000 [00:00<00:21, 46.52it/s][A
  2%|▏         | 24/1000 [00:00<00:20, 47.34it/s][A
  3%|▎         | 30/1000 [00:00<00:19, 48.66it/s][A
  4%|▎         | 36/1000 [00:00<00:19, 49.59it/s][A
  4%|▍         | 42/1000 [00:00<00:19, 50.04it/s][A
  5%|▍         | 47/1000 [00:00<00:19, 49.55it/s][A
  5%|▌         | 52/1000 [00:01<00:19, 48.11it/s][A
  6%|▌         | 57/1000 [00:01<00:19, 48.18it/s][A
  6%|▌         | 62/1000 [00:01<00:19, 47.90it/s][A
  7%|▋         | 67/1000 [00:01<00:19, 48.01it/s][A
  7%|▋         | 72/1000 [00:01<00:19, 46.71it/s][A
  8%|▊         | 77/1000 [00:01<00:19, 46.27it/s][A
  8%|▊         | 82/1000 [00:01<00:19, 47.09it/s][A
  9%|▊         | 87/1000 [00:01<00:19, 47.91it/s][A
  9%|▉         | 92/1000 [00:01<00:19, 46.89it/s][A
 10%|▉         | 97/1000 [00:02<00:19, 45.50it/s][A
 10%|█         | 102/1000 [00:02<00:19, 45.67i

measurement pattern 22 : ('Z', 'Y', 'Y')



  1%|          | 6/1000 [00:00<00:19, 50.98it/s][A
  1%|          | 12/1000 [00:00<00:19, 50.82it/s][A
  2%|▏         | 18/1000 [00:00<00:19, 50.81it/s][A
  2%|▏         | 24/1000 [00:00<00:19, 50.67it/s][A
  3%|▎         | 30/1000 [00:00<00:19, 50.67it/s][A
  4%|▎         | 36/1000 [00:00<00:18, 50.89it/s][A
  4%|▍         | 42/1000 [00:00<00:18, 51.07it/s][A
  5%|▍         | 48/1000 [00:00<00:18, 51.21it/s][A
  5%|▌         | 54/1000 [00:01<00:18, 50.93it/s][A
  6%|▌         | 60/1000 [00:01<00:18, 51.08it/s][A
  7%|▋         | 66/1000 [00:01<00:18, 50.82it/s][A
  7%|▋         | 72/1000 [00:01<00:18, 50.91it/s][A
  8%|▊         | 78/1000 [00:01<00:18, 49.63it/s][A
  8%|▊         | 83/1000 [00:01<00:18, 48.61it/s][A
  9%|▉         | 88/1000 [00:01<00:18, 48.27it/s][A
  9%|▉         | 93/1000 [00:01<00:18, 48.14it/s][A
 10%|▉         | 98/1000 [00:01<00:18, 47.94it/s][A
 10%|█         | 103/1000 [00:02<00:19, 46.19it/s][A
 11%|█         | 108/1000 [00:02<00:20, 43.22

measurement pattern 23 : ('Z', 'Y', 'Z')



  1%|          | 12/1000 [00:00<00:20, 47.07it/s][A
  2%|▏         | 17/1000 [00:00<00:20, 47.70it/s][A
  2%|▏         | 22/1000 [00:00<00:20, 46.57it/s][A
  3%|▎         | 27/1000 [00:00<00:20, 47.49it/s][A
  3%|▎         | 32/1000 [00:00<00:20, 48.06it/s][A
  4%|▎         | 37/1000 [00:00<00:19, 48.38it/s][A
  4%|▍         | 43/1000 [00:00<00:19, 49.00it/s][A
  5%|▍         | 48/1000 [00:00<00:19, 48.45it/s][A
  5%|▌         | 54/1000 [00:01<00:19, 49.01it/s][A
  6%|▌         | 59/1000 [00:01<00:19, 48.19it/s][A
  6%|▋         | 65/1000 [00:01<00:19, 48.82it/s][A
  7%|▋         | 70/1000 [00:01<00:19, 48.41it/s][A
  8%|▊         | 75/1000 [00:01<00:18, 48.83it/s][A
  8%|▊         | 80/1000 [00:01<00:20, 44.92it/s][A
  8%|▊         | 85/1000 [00:01<00:20, 45.15it/s][A
  9%|▉         | 90/1000 [00:01<00:19, 45.60it/s][A
 10%|▉         | 95/1000 [00:02<00:19, 46.47it/s][A
 10%|█         | 100/1000 [00:02<00:19, 46.92it/s][A
 10%|█         | 105/1000 [00:02<00:19, 47.0

measurement pattern 24 : ('Z', 'Z', 'X')



  1%|          | 12/1000 [00:00<00:19, 50.30it/s][A
  2%|▏         | 18/1000 [00:00<00:19, 50.17it/s][A
  2%|▏         | 24/1000 [00:00<00:21, 46.35it/s][A
  3%|▎         | 29/1000 [00:00<00:21, 45.69it/s][A
  3%|▎         | 34/1000 [00:00<00:20, 46.79it/s][A
  4%|▍         | 39/1000 [00:00<00:20, 46.37it/s][A
  4%|▍         | 44/1000 [00:00<00:20, 46.10it/s][A
  5%|▍         | 49/1000 [00:01<00:20, 46.60it/s][A
  5%|▌         | 54/1000 [00:01<00:20, 46.11it/s][A
  6%|▌         | 59/1000 [00:01<00:20, 46.00it/s][A
  6%|▋         | 64/1000 [00:01<00:19, 46.95it/s][A
  7%|▋         | 69/1000 [00:01<00:19, 47.82it/s][A
  7%|▋         | 74/1000 [00:01<00:19, 47.73it/s][A
  8%|▊         | 80/1000 [00:01<00:18, 48.77it/s][A
  8%|▊         | 85/1000 [00:01<00:19, 47.83it/s][A
  9%|▉         | 91/1000 [00:01<00:18, 48.81it/s][A
 10%|▉         | 96/1000 [00:02<00:18, 49.12it/s][A
 10%|█         | 101/1000 [00:02<00:18, 49.13it/s][A
 11%|█         | 107/1000 [00:02<00:18, 49.5

measurement pattern 25 : ('Z', 'Z', 'Y')



  1%|          | 12/1000 [00:00<00:19, 51.17it/s][A
  2%|▏         | 18/1000 [00:00<00:19, 51.09it/s][A
  2%|▏         | 24/1000 [00:00<00:19, 50.94it/s][A
  3%|▎         | 30/1000 [00:00<00:19, 50.87it/s][A
  4%|▎         | 36/1000 [00:00<00:18, 51.10it/s][A
  4%|▍         | 42/1000 [00:00<00:18, 51.23it/s][A
  5%|▍         | 48/1000 [00:00<00:18, 51.28it/s][A
  5%|▌         | 54/1000 [00:01<00:18, 51.48it/s][A
  6%|▌         | 60/1000 [00:01<00:18, 51.39it/s][A
  7%|▋         | 66/1000 [00:01<00:18, 51.58it/s][A
  7%|▋         | 72/1000 [00:01<00:18, 51.31it/s][A
  8%|▊         | 78/1000 [00:01<00:18, 49.16it/s][A
  8%|▊         | 83/1000 [00:01<00:19, 46.71it/s][A
  9%|▉         | 88/1000 [00:01<00:19, 45.99it/s][A
  9%|▉         | 93/1000 [00:01<00:20, 44.77it/s][A
 10%|▉         | 98/1000 [00:02<00:20, 44.78it/s][A
 10%|█         | 103/1000 [00:02<00:20, 43.57it/s][A
 11%|█         | 108/1000 [00:02<00:20, 42.95it/s][A
 11%|█▏        | 113/1000 [00:02<00:21, 42.

measurement pattern 26 : ('Z', 'Z', 'Z')



  1%|          | 12/1000 [00:00<00:19, 51.30it/s][A
  2%|▏         | 18/1000 [00:00<00:20, 47.63it/s][A
  2%|▏         | 23/1000 [00:00<00:22, 43.69it/s][A
  3%|▎         | 28/1000 [00:00<00:21, 44.29it/s][A
  3%|▎         | 33/1000 [00:00<00:21, 45.65it/s][A
  4%|▍         | 38/1000 [00:00<00:21, 44.30it/s][A
  4%|▍         | 43/1000 [00:00<00:20, 45.86it/s][A
  5%|▍         | 49/1000 [00:01<00:19, 47.57it/s][A
  5%|▌         | 54/1000 [00:01<00:19, 47.95it/s][A
  6%|▌         | 60/1000 [00:01<00:19, 49.03it/s][A
  7%|▋         | 66/1000 [00:01<00:18, 49.92it/s][A
  7%|▋         | 71/1000 [00:01<00:18, 49.77it/s][A
  8%|▊         | 76/1000 [00:01<00:19, 47.49it/s][A
  8%|▊         | 82/1000 [00:01<00:18, 48.85it/s][A
  9%|▉         | 88/1000 [00:01<00:18, 49.39it/s][A
  9%|▉         | 94/1000 [00:01<00:18, 49.97it/s][A
 10%|▉         | 99/1000 [00:02<00:19, 45.68it/s][A
 10%|█         | 104/1000 [00:02<00:19, 45.24it/s][A
 11%|█         | 109/1000 [00:02<00:19, 45.9

train data is ready !


## load dataset

In [23]:
meas_pattern_path = train_data_path + "/measurement_pattern.txt"
meas_label_path = train_data_path + "/measurement_label.txt"
meas_result_path = train_data_path + "/measurement_result.txt"
ideal_rho_re_path = ideal_state_path + "/rho_real.txt"
ideal_rho_im_path = ideal_state_path + "/rho_imag.txt"
meas_result, ideal_rho, meas_label, meas_pattern = data.load_data_DM(meas_result_path,
                                                                     ideal_rho_re_path,
                                                                     ideal_rho_im_path,
                                                                     meas_label_path,
                                                                     meas_pattern_path)

OSError: ./target_state//rho_real.txt not found.

## build RBM architecture

In [None]:
nn_state_dm = DensityMatrix(
    num_visible = n_visible_unit, 
    num_hidden = n_hidden_unit, 
    num_aux = n_aux_unit, 
    unitary_dict = unitaries.create_dict(),
    gpu = False
)

## estimate observable expectation value

In [None]:
class GeneralPauliDistill(ObservableBase):
    def __init__(self, pauli_dict: dict, m: int) -> None:
        self.name = "distilled_pauli"
        self.symbol = "distilled_general_pauli"
        self.pauli_dict = pauli_dict
        self.num_copy = m
        
    def apply(self, nn_state, samples):
        """
        This function calcualte <x1 x2 ... xm | rho^{\otimes m} O | xm x1 x2 ... xm-1> / <x1 x2 ... xm | rho^{\otimes m} | x1 x2 ... xm>
        where O acts only on the first register.
        """
        
        # [num_sample, num_visible_node]
        # samples = [s1, s2, s3 ... sN]
        #  where num_sample = N, and si is num_visible_node-bits
        samples = samples.to(device=nn_state.device)
        
        num_sample, num_visible_node = samples.shape
        
        # [num_sample, num_visible_node * num_copy]
        # samples_array = [[s1 sN sN-1], [s2 s1 sN], [s3 s2 s1],.. [sN sN-1 sN-2]]
        #  each row is num_copy*num_visible_node bits the above example is for num_copy=3
        samples_array = []
        for copy_index in range(self.num_copy):
            rolled_samples = torch.roll(samples, shifts=copy_index, dims=0)
            samples_array.append(rolled_samples)
        samples_array = torch.hstack(samples_array)
        assert(samples_array.shape[0] == num_sample)
        assert(samples_array.shape[1] == num_visible_node * self.num_copy)
        
        # roll second dim of [num_sample, num_visible_node * num_copy] by num_visible_node
        # swapped_samples_array = [[sN-1 s1 sN], [sN s2 s1], [s1 s3 s2],.. [sN-2 sN sN-1]]
        swapped_samples_array = torch.roll(samples_array, shifts = num_visible_node, dims=1)

        # pick copy of first block
        #  first_block_sample = [sN-1, sN, s1, s2, ... sN-2]
        first_block_sample = swapped_samples_array[:, :num_visible_node].clone()

        # calculate coefficient for first block [num_samples, 0:num_visible_node]
        total_prod = cplx.make_complex(torch.ones_like(samples[:,0]), torch.zeros_like(samples[:,0]))
        for index, pauli in self.pauli_dict.items():
            assert(index < num_visible_node)
            coeff = to_pm1(first_block_sample[:, index])
            if pauli == "Z":
                coeff = cplx.make_complex(coeff, torch.zeros_like(coeff))
                total_prod = cplx.elementwise_mult(coeff, total_prod)
            elif pauli == "Y":
                coeff = cplx.make_complex(torch.zeros_like(coeff), coeff)
                total_prod = cplx.elementwise_mult(coeff, total_prod)
        
        # flip samples for for first block [num_samples, 0:num_visible_node]
        # first_block_sample -> [OsN-1, OsN, Os1, Os2, ... OsN-2]
        #  where Osi is bit array after Pauli bit-flips 
        for index, pauli in self.pauli_dict.items():
            assert(index < num_visible_node)
            if pauli in ["X", "Y"]:
                first_block_sample = flip_spin(index, first_block_sample)


        # store flipped first block
        swapped_samples_array[:, :num_visible_node] = first_block_sample

        # calculate product of coefficients
        # samples_array = [[s1 sN sN-1], [s2 s1 sN], [s3 s2 s1],.. [sN sN-1 sN-2]]
        # swapped_samples_array = [[OsN-1 s1 sN], [OsN s2 s1], [Os1 s3 s2],.. [OsN-2 sN sN-1]]
        """
        total_prod = [
            <s1 sN sN-1 | rho^{\otimes 3} | OsN-1 s1 sN> / <s1 sN sN-1 | rho^{\otimes 3} | s1 sN sN-1> , 
            <s2 s1 sN   | rho^{\otimes 3} | OsN s2 s1>   / <s2 s1 sN   | rho^{\otimes 3} | s2 s1 sN> , 
            <s3 s2 s1   | rho^{\otimes 3} | Os1 s3 s2>   / <s3 s2 s1   | rho^{\otimes 3} | s3 s2 s1> , 

        e.g. 
        <s3 s2 s1   | rho^{\otimes 3} | Os1 s3 s2>   / <s3 s2 s1   | rho^{\otimes 3} | s3 s2 s1>
         = <s3 | rho | Os1> <s2 | rho | s3> < s1| rho | s2> / (<s3 | rho | s3> <s2 | rho | s2> < s1| rho | s1>)
         =  (<s3 | rho | Os1> / <s3 | rho | s3>)
          * (<s2 | rho | s3> / <s2 | rho | s2> )
          * (< s1| rho | s2> / < s1| rho | s1>)
         
        importance_sampling_numerator(s3, Os1)  provides <s3 | rho | Os1>
        importance_sampling_denominator(s3)     provides <s3 | rho | s3>
        """
        for copy_index in range(self.num_copy):
            st = copy_index * samples.shape[1]
            en = (copy_index+1) * samples.shape[1]
            # numerator is []
            numerator = nn_state.importance_sampling_numerator(swapped_samples_array[:, st:en], samples_array[:, st:en])
            denominator = nn_state.importance_sampling_denominator(samples_array[:, st:en])
            values = cplx.elementwise_division(numerator, denominator)
            total_prod = cplx.elementwise_mult(total_prod, values)

        value = cplx.real(total_prod)
        return value

def calculate_distilled_expectation_value(pauli_dict: dict, num_samples: int, num_copies: int):
    obs_num = GeneralPauliDistill(pauli_dict, num_copies)
    obs_div = GeneralPauliDistill({}, num_copies)
    num_stat = obs_num.statistics(nn_state_dm, num_samples=num_samples)
    div_stat = obs_div.statistics(nn_state_dm, num_samples=num_samples)

    from uncertainties import ufloat
    num = ufloat(num_stat["mean"], num_stat["std_error"])
    div = ufloat(div_stat["mean"], div_stat["std_error"])
    val = num/div
    result_dict = {"mean": val.n , "std_error": val.s, "num_samples": num_samples, "num_copies": num_copies}
    return result_dict

def get_density_matrix(nn_state):
    space = nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)
    tensor = nn_state.rho(space, space)/Z
    matrix = cplx.numpy(tensor)
    return matrix

def get_max_eigvec(matrix):
    e_val, e_vec = np.linalg.eigh(matrix)
    me_val = e_val[-1]
    me_vec = e_vec[:,-1]
    return me_vec

def get_eigvec(nn_state, obs, space, **kwargs):
    dm = get_density_matrix(nn_state)
    ev = get_max_eigvec(dm)
    ev = np.atleast_2d(ev)
    val = ev@obs@ev.T.conj()
    val = val[0,0].real
    return val

def observable_XXX():
    target_list = [0, 1, 2]
    pauli_index = [1, 1, 1] # 1:X , 2:Y, 3:Z
    gate = Pauli(target_list, pauli_index) # = X_1 X_2
    return gate.get_matrix()

def observable_XZZ():
    target_list = [0, 1, 2]
    pauli_index = [1, 3, 3] # 1:X , 2:Y, 3:Z
    gate = Pauli(target_list, pauli_index) # = X_1 Z_2
    return gate.get_matrix()

def observable_XXX_ev(nn_state, **kwargs):
    obs_stat = calculate_distilled_expectation_value({0: "X", 1: "X", 2: "X"}, n_sampling, n_copy)
    return obs_stat["mean"]

def observable_XZZ_ev(nn_state, **kwargs):
    obs_stat = calculate_distilled_expectation_value({0: "X", 1: "Z", 2: "Z"}, n_sampling, n_copy)
    return obs_stat["mean"]

## callback setting 

In [None]:
def create_callback_dm(nn_state):
    metric_dict = {
        "Fidelity": ts.fidelity,
        "KL_Divergence": ts.KL,
        "Observable_XXX_ev": observable_XXX_ev,
        "Observable_XZZ_ev": observable_XZZ_ev,
    }

    space = nn_state.generate_hilbert_space()
    callbacks = [
        MetricEvaluator(
            period,
            metric_dict,
            target = ideal_rho,
            bases = meas_pattern,
            verbose = True,
            space = space,
        )
    ]
    return callbacks

callbacks = create_callback_dm(nn_state_dm)

## train

In [None]:
nn_state_dm.fit(
    data = meas_result,
    input_bases = meas_label,
    epochs = epoch,
    pos_batch_size = pbs,
    neg_batch_size = nbs,
    lr = lr,
    k = n_gibbs_step,
    bases = meas_pattern,
    callbacks = callbacks,
    time = True,
    optimizer = torch.optim.Adadelta,
    scheduler = torch.optim.lr_scheduler.StepLR,
    scheduler_args = {"step_size": lr_drop_epoch, "gamma": lr_drop_factor},
)

In [None]:
fidelities = callbacks[0]["Fidelity"]
KLs = callbacks[0]["KL_Divergence"]
epoch_range = np.arange(period, epoch + 1, period)

fig, axs = plt.subplots(nrows = 1, ncols = 2, figsize = (16, 5))
ax = axs[0]
ax.plot(epoch_range, fidelities, "o", color = "C0", markeredgecolor = "black")
ax.set_ylabel(r"Fidelity")
ax.set_xlabel(r"Epoch")
ax.set_ylim(0.00, 1.00)

ax = axs[1]
ax.plot(epoch_range, KLs, "o", color = "C1", markeredgecolor = "black")
ax.set_ylabel(r"KL Divergence")
ax.set_xlabel(r"Epoch")

## save model & train log

In [None]:
# save model
nn_state_dm.save("./exp001/model.pt")
# save train log
train_log_df = pd.DataFrame()
train_log_df["epoch"] = np.arange(period, epoch+1, period)
train_log_df["Fidelity"] = callbacks[0]["Fidelity"]
train_log_df["KL_Divergence"] = callbacks[0]["KL_Divergence"]
train_log_df["Observable_XX_ev"] = callbacks[0]["Observable_XX_ev"]
train_log_df["Observable_XZ_ev"] = callbacks[0]["Observable_XZ_ev"]
train_log_df.to_csv("./exp001/train_log.csv", index=False)