## import module

In [1]:
import os
import subprocess
import random
import warnings
import numpy as np
from scipy.linalg import sqrtm
from scipy.stats import unitary_group
import pandas as pd
import yaml
from tqdm.notebook import tqdm
import itertools

import torch
from qucumber.nn_states import DensityMatrix
from qucumber.callbacks import MetricEvaluator
import qucumber.utils.unitaries as unitaries
import qucumber.utils.training_statistics as ts
import qucumber.utils.cplx as cplx
import qucumber.utils.data as data
from qucumber.observables import ObservableBase, to_pm1
from qucumber.observables.pauli import flip_spin
import qucumber

from qulacs.gate import Pauli

%load_ext autoreload
%autoreload 2
import utils
import gate
import measurement
import target_circuit
import dataset

## setting params

In [2]:
with open('./params_setting.yaml', 'r') as yml:
    params = yaml.safe_load(yml)
    
# quantum circuit parameter
circuit_name = params["circuit_info"]["circuit_name"]
n_qubit = params["circuit_info"]["n_qubit"]
state_class = params["circuit_info"]["state_class"]
error_model = params["circuit_info"]["error_model"]
error_rate = params["circuit_info"]["error_rate"]
each_n_shot = params["circuit_info"]["each_n_shot"]

# RBM architecture parameter
num_visible = params["architecture_info"]["n_visible_unit"]
num_hidden = params["architecture_info"]["n_hidden_unit"] 
num_aux = params["architecture_info"]["n_aux_unit"]

# train parameter
lr = params["train_info"]["lr"]
pbs = params["train_info"]["positive_batch_size"]
nbs = params["train_info"]["negative_batch_size"]
n_gibbs_step = params["train_info"]["n_gibbs_step"]
period = 1
epoch = params["train_info"]["n_epoch"]
lr_drop_epoch = params["train_info"]["lr_drop_epoch"]
lr_drop_factor = params["train_info"]["lr_drop_factor"]
use_gpu = params["train_info"]["use_gpu"]
seed = params["train_info"]["seed"]

# sampling parameter
n_sampling = params["sampling_info"]["n_sample"]
n_copy = params["sampling_info"]["n_copy"]

# data path info
environment = "local"
if environment == "local":
    train_data_path = f"./{circuit_name}/data/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
    target_state_path = f"./{circuit_name}/target_state/{state_class}/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/"
    model_path = f"./{circuit_name}/model/{state_class}/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
    train_log_path = f"./{circuit_name}/train_log/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
if environment == "colab":
    from google.colab import drive
    drive.mount("/content/drive/")
    drive_path = "/content/drive/MyDrive/NQS4VD/GHZ"
    train_data_path = drive_path + f"/{circuit_name}/data/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
    target_state_path = drive_path + f"/{circuit_name}/target_state/{state_class}/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/"
    model_path = drive_path + f"/{circuit_name}/model/{state_class}/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"
    train_log_path = drive_path + f"/{circuit_name}/train_log/{n_qubit}-qubit/{error_model}/p={100*error_rate}%/each_{each_n_shot}_shot/"


def seed_settings(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    qucumber.set_random_seed(seed, cpu=True, gpu=False)

seed_settings(seed=seed)

## generate dataset

In [3]:
target_state = target_circuit.GHZ(n_qubit, state_class, error_model, error_rate)
utils.save_density_matrix(target_state, target_state_path)
meas_pattern_df, train_df = dataset.generate(target_state, n_qubit, error_model, each_n_shot)
dataset.save(meas_pattern_df, train_df, train_data_path)

0it [00:00, ?it/s]

measurement pattern 1/27 : ('X', 'X', 'X')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 31%|████████████████████▊                                              | 310/1000 [00:00<00:00, 3097.26it/s][A
 62%|█████████████████████████████████████████▌                         | 620/1000 [00:00<00:00, 3045.10it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 3001.34it/s][A
1it [00:00,  2.97it/s]

measurement pattern 2/27 : ('X', 'X', 'Y')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 30%|████████████████████                                               | 300/1000 [00:00<00:00, 2992.41it/s][A
 60%|████████████████████████████████████████▏                          | 600/1000 [00:00<00:00, 2938.96it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2913.13it/s][A
2it [00:00,  2.92it/s]

measurement pattern 3/27 : ('X', 'X', 'Z')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 30%|███████████████████▊                                               | 296/1000 [00:00<00:00, 2948.23it/s][A
 60%|███████████████████████████████████████▉                           | 597/1000 [00:00<00:00, 2972.44it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 3031.98it/s][A
3it [00:01,  2.95it/s]

measurement pattern 4/27 : ('X', 'Y', 'X')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 30%|████████████████████▍                                              | 305/1000 [00:00<00:00, 3048.25it/s][A
 61%|████████████████████████████████████████▊                          | 610/1000 [00:00<00:00, 2865.23it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2896.80it/s][A
4it [00:01,  2.92it/s]

measurement pattern 5/27 : ('X', 'Y', 'Y')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 30%|████████████████████▏                                              | 301/1000 [00:00<00:00, 3006.48it/s][A
 60%|████████████████████████████████████████▎                          | 602/1000 [00:00<00:00, 2789.82it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2848.80it/s][A
5it [00:01,  2.88it/s]

measurement pattern 6/27 : ('X', 'Y', 'Z')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 29%|███████████████████▎                                               | 288/1000 [00:00<00:00, 2873.70it/s][A
 58%|██████████████████████████████████████▌                            | 576/1000 [00:00<00:00, 2813.76it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2802.76it/s][A
6it [00:02,  2.84it/s]

measurement pattern 7/27 : ('X', 'Z', 'X')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 32%|█████████████████████                                              | 315/1000 [00:00<00:00, 3148.30it/s][A
 63%|██████████████████████████████████████████▏                        | 630/1000 [00:00<00:00, 2962.13it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2922.33it/s][A
7it [00:02,  2.86it/s]

measurement pattern 8/27 : ('X', 'Z', 'Y')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 30%|████████████████████                                               | 299/1000 [00:00<00:00, 2985.17it/s][A
 60%|████████████████████████████████████████                           | 598/1000 [00:00<00:00, 2926.60it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2946.67it/s][A
8it [00:02,  2.88it/s]

measurement pattern 9/27 : ('X', 'Z', 'Z')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 32%|█████████████████████▏                                             | 316/1000 [00:00<00:00, 3159.29it/s][A
 63%|██████████████████████████████████████████▎                        | 632/1000 [00:00<00:00, 3078.87it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2993.62it/s][A
9it [00:03,  2.90it/s]

measurement pattern 10/27 : ('Y', 'X', 'X')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 26%|█████████████████▍                                                 | 260/1000 [00:00<00:00, 2595.91it/s][A
 54%|███████████████████████████████████▊                               | 535/1000 [00:00<00:00, 2673.63it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2744.10it/s][A
10it [00:03,  2.84it/s]

measurement pattern 11/27 : ('Y', 'X', 'Y')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 29%|███████████████████▏                                               | 286/1000 [00:00<00:00, 2853.73it/s][A
 57%|██████████████████████████████████████▎                            | 572/1000 [00:00<00:00, 2795.82it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2813.43it/s][A
11it [00:03,  2.82it/s]

measurement pattern 12/27 : ('Y', 'X', 'Z')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 29%|███████████████████▍                                               | 291/1000 [00:00<00:00, 2903.41it/s][A
 58%|██████████████████████████████████████▉                            | 582/1000 [00:00<00:00, 2623.92it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2699.24it/s][A
12it [00:04,  2.78it/s]

measurement pattern 13/27 : ('Y', 'Y', 'X')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 29%|███████████████████▏                                               | 286/1000 [00:00<00:00, 2857.92it/s][A
 57%|██████████████████████████████████████▎                            | 572/1000 [00:00<00:00, 2522.27it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2654.15it/s][A
13it [00:04,  2.73it/s]

measurement pattern 14/27 : ('Y', 'Y', 'Y')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 29%|███████████████████▏                                               | 287/1000 [00:00<00:00, 2868.75it/s][A
 57%|██████████████████████████████████████▍                            | 574/1000 [00:00<00:00, 2842.15it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2811.02it/s][A
14it [00:04,  2.75it/s]

measurement pattern 15/27 : ('Y', 'Y', 'Z')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 29%|███████████████████▏                                               | 287/1000 [00:00<00:00, 2863.94it/s][A
 57%|██████████████████████████████████████▍                            | 574/1000 [00:00<00:00, 2805.93it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2768.99it/s][A
15it [00:05,  2.74it/s]

measurement pattern 16/27 : ('Y', 'Z', 'X')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 28%|██████████████████▋                                                | 278/1000 [00:00<00:00, 2775.23it/s][A
 57%|██████████████████████████████████████▎                            | 571/1000 [00:00<00:00, 2858.49it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2829.07it/s][A
16it [00:05,  2.76it/s]

measurement pattern 17/27 : ('Y', 'Z', 'Y')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 29%|███████████████████▎                                               | 288/1000 [00:00<00:00, 2878.84it/s][A
 58%|██████████████████████████████████████▌                            | 576/1000 [00:00<00:00, 2805.57it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2793.38it/s][A
17it [00:06,  2.76it/s]

measurement pattern 18/27 : ('Y', 'Z', 'Z')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 28%|███████████████████                                                | 285/1000 [00:00<00:00, 2843.08it/s][A
 57%|██████████████████████████████████████▏                            | 570/1000 [00:00<00:00, 2830.32it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2804.25it/s][A
18it [00:06,  2.77it/s]

measurement pattern 19/27 : ('Z', 'X', 'X')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 31%|████████████████████▊                                              | 311/1000 [00:00<00:00, 3106.17it/s][A
 62%|█████████████████████████████████████████▋                         | 622/1000 [00:00<00:00, 2987.74it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 3021.15it/s][A
19it [00:06,  2.83it/s]

measurement pattern 20/27 : ('Z', 'X', 'Y')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 28%|██████████████████▌                                                | 277/1000 [00:00<00:00, 2763.44it/s][A
 56%|█████████████████████████████████████▊                             | 564/1000 [00:00<00:00, 2817.80it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2843.46it/s][A
20it [00:07,  2.82it/s]

measurement pattern 21/27 : ('Z', 'X', 'Z')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 32%|█████████████████████▋                                             | 323/1000 [00:00<00:00, 3227.18it/s][A
 65%|███████████████████████████████████████████▎                       | 646/1000 [00:00<00:00, 3112.43it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 3098.90it/s][A
21it [00:07,  2.89it/s]

measurement pattern 22/27 : ('Z', 'Y', 'X')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 30%|███████████████████▊                                               | 295/1000 [00:00<00:00, 2948.15it/s][A
 59%|███████████████████████████████████████▌                           | 590/1000 [00:00<00:00, 2849.52it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2879.60it/s][A
22it [00:07,  2.88it/s]

measurement pattern 23/27 : ('Z', 'Y', 'Y')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 28%|██████████████████▉                                                | 282/1000 [00:00<00:00, 2819.85it/s][A
 57%|██████████████████████████████████████▏                            | 570/1000 [00:00<00:00, 2854.96it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2834.12it/s][A
23it [00:08,  2.86it/s]

measurement pattern 24/27 : ('Z', 'Y', 'Z')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 30%|███████████████████▊                                               | 295/1000 [00:00<00:00, 2947.87it/s][A
 59%|███████████████████████████████████████▌                           | 590/1000 [00:00<00:00, 2906.76it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2920.30it/s][A
24it [00:08,  2.87it/s]

measurement pattern 25/27 : ('Z', 'Z', 'X')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 32%|█████████████████████▎                                             | 318/1000 [00:00<00:00, 3174.24it/s][A
 64%|██████████████████████████████████████████▌                        | 636/1000 [00:00<00:00, 3010.08it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 3067.35it/s][A
25it [00:08,  2.91it/s]

measurement pattern 26/27 : ('Z', 'Z', 'Y')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 31%|████████████████████▊                                              | 311/1000 [00:00<00:00, 3107.60it/s][A
 62%|█████████████████████████████████████████▋                         | 622/1000 [00:00<00:00, 2987.65it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2997.99it/s][A
26it [00:09,  2.93it/s]

measurement pattern 27/27 : ('Z', 'Z', 'Z')



  0%|                                                                               | 0/1000 [00:00<?, ?it/s][A
 31%|████████████████████▊                                              | 310/1000 [00:00<00:00, 3099.77it/s][A
 62%|█████████████████████████████████████████▌                         | 620/1000 [00:00<00:00, 3073.72it/s][A
100%|██████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 3093.13it/s][A
27it [00:09,  2.85it/s]


## load dataset

In [4]:
meas_result, target_rho, meas_label, meas_pattern = utils.load_dataset_DM(train_data_path, target_state_path)

## callback settings

In [5]:
n_on_epoch = 1
def save_model(nn_state, **kwargs):
    global n_on_epoch
    os.makedirs(model_path, exist_ok = True)
    nn_state.save(model_path + f"epoch{n_on_epoch}_model.pt")
    n_on_epoch = n_on_epoch + 1

def F_ideal_train(nn_state, **kwargs):
    save_model(nn_state)
    ideal_state = target_circuit.GHZ(n_qubit, state_class, "ideal", error_rate)
    train_state = utils.get_density_matrix(nn_state)
    F = np.trace(sqrtm(sqrtm(ideal_state)@train_state@sqrtm(ideal_state)))
    
    return (F.real)**2

def F_noisy_train(nn_state, **kwargs):
    noisy_state = target_circuit.GHZ(n_qubit, state_class, error_model, error_rate)
    train_state = utils.get_density_matrix(nn_state)
    F = np.trace(sqrtm(sqrtm(noisy_state)@train_state@sqrtm(noisy_state)))
    
    return (F.real)**2
    
def F_ideal_mevec(nn_state, **kwargs):
    ideal_state = target_circuit.GHZ(n_qubit, state_class, "ideal", error_rate)
    train_state = utils.get_density_matrix(nn_state)
    max_eigen_state = utils.get_max_eigen_vector(train_state)
    F = max_eigen_state.T.conjugate()@ideal_state@max_eigen_state
    
    return F.real

def create_callback(nn_state):
    metric_dict = {
        "ideal_train": F_ideal_train,
        "noisy_train": F_noisy_train,
        "ideal_mevec": F_ideal_mevec,
        "KL_Divergence": ts.KL,
    }
    space = nn_state.generate_hilbert_space()
    callbacks = [
        MetricEvaluator(
            period,
            metric_dict,
            target = target_rho,
            bases = meas_pattern,
            verbose = True,
            space = space,
        )
    ]
    
    return callbacks

In [6]:
nn_state = DensityMatrix(num_visible = num_visible, num_hidden = num_hidden, num_aux = num_aux, unitary_dict = unitaries.create_dict(), gpu = use_gpu)
callbacks = create_callback(nn_state)

## train

In [None]:
nn_state.fit(
    data = meas_result,
    input_bases = meas_label,
    epochs = epoch,
    pos_batch_size = pbs,
    neg_batch_size = nbs,
    lr = lr,
    k = n_gibbs_step,
    bases = meas_pattern,
    callbacks = callbacks,
    time = True,
    optimizer = torch.optim.Adadelta,
    scheduler = torch.optim.lr_scheduler.StepLR,
    scheduler_args = {"step_size": lr_drop_epoch, "gamma": lr_drop_factor},
)

Epoch: 1	ideal_train = 0.431050	noisy_train = 0.582207	ideal_mevec = 0.515121	KL_Divergence = 0.352396
Epoch: 2	ideal_train = 0.425227	noisy_train = 0.626822	ideal_mevec = 0.569028	KL_Divergence = 0.130827
Epoch: 3	ideal_train = 0.459158	noisy_train = 0.639415	ideal_mevec = 0.553979	KL_Divergence = 0.153176
Epoch: 4	ideal_train = 0.461429	noisy_train = 0.656081	ideal_mevec = 0.646839	KL_Divergence = 0.088146
Epoch: 5	ideal_train = 0.467399	noisy_train = 0.667897	ideal_mevec = 0.992896	KL_Divergence = 0.070607
Epoch: 6	ideal_train = 0.471788	noisy_train = 0.670245	ideal_mevec = 0.884132	KL_Divergence = 0.070946
Epoch: 7	ideal_train = 0.478334	noisy_train = 0.674618	ideal_mevec = 0.840206	KL_Divergence = 0.068404
Epoch: 8	ideal_train = 0.480399	noisy_train = 0.676550	ideal_mevec = 0.939291	KL_Divergence = 0.067738
Epoch: 9	ideal_train = 0.486001	noisy_train = 0.686975	ideal_mevec = 0.763833	KL_Divergence = 0.072758
Epoch: 10	ideal_train = 0.494400	noisy_train = 0.689414	ideal_mevec = 0.6

## save train log

In [None]:
os.makedirs(train_log_path, exist_ok = True)
train_log_df = pd.DataFrame()
train_log_df["epoch"] = np.arange(1, epoch+1, period)
train_log_df["F_ideal_train"] = callbacks[0]["ideal_train"]
train_log_df["F_noisy_train"] = callbacks[0]["noisy_train"]
train_log_df["F_ideal_meve"] = callbacks[0]["ideal_mevec"]
train_log_df["KL_Divergence"] = callbacks[0]["KL_Divergence"]
train_log_df.to_csv(train_log_path + "train_log.csv", index=False)