Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State sample improvement #5

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
af09251
Adding a white space
alonkukl Aug 9, 2023
f4c3fe2
Adding the single edge tensor - find better name..
alonkukl Aug 15, 2023
b90baf4
also in the init file
alonkukl Aug 15, 2023
34b0c7e
fixning a typo bug
alonkukl Aug 15, 2023
c2be18e
Declearing the correct reqired type
alonkukl Aug 15, 2023
d62a78a
initial implementation of the QFactorQG
alonkukl Aug 15, 2023
167d8cd
Adding more assertions
alonkukl Aug 15, 2023
88d25be
implemented the use of validations states
alonkukl Aug 15, 2023
458fcd3
Adding a stoping critiria using the val cost
alonkukl Aug 15, 2023
c4bdf8a
Refactoring the generation of random states
alonkukl Sep 7, 2023
93a4ce7
Adding a seperate file for the state sampling
alonkukl Sep 12, 2023
c1b9873
Removing from the main file
alonkukl Sep 25, 2023
9f413c9
Before pre-commits fixs
alonkukl Oct 3, 2023
52afa01
restoring qfactor_jax
alonkukl Oct 3, 2023
efcef6f
Merge branch 'main' into state-sample-improvment
alonkukl Oct 3, 2023
9307ffa
Fixing all pre-commit issues
alonkukl Oct 3, 2023
232c139
Fixing pre-commit issues
alonkukl Oct 3, 2023
3c44d6f
Adding the example for the sampling case
alonkukl Oct 3, 2023
c97e967
Fixing some typos, and adding missing methods
alonkukl Oct 3, 2023
c62c979
pre=commit fixes
alonkukl Oct 3, 2023
0af6e81
Toffoli example now working
alonkukl Oct 3, 2023
9a3fc46
Fixing pre-commit
alonkukl Oct 3, 2023
1f8f274
Now also testing the toffoli with sampling
alonkukl Oct 3, 2023
3da5fe0
Updating the copyrights
alonkukl Oct 6, 2023
758fb93
Updating the package name
alonkukl Oct 6, 2023
87a3b1e
Fix a bug with type casting
alonkukl Oct 16, 2023
bd6eb06
Adding example that demonstrates the speed boost
alonkukl Oct 16, 2023
a6e6fef
Fixing the wrong convertion
alonkukl Oct 16, 2023
855d8d6
Renaming the Class
alonkukl Oct 19, 2023
688af23
some typos fix
alonkukl Oct 23, 2023
a9cd7b3
Adding plateau detection
alonkukl Oct 24, 2023
c1e73e7
pre-commit fixes
alonkukl Oct 25, 2023
5a10d00
Adding params to the file
alonkukl Oct 25, 2023
681a188
Merge branch 'main' of https://github.com/BQSKit/bqskit-gpu into stat…
alonkukl Oct 25, 2023
2ea0756
Fixing the name of the pakcage
alonkukl Oct 27, 2023
14f798a
fixing the exported name
alonkukl Oct 27, 2023
32227eb
Using dims
alonkukl Oct 28, 2023
2ddef1a
Update default params
alonkukl Nov 12, 2023
3ce906d
Saving thr amount the training states needed for instantiation
alonkukl Dec 8, 2023
16e4ea5
Changing the name of the logger, so the runtime will FWD it
alonkukl Dec 18, 2023
e908891
Stopping the instantiation if trying to use more 2^num_qudits states
alonkukl Dec 18, 2023
770a2ea
Pre-commit fixes
alonkukl Dec 19, 2023
d455977
Change print to log debug
alonkukl Dec 19, 2023
260a8ba
1. Recursively split by 2 the num_starts when OMM
alonkukl Dec 19, 2023
c3b224c
pre-commit fixes
alonkukl Dec 19, 2023
256f863
Adding some doc
alonkukl Feb 15, 2024
34f4df0
Updating the references and year
alonkukl Jul 1, 2024
d936935
Merge remote-tracking branch 'origin/main' into state-sample-improvment
alonkukl Jul 1, 2024
dabb9dd
Ignoring a type error, this needs to be fixed...
alonkukl Jul 1, 2024
cef123a
pro-commit fix
alonkukl Jul 1, 2024
076bc7b
Enabling 64bit to have good accuracy
alonkukl Jul 1, 2024
558a074
Updating the qasms used for the reinst example
alonkukl Jul 2, 2024
e3c1d99
pre-commit fix
alonkukl Jul 2, 2024
04d1f55
renaming
alonkukl Jul 2, 2024
f5c4a00
Removing the sh script
alonkukl Jul 2, 2024
759d13d
Setting the JAX env to only use CPUs
alonkukl Jul 2, 2024
069b0f5
Updating year
alonkukl Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2023,
Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2024,
U.S. Federal Government and the Government of Israel. All rights reserved.

Redistribution and use in source and binary forms, with or without
Expand Down
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# QFactor implementation on GPUs using JAX
`bqskit-qfactor-jax` is a Python package that implements circuit instantiation with [QFactor](https://arxiv.org/abs/2306.08152) on GPUs to accelerate [BQSKit](https://github.com/bqskit/bqskit). It uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as an abstraction layer of the GPUs, seamlessly utilizing JIT compilation and GPU parallelism.
# QFactor and QFactor-Sample implementations on GPUs using JAX
`bqskit-qfactor-jax` is a Python package that implements circuit instantiation using the [QFactor](https://ieeexplore.ieee.org/abstract/document/10313638) and [QFactor-Sample](https://arxiv.org/abs/2405.12866) algorithms on GPUs to accelerate [BQSKit](https://github.com/bqskit/bqskit). It uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as an abstraction layer of the GPUs, seamlessly utilizing JIT compilation and GPU parallelism.

## Installation
`bqskit-qfactor-jax` is available for Python 3.8+ on Linux.
Expand Down Expand Up @@ -34,7 +34,11 @@ echo quit | nvidia-cuda-mps-control
```

# References
Kukliansky, Alon, et al. "QFactor:A Domain-Specific Optimizer for Quantum Circuit Instantiation." arXiv preprint [arXiv:2306.08152](https://arxiv.org/abs/2306.08152) (2023).
If you are using QFactor-JAX please cite:
Kukliansky, Alon, et al. "QFactor: A Domain-Specific Optimizer for Quantum Circuit Instantiation." 2023 IEEE International Conference on Quantum Computing and Engineering (QCE). Vol. 1. IEEE, 2023. [Link](https://ieeexplore.ieee.org/abstract/document/10313638).

If you are using QFactor-Sample please cite:
Kukliansky, Alon, et al. "Leveraging Quantum Machine Learning Generalization to Significantly Speed-up Quantum Compilation" arXiv preprint [arXiv:2405.12866](https://arxiv.org/abs/2405.12866) (2024).

## License
The software in this repository is licensed under a **BSD free software
Expand All @@ -45,5 +49,5 @@ for more information.

## Copyright

Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2023,
Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2024,
U.S. Federal Government and the Government of Israel. All rights reserved.
87 changes: 87 additions & 0 deletions examples/adder63_10q_block_28.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
OPENQASM 2.0;
include "qelib1.inc";
qreg q[10];
cx q[7], q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[9];
cx q[7], q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[9];
cx q[7], q[9];
cx q[5], q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[5];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
cx q[5], q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[5];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
cx q[5], q[7];
cx q[4], q[5];
cx q[7], q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[5];
u3(0.0, 0.0, 0.7853981633974483) q[7];
cx q[4], q[5];
cx q[6], q[7];
u3(0.0, 0.0, 5.497787143782138) q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[5];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[6];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
cx q[4], q[5];
cx q[6], q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[6];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
cx q[2], q[4];
cx q[6], q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[2];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, 6.283185307179586) q[6];
cx q[7], q[8];
u3(0.0, 0.0, 0.7853981633974483) q[2];
cx q[3], q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[8];
cx q[0], q[2];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
cx q[7], q[8];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[0];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[2];
cx q[3], q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[8];
cx q[0], q[2];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
cx q[7], q[8];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[0];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[2];
cx q[3], q[4];
cx q[7], q[9];
cx q[0], q[2];
cx q[1], q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, -3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[1];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
cx q[6], q[7];
u3(0.0, 0.0, 7.0685834705770345) q[9];
cx q[1], q[3];
u3(1.5707963267948966, 2.356194490192345, -3.141592653589793) q[6];
u3(1.5707963267948966, -2.356194490192345, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[1];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
cx q[7], q[9];
cx q[1], q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[1];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
cx q[7], q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[9];
cx q[7], q[9];
cx q[6], q[7];
u3(0.0, 0.0, 5.497787143782138) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
141 changes: 141 additions & 0 deletions examples/compare_qfactor_sample_to_qfactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from __future__ import annotations

import argparse
import time

from bqskit import enable_logging
from bqskit.compiler import CompilationTask
from bqskit.compiler import Compiler
from bqskit.ir.circuit import Circuit
from bqskit.passes import ToVariablePass

from qfactorjax.qfactor import QFactorJax
from qfactorjax.qfactor_sample_jax import QFactorSampleJax

enable_logging(verbose=True)


parser = argparse.ArgumentParser(
description='Comparing the re-instantiation run time of QFactor-JAX and '
'QFactor-Sample-JAX. Running on adder63_10q_block_28.qasm the '
'difference is X4, and for heisenberg64_10q_block_104.qasm '
'the difference is X10 and QFactor-JAX doesn\'t find a solution. '
'For vqe12_10q_block145.qasm the difference is X34.',
)

parser.add_argument('--input_qasm', type=str, required=True)
parser.add_argument('--multistarts', type=int, default=32)
parser.add_argument('--max_iters', type=int, default=6000)
parser.add_argument('--dist_tol', type=float, default=1e-8)
parser.add_argument('--num_params_coef', type=int, default=1)
parser.add_argument('--exact_amount_of_sample_states', type=int)
parser.add_argument('--overtrain_relative_threshold', type=float, default=0.1)


params = parser.parse_args()


print(params)

file_name = params.input_qasm
dist_tol_requested = params.dist_tol
num_mutlistarts = params.multistarts
max_iters = params.max_iters

num_params_coef = params.num_params_coef

exact_amount_of_sample_states = params.exact_amount_of_sample_states
overtrain_relative_threshold = params.overtrain_relative_threshold


instantiate_options = {
'multistarts': num_mutlistarts,
}


qfactor_gpu_instantiator = QFactorJax(

dist_tol=dist_tol_requested, # Stopping criteria for distance

max_iters=100000, # Maximum number of iterations
min_iters=1, # Minimum number of iterations

# One step plateau detection -
# diff_tol_a + diff_tol_r ∗ |c(i)| <= |c(i)|-|c(i-1)|
diff_tol_a=0.0, # Stopping criteria for distance change
diff_tol_r=1e-10, # Relative criteria for distance change

# Long plateau detection -
# diff_tol_step_r*|c(i-diff_tol_step)| <= |c(i)|-|c(i-diff_tol_step)|
diff_tol_step_r=0.1, # The relative improvement expected
diff_tol_step=200, # The interval in which to check the improvement

# Regularization parameter - [0.0 - 1.0]
# Increase to overcome local minima at the price of longer compute
beta=0.0,
)


qfactor_sample_gpu_instantiator = QFactorSampleJax(

dist_tol=dist_tol_requested, # Stopping criteria for distance

max_iters=max_iters, # Maximum number of iterations
min_iters=6, # Minimum number of iterations

# Regularization parameter - [0.0 - 1.0]
# Increase to overcome local minima at the price of longer compute
beta=0.0,

amount_of_validation_states=2,
# indicates the ratio between the sum of parameters in the circuits to the
# sample size.
diff_tol_r=1e-4,
num_params_coef=num_params_coef,
overtrain_relative_threshold=overtrain_relative_threshold,
exact_amount_of_states_to_train_on=exact_amount_of_sample_states,
)


print(
f'Will use {file_name} {dist_tol_requested = } {num_mutlistarts = }'
f' {num_params_coef = }',
)

orig_10q_block_cir = Circuit.from_file(f'{file_name}')

with Compiler(num_workers=1) as compiler:
task = CompilationTask(orig_10q_block_cir, [ToVariablePass()])
task_id = compiler.submit(task)
orig_10q_block_cir_vu = compiler.result(task_id)


tic = time.perf_counter()
target = orig_10q_block_cir_vu.get_unitary()
time_to_simulate_circ = time.perf_counter() - tic
print(f'Time to simulate was {time_to_simulate_circ}')

tic = time.perf_counter()
orig_10q_block_cir_vu.instantiate(
target, multistarts=num_mutlistarts, method=qfactor_sample_gpu_instantiator,
)
sample_inst_time = time.perf_counter() - tic
inst_sample_dist_from_target = orig_10q_block_cir_vu.get_unitary(
).get_distance_from(target, 1)

print(
f'QFactor-Sample-JAX {sample_inst_time = } '
f'{inst_sample_dist_from_target = }'
f' {num_params_coef = }',
)

tic = time.perf_counter()
orig_10q_block_cir_vu.instantiate(
target, multistarts=num_mutlistarts, method=qfactor_gpu_instantiator,
)
full_inst_time = time.perf_counter() - tic
inst_dist_from_target = orig_10q_block_cir_vu.get_unitary().get_distance_from(
target, 1,
)

print(f'QFactor-JAX {full_inst_time = } {inst_dist_from_target = }')
8 changes: 7 additions & 1 deletion examples/gate_deletion_syth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from timeit import default_timer as timer

from bqskit import Circuit
from bqskit import enable_logging
from bqskit.compiler import Compiler
from bqskit.passes import ForEachBlockPass
from bqskit.passes import QuickPartitioner
Expand All @@ -18,6 +19,9 @@
from qfactorjax.qfactor import QFactorJax


enable_logging()


def run_gate_del_flow_example(
amount_of_workers: int = 10,
) -> tuple[Circuit, Circuit, float]:
Expand Down Expand Up @@ -102,7 +106,9 @@ def run_gate_del_flow_example(

if __name__ == '__main__':

in_circuit, out_circuit, run_time = run_gate_del_flow_example()
in_circuit, out_circuit, run_time = run_gate_del_flow_example(
amount_of_workers=1,
)

print(
f'Partitioning + Synthesis took {run_time}'
Expand Down
57 changes: 57 additions & 0 deletions examples/heisenberg64_10q_block_104.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
OPENQASM 2.0;
include "qelib1.inc";
qreg q[10];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[1];
cx q[0], q[1];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[2];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[5];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[6];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[8];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[9];
u3(0.0, 0.0, 0.02) q[1];
cx q[0], q[1];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[0];
u3(0.0, 1.406583, -1.406583) q[1];
cx q[1], q[2];
u3(0.0, 0.0, 0.02) q[2];
cx q[1], q[2];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[1];
u3(0.0, 1.406583, -1.406583) q[2];
cx q[2], q[3];
u3(0.0, 0.0, 0.02) q[3];
cx q[2], q[3];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[2];
u3(0.0, 1.406583, -1.406583) q[3];
cx q[3], q[4];
u3(0.0, 0.0, 0.02) q[4];
cx q[3], q[4];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[3];
u3(0.0, 1.406583, -1.406583) q[4];
cx q[4], q[5];
u3(0.0, 0.0, 0.02) q[5];
cx q[4], q[5];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[4];
u3(0.0, 1.406583, -1.406583) q[5];
cx q[5], q[6];
u3(0.0, 0.0, 0.02) q[6];
cx q[5], q[6];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[5];
u3(0.0, 1.406583, -1.406583) q[6];
cx q[6], q[7];
u3(0.0, 0.0, 0.02) q[7];
cx q[6], q[7];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[6];
u3(0.0, 1.406583, -1.406583) q[7];
cx q[7], q[8];
u3(0.0, 0.0, 0.02) q[8];
cx q[7], q[8];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[7];
u3(0.0, 1.406583, -1.406583) q[8];
cx q[8], q[9];
u3(0.0, 0.0, 0.02) q[9];
cx q[8], q[9];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[8];
u3(0.0, 1.406583, -1.406583) q[9];
Loading
Loading