## import module

In [1]:
import os
import subprocess
import random
import warnings
import numpy as np
from scipy.linalg import sqrtm
from scipy.stats import unitary_group
import pandas as pd
import yaml
from tqdm.notebook import tqdm
import itertools

import torch
from qucumber.nn_states import DensityMatrix
from qucumber.callbacks import MetricEvaluator
import qucumber.utils.unitaries as unitaries
import qucumber.utils.training_statistics as ts
import qucumber.utils.cplx as cplx
import qucumber.utils.data as data
from qucumber.observables import ObservableBase, to_pm1
from qucumber.observables.pauli import flip_spin
import qucumber

from qulacs.gate import Pauli

%load_ext autoreload
%autoreload 2
import utils
import gate
import measurement
import target_circuit
import dataset

## setting params

In [8]:
with open('./params_setting.yaml', 'r') as yml:
    params = yaml.safe_load(yml)
    
# quantum circuit parameter
circuit_name = params["circuit_info"]["circuit_name"]
n_qubit = params["circuit_info"]["n_qubit"]
state_class = params["circuit_info"]["state_class"]
error_model = params["circuit_info"]["error_model"]
error_rate = params["circuit_info"]["error_rate"]
each_n_shot = params["circuit_info"]["each_n_shot"]

# RBM architecture parameter
num_visible = params["architecture_info"]["n_visible_unit"]
num_hidden = params["architecture_info"]["n_hidden_unit"] 
num_aux = params["architecture_info"]["n_aux_unit"]

# train parameter
lr = params["train_info"]["lr"]
pbs = params["train_info"]["positive_batch_size"]
nbs = params["train_info"]["negative_batch_size"]
n_gibbs_step = params["train_info"]["n_gibbs_step"]
period = 1
epoch = params["train_info"]["n_epoch"]
lr_drop_epoch = params["train_info"]["lr_drop_epoch"]
lr_drop_factor = params["train_info"]["lr_drop_factor"]
use_gpu = params["train_info"]["use_gpu"]
seed = params["train_info"]["seed"]

# sampling parameter
n_sampling = params["sampling_info"]["n_sample"]
n_copy = params["sampling_info"]["n_copy"]

# data path info
environment = "local"
if environment == "local":
    train_data_path = f"./{circuit_name}/data/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
    target_state_path = f"./{circuit_name}/target_state/{state_class}/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/"
    model_path = f"./{circuit_name}/model/{state_class}/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
    train_log_path = f"./{circuit_name}/train_log/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
if environment == "colab":
    from google.colab import drive
    drive.mount("/content/drive/")
    drive_path = "/content/drive/MyDrive/NQS4VD/GHZ"
    train_data_path = drive_path + f"/{circuit_name}/data/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
    target_state_path = drive_path + f"/{circuit_name}/target_state/{state_class}/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/"
    model_path = drive_path + f"/{circuit_name}/model/{state_class}/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
    train_log_path = drive_path + f"/{circuit_name}/train_log/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"


def seed_settings(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    qucumber.set_random_seed(seed, cpu=True, gpu=False)

seed_settings(seed=seed)

## generate dataset

In [3]:
target_state = target_circuit.GHZ(n_qubit, state_class, error_model, error_rate)
utils.save_density_matrix(target_state, target_state_path)

meas_pattern_df, train_df = dataset.generate(target_state, n_qubit, error_model, each_n_shot)
dataset.save(meas_pattern_df, train_df, train_data_path)

0it [00:00, ?it/s]

measurement pattern 1/243 : ('X', 'X', 'X', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 2/243 : ('X', 'X', 'X', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 3/243 : ('X', 'X', 'X', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 4/243 : ('X', 'X', 'X', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 5/243 : ('X', 'X', 'X', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 6/243 : ('X', 'X', 'X', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 7/243 : ('X', 'X', 'X', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 8/243 : ('X', 'X', 'X', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 9/243 : ('X', 'X', 'X', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 10/243 : ('X', 'X', 'Y', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 11/243 : ('X', 'X', 'Y', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 12/243 : ('X', 'X', 'Y', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 13/243 : ('X', 'X', 'Y', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 14/243 : ('X', 'X', 'Y', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 15/243 : ('X', 'X', 'Y', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 16/243 : ('X', 'X', 'Y', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 17/243 : ('X', 'X', 'Y', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 18/243 : ('X', 'X', 'Y', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 19/243 : ('X', 'X', 'Z', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 20/243 : ('X', 'X', 'Z', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 21/243 : ('X', 'X', 'Z', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 22/243 : ('X', 'X', 'Z', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 23/243 : ('X', 'X', 'Z', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 24/243 : ('X', 'X', 'Z', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 25/243 : ('X', 'X', 'Z', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 26/243 : ('X', 'X', 'Z', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 27/243 : ('X', 'X', 'Z', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 28/243 : ('X', 'Y', 'X', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 29/243 : ('X', 'Y', 'X', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 30/243 : ('X', 'Y', 'X', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 31/243 : ('X', 'Y', 'X', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 32/243 : ('X', 'Y', 'X', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 33/243 : ('X', 'Y', 'X', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 34/243 : ('X', 'Y', 'X', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 35/243 : ('X', 'Y', 'X', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 36/243 : ('X', 'Y', 'X', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 37/243 : ('X', 'Y', 'Y', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 38/243 : ('X', 'Y', 'Y', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 39/243 : ('X', 'Y', 'Y', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 40/243 : ('X', 'Y', 'Y', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 41/243 : ('X', 'Y', 'Y', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 42/243 : ('X', 'Y', 'Y', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 43/243 : ('X', 'Y', 'Y', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 44/243 : ('X', 'Y', 'Y', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 45/243 : ('X', 'Y', 'Y', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 46/243 : ('X', 'Y', 'Z', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 47/243 : ('X', 'Y', 'Z', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 48/243 : ('X', 'Y', 'Z', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 49/243 : ('X', 'Y', 'Z', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 50/243 : ('X', 'Y', 'Z', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 51/243 : ('X', 'Y', 'Z', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 52/243 : ('X', 'Y', 'Z', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 53/243 : ('X', 'Y', 'Z', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 54/243 : ('X', 'Y', 'Z', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 55/243 : ('X', 'Z', 'X', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 56/243 : ('X', 'Z', 'X', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 57/243 : ('X', 'Z', 'X', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 58/243 : ('X', 'Z', 'X', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 59/243 : ('X', 'Z', 'X', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 60/243 : ('X', 'Z', 'X', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 61/243 : ('X', 'Z', 'X', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 62/243 : ('X', 'Z', 'X', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 63/243 : ('X', 'Z', 'X', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 64/243 : ('X', 'Z', 'Y', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 65/243 : ('X', 'Z', 'Y', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 66/243 : ('X', 'Z', 'Y', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 67/243 : ('X', 'Z', 'Y', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 68/243 : ('X', 'Z', 'Y', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 69/243 : ('X', 'Z', 'Y', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 70/243 : ('X', 'Z', 'Y', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 71/243 : ('X', 'Z', 'Y', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 72/243 : ('X', 'Z', 'Y', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 73/243 : ('X', 'Z', 'Z', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 74/243 : ('X', 'Z', 'Z', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 75/243 : ('X', 'Z', 'Z', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 76/243 : ('X', 'Z', 'Z', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 77/243 : ('X', 'Z', 'Z', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 78/243 : ('X', 'Z', 'Z', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 79/243 : ('X', 'Z', 'Z', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 80/243 : ('X', 'Z', 'Z', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 81/243 : ('X', 'Z', 'Z', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 82/243 : ('Y', 'X', 'X', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 83/243 : ('Y', 'X', 'X', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 84/243 : ('Y', 'X', 'X', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 85/243 : ('Y', 'X', 'X', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 86/243 : ('Y', 'X', 'X', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 87/243 : ('Y', 'X', 'X', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 88/243 : ('Y', 'X', 'X', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 89/243 : ('Y', 'X', 'X', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 90/243 : ('Y', 'X', 'X', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 91/243 : ('Y', 'X', 'Y', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 92/243 : ('Y', 'X', 'Y', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 93/243 : ('Y', 'X', 'Y', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 94/243 : ('Y', 'X', 'Y', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 95/243 : ('Y', 'X', 'Y', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 96/243 : ('Y', 'X', 'Y', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 97/243 : ('Y', 'X', 'Y', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 98/243 : ('Y', 'X', 'Y', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 99/243 : ('Y', 'X', 'Y', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 100/243 : ('Y', 'X', 'Z', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 101/243 : ('Y', 'X', 'Z', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 102/243 : ('Y', 'X', 'Z', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 103/243 : ('Y', 'X', 'Z', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 104/243 : ('Y', 'X', 'Z', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 105/243 : ('Y', 'X', 'Z', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 106/243 : ('Y', 'X', 'Z', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 107/243 : ('Y', 'X', 'Z', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 108/243 : ('Y', 'X', 'Z', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 109/243 : ('Y', 'Y', 'X', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 110/243 : ('Y', 'Y', 'X', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 111/243 : ('Y', 'Y', 'X', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 112/243 : ('Y', 'Y', 'X', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 113/243 : ('Y', 'Y', 'X', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 114/243 : ('Y', 'Y', 'X', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 115/243 : ('Y', 'Y', 'X', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 116/243 : ('Y', 'Y', 'X', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 117/243 : ('Y', 'Y', 'X', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 118/243 : ('Y', 'Y', 'Y', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 119/243 : ('Y', 'Y', 'Y', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 120/243 : ('Y', 'Y', 'Y', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 121/243 : ('Y', 'Y', 'Y', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 122/243 : ('Y', 'Y', 'Y', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 123/243 : ('Y', 'Y', 'Y', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 124/243 : ('Y', 'Y', 'Y', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 125/243 : ('Y', 'Y', 'Y', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 126/243 : ('Y', 'Y', 'Y', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 127/243 : ('Y', 'Y', 'Z', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 128/243 : ('Y', 'Y', 'Z', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 129/243 : ('Y', 'Y', 'Z', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 130/243 : ('Y', 'Y', 'Z', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 131/243 : ('Y', 'Y', 'Z', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 132/243 : ('Y', 'Y', 'Z', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 133/243 : ('Y', 'Y', 'Z', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 134/243 : ('Y', 'Y', 'Z', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 135/243 : ('Y', 'Y', 'Z', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 136/243 : ('Y', 'Z', 'X', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 137/243 : ('Y', 'Z', 'X', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 138/243 : ('Y', 'Z', 'X', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 139/243 : ('Y', 'Z', 'X', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 140/243 : ('Y', 'Z', 'X', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 141/243 : ('Y', 'Z', 'X', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 142/243 : ('Y', 'Z', 'X', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 143/243 : ('Y', 'Z', 'X', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 144/243 : ('Y', 'Z', 'X', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 145/243 : ('Y', 'Z', 'Y', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 146/243 : ('Y', 'Z', 'Y', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 147/243 : ('Y', 'Z', 'Y', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 148/243 : ('Y', 'Z', 'Y', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 149/243 : ('Y', 'Z', 'Y', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 150/243 : ('Y', 'Z', 'Y', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 151/243 : ('Y', 'Z', 'Y', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 152/243 : ('Y', 'Z', 'Y', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 153/243 : ('Y', 'Z', 'Y', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 154/243 : ('Y', 'Z', 'Z', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 155/243 : ('Y', 'Z', 'Z', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 156/243 : ('Y', 'Z', 'Z', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 157/243 : ('Y', 'Z', 'Z', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 158/243 : ('Y', 'Z', 'Z', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 159/243 : ('Y', 'Z', 'Z', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 160/243 : ('Y', 'Z', 'Z', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 161/243 : ('Y', 'Z', 'Z', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 162/243 : ('Y', 'Z', 'Z', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 163/243 : ('Z', 'X', 'X', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 164/243 : ('Z', 'X', 'X', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 165/243 : ('Z', 'X', 'X', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 166/243 : ('Z', 'X', 'X', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 167/243 : ('Z', 'X', 'X', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 168/243 : ('Z', 'X', 'X', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 169/243 : ('Z', 'X', 'X', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 170/243 : ('Z', 'X', 'X', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 171/243 : ('Z', 'X', 'X', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 172/243 : ('Z', 'X', 'Y', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 173/243 : ('Z', 'X', 'Y', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 174/243 : ('Z', 'X', 'Y', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 175/243 : ('Z', 'X', 'Y', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 176/243 : ('Z', 'X', 'Y', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 177/243 : ('Z', 'X', 'Y', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 178/243 : ('Z', 'X', 'Y', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 179/243 : ('Z', 'X', 'Y', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 180/243 : ('Z', 'X', 'Y', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 181/243 : ('Z', 'X', 'Z', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 182/243 : ('Z', 'X', 'Z', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 183/243 : ('Z', 'X', 'Z', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 184/243 : ('Z', 'X', 'Z', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 185/243 : ('Z', 'X', 'Z', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 186/243 : ('Z', 'X', 'Z', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 187/243 : ('Z', 'X', 'Z', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 188/243 : ('Z', 'X', 'Z', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 189/243 : ('Z', 'X', 'Z', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 190/243 : ('Z', 'Y', 'X', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 191/243 : ('Z', 'Y', 'X', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 192/243 : ('Z', 'Y', 'X', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 193/243 : ('Z', 'Y', 'X', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 194/243 : ('Z', 'Y', 'X', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 195/243 : ('Z', 'Y', 'X', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 196/243 : ('Z', 'Y', 'X', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 197/243 : ('Z', 'Y', 'X', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 198/243 : ('Z', 'Y', 'X', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 199/243 : ('Z', 'Y', 'Y', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 200/243 : ('Z', 'Y', 'Y', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 201/243 : ('Z', 'Y', 'Y', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 202/243 : ('Z', 'Y', 'Y', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 203/243 : ('Z', 'Y', 'Y', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 204/243 : ('Z', 'Y', 'Y', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 205/243 : ('Z', 'Y', 'Y', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 206/243 : ('Z', 'Y', 'Y', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 207/243 : ('Z', 'Y', 'Y', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 208/243 : ('Z', 'Y', 'Z', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 209/243 : ('Z', 'Y', 'Z', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 210/243 : ('Z', 'Y', 'Z', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 211/243 : ('Z', 'Y', 'Z', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 212/243 : ('Z', 'Y', 'Z', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 213/243 : ('Z', 'Y', 'Z', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 214/243 : ('Z', 'Y', 'Z', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 215/243 : ('Z', 'Y', 'Z', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 216/243 : ('Z', 'Y', 'Z', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 217/243 : ('Z', 'Z', 'X', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 218/243 : ('Z', 'Z', 'X', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 219/243 : ('Z', 'Z', 'X', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 220/243 : ('Z', 'Z', 'X', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 221/243 : ('Z', 'Z', 'X', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 222/243 : ('Z', 'Z', 'X', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 223/243 : ('Z', 'Z', 'X', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 224/243 : ('Z', 'Z', 'X', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 225/243 : ('Z', 'Z', 'X', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 226/243 : ('Z', 'Z', 'Y', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 227/243 : ('Z', 'Z', 'Y', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 228/243 : ('Z', 'Z', 'Y', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 229/243 : ('Z', 'Z', 'Y', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 230/243 : ('Z', 'Z', 'Y', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 231/243 : ('Z', 'Z', 'Y', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 232/243 : ('Z', 'Z', 'Y', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 233/243 : ('Z', 'Z', 'Y', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 234/243 : ('Z', 'Z', 'Y', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 235/243 : ('Z', 'Z', 'Z', 'X', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 236/243 : ('Z', 'Z', 'Z', 'X', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 237/243 : ('Z', 'Z', 'Z', 'X', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 238/243 : ('Z', 'Z', 'Z', 'Y', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 239/243 : ('Z', 'Z', 'Z', 'Y', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 240/243 : ('Z', 'Z', 'Z', 'Y', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 241/243 : ('Z', 'Z', 'Z', 'Z', 'X')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 242/243 : ('Z', 'Z', 'Z', 'Z', 'Y')


  0%|          | 0/1000 [00:00<?, ?it/s]

measurement pattern 243/243 : ('Z', 'Z', 'Z', 'Z', 'Z')


  0%|          | 0/1000 [00:00<?, ?it/s]

## load dataset

In [9]:
meas_result, target_rho, meas_label, meas_pattern = utils.load_dataset_DM(train_data_path, target_state_path)

## callback settings

In [10]:
n_on_epoch = 1
def save_model(nn_state, **kwargs):
    global n_on_epoch
    os.makedirs(model_path, exist_ok = True)
    nn_state.save(model_path + f"epoch{n_on_epoch}_model.pt")
    n_on_epoch = n_on_epoch + 1

def F_ideal_train(nn_state, **kwargs):
    save_model(nn_state)
    ideal_state = target_circuit.GHZ(n_qubit, state_class, "ideal", error_rate)
    train_state = utils.get_density_matrix(nn_state)
    F = np.trace(sqrtm(sqrtm(ideal_state)@train_state@sqrtm(ideal_state)))
    
    return (F.real)**2

def F_noisy_train(nn_state, **kwargs):
    noisy_state = target_circuit.GHZ(n_qubit, state_class, error_model, error_rate)
    train_state = utils.get_density_matrix(nn_state)
    F = np.trace(sqrtm(sqrtm(noisy_state)@train_state@sqrtm(noisy_state)))
    
    return (F.real)**2
    
def F_ideal_mevec(nn_state, **kwargs):
    ideal_state = target_circuit.GHZ(n_qubit, state_class, "ideal", error_rate)
    train_state = utils.get_density_matrix(nn_state)
    max_eigen_state = utils.get_max_eigen_vector(train_state)
    F = max_eigen_state.T.conjugate()@ideal_state@max_eigen_state
    
    return F.real

def create_callback(nn_state):
    metric_dict = {
        "ideal_train": F_ideal_train,
        "noisy_train": F_noisy_train,
        "ideal_mevec": F_ideal_mevec,
        "KL_Divergence": ts.KL,
    }
    space = nn_state.generate_hilbert_space()
    callbacks = [
        MetricEvaluator(
            period,
            metric_dict,
            target = target_rho,
            bases = meas_pattern,
            verbose = True,
            space = space,
        )
    ]
    
    return callbacks

In [11]:
nn_state = DensityMatrix(num_visible = num_visible, num_hidden = num_hidden, num_aux = num_aux, unitary_dict = unitaries.create_dict(), gpu = use_gpu)
callbacks = create_callback(nn_state)

## train

In [None]:
nn_state.fit(
    data = meas_result,
    input_bases = meas_label,
    epochs = epoch,
    pos_batch_size = pbs,
    neg_batch_size = nbs,
    lr = lr,
    k = n_gibbs_step,
    bases = meas_pattern,
    callbacks = callbacks,
    time = True,
    optimizer = torch.optim.Adadelta,
    scheduler = torch.optim.lr_scheduler.StepLR,
    scheduler_args = {"step_size": lr_drop_epoch, "gamma": lr_drop_factor},
)

Epoch: 1	ideal_train = 0.361920	noisy_train = 0.528212	ideal_mevec = 0.489635	KL_Divergence = 0.125390
Epoch: 2	ideal_train = 0.423057	noisy_train = 0.559070	ideal_mevec = 0.520877	KL_Divergence = 0.041580
Epoch: 3	ideal_train = 0.437252	noisy_train = 0.565528	ideal_mevec = 0.501041	KL_Divergence = 0.044265
Epoch: 4	ideal_train = 0.443716	noisy_train = 0.565243	ideal_mevec = 0.501194	KL_Divergence = 0.038650
Epoch: 5	ideal_train = 0.432488	noisy_train = 0.565814	ideal_mevec = 0.498736	KL_Divergence = 0.042207
Epoch: 6	ideal_train = 0.436882	noisy_train = 0.568344	ideal_mevec = 0.500207	KL_Divergence = 0.036975
Epoch: 7	ideal_train = 0.445802	noisy_train = 0.566131	ideal_mevec = 0.498790	KL_Divergence = 0.040333
Epoch: 8	ideal_train = 0.457371	noisy_train = 0.564230	ideal_mevec = 0.500021	KL_Divergence = 0.142459
Epoch: 9	ideal_train = 0.440467	noisy_train = 0.570834	ideal_mevec = 0.500313	KL_Divergence = 0.037748
Epoch: 10	ideal_train = 0.425336	noisy_train = 0.574405	ideal_mevec = 0.4

## save train log

In [None]:
os.makedirs(train_log_path, exist_ok = True)
train_log_df = pd.DataFrame()
train_log_df["epoch"] = np.arange(1, epoch+1, period)
train_log_df["F_ideal_train"] = callbacks[0]["ideal_train"]
train_log_df["F_noisy_train"] = callbacks[0]["noisy_train"]
train_log_df["F_ideal_meve"] = callbacks[0]["ideal_mevec"]
train_log_df["KL_Divergence"] = callbacks[0]["KL_Divergence"]
train_log_df.to_csv(train_log_path + "train_log.csv", index=False)