<p align='center'><img src="img/qash.png" width="250"></p>

# Qash - QKDC (Quantum Key Derivation Circuits)
- key-derivation & hashing using quantum computations
- can be used to produce deterministic and non-deterministic outputs
- configurations set using config.toml file
- key determinism is based on shot number
    - shots=None or shots=0 allows determinsitic output
    - shots >= 1 allows non-deterministic output
    - deterministic: 
        - each input leads to same output every run
        - requires shots=None and backend to support statevector mode
        - only works with simulators, QPUs currently do not support statevector
    - non-deterministic
        - same input does not lead to same output for each run
        - when using physical/real hardware, this is only functionality supported
        - true randomness is not available when using simulators (psuedorandom)
- gradient calculation
    - calculate gradient of quantum circuit output
    - generate multiple unique keys/hashes using 1 circuit execution
    - less possibility for hash collisions
    - useful for key-derivation
    - very difficult to reverse without original (pre-gradient) values

In [1]:
import pennylane as qml
from jax import jit, random
from jax import numpy as jnp
import sys, os
import qkdc_helper as helper
import qkdc_electron as electron
import qkdc_photon as photon 
from functools import partial
from quantuminspire.credentials import enable_account
from qiskit_aqt_provider import AQTProvider

### Set Enviornment Variables
- kokkos
    - set OMP_PROC_BIND='spread' is required
    - set OMP_PLACES='threads' for best performance

In [2]:
os.environ["OMP_PROC_BIND"] = "spread"
os.environ["OMP_PLACES"] = "threads"

### Load API Key (Quantum Inspire)
- config.toml file is used to configure/store API key(s)

In [3]:
try:
    qml.default_config.load(".config/pennylane/config.toml")
    enable_account(qml.default_config['quantuminspire']['global']['api_key'])
except Exception as e: 
    print(e)

### Turn on x64 Float Mode
- double precision mode is required when using many of the JAX features, such as JIT

In [4]:
x64_mode = True   # double precision float mode for JAX 
helper.x64Switch(x64_mode)

jax float64 = enabled


### Check whether JAX is using CPU or GPU

In [5]:
helper.getBackend()

platform = gpu


### Process Data 
- pepper is an optional array of floats that represent a quantum device state, this state is applied in order to make the hash more effective (similar purpose to a salt)
    - if not using pepper (pepper=None), then superposition is used to initialize starting state
    - can be applied alongside an optional salt value (salt is applied the same as any othe key derivation program)
    - salt can also be applied without using a pepper
- if pepper is not set than circuits will handle it accordingly (circuits will still run)
- use of padding allows for string to be interpreted as longer value
    - 6 = 5 + 1 (total length = original string length + padding length)
    - hashes vary in size based on string length (use padding to control length of hashes)

In [6]:
pepper = jnp.array([0.0, 0.0, 0.0, 0.0])   # user defined array of floats that help randomize hash

text = "test"
text2 = "tess"
total_char_len = 4    # total desired char length to base hash on

pad_len1 = total_char_len - len(text)   # pad length for 1st string
pad_len2 = total_char_len - len(text2)  # pad length for 2nd string

text_arr = helper.createAndPad(text, pad_len1)    # create and pad data (1st string/array)
text_arr2 = helper.createAndPad(text2, pad_len2)  # create and pad data (2nd string/array)

num_wires = len(text_arr)   # calculate number of wires needed for quantum devices

print(f"test data (string): {text}")
print(f"test data (jax.numpy): {text_arr}")
print("...")
print(f"test data 2 (string): {text2}")
print(f"test data 2 (jax.numpy): {text_arr2}")
print("...")
if text_arr2.all == text_arr.all:
    print("processed string inputs are equal")
else:
    print("processed string inputs are different")

test data (string): test
test data (jax.numpy): [116. 101. 115. 116.]
...
test data 2 (string): tess
test data 2 (jax.numpy): [116. 101. 115. 115.]
...
processed string inputs are different


### SuperConductor Circuit
- this circuit is meant to run on superconducting QPUs 
    - is also compatible with trapped-ion devices (IonQ, AQT)
- using variance measurement instead of expectation value allows for better compatibility with various devices/simulators
    - better compatibility with various quantum devices

In [7]:
@partial(jit, static_argnames=['device','shots'])
def qxHashCirq(input, pepper, seed, device, shots=None):
    key = random.PRNGKey(seed)
    if device == 'default.qubit.jax':
        qdev = qml.device(device, wires=num_wires, prng_key=key, shots=shots)
    elif device == 'aqt-local':
        aqt_backend = AQTProvider("").get_backend("offline_simulator_no_noise")
        qdev = qml.device('qiskit.remote', wires=num_wires, backend=aqt_backend, shots=shots)
    elif device == 'qiskit.aer':
        backend = helper.chooseBackend(shots)
        qdev = qml.device(device, wires=num_wires, shots=shots, backend=backend)
    else:
        qdev = qml.device(device, wires=num_wires, shots=shots)

    @qml.qnode(qdev, interface="jax")
    def cirq(input, pepper, key):
        if pepper is not None:
            electron.angleEmbed(input,pepper)
        else:
            electron.superPos(input)
        electron.rotLoop(input)
        electron.singleX(input)
        electron.rotLoop(input)
        electron.strongTangle(input, key)
        electron.rotLoop(input)
        return [qml.var(qml.PauliZ(wires=i)) for i in range(num_wires)]
    
    return cirq(input, pepper, key)

### Photonic Circuit
- this circuit is meant to run on photonic QPUs (only compatible with photonic processors)
- Note: only works with fock backends, for gaussian implementation see [GausQash](https://github.com/TimeMelt/GausQash)

In [8]:
def qxBerryCirq(input, pepper):
    berry_device = qml.device('strawberryfields.fock', wires=num_wires, cutoff_dim=2)

    @qml.qnode(berry_device, interface="jax")
    def cirq(input, pepper):
        photon.prepareCohState(input, pepper)
        photon.thermalState(input)
        photon.photonRotate(input)
        photon.displaceStep(input)
        photon.beamSplit(input)
        photon.crossKerr(input)
        photon.cubicPhase(input)
        return [qml.expval(qml.NumberOperator(wires=i)) for i in range(num_wires)]

    return cirq(input, pepper)

### Define Parameters
- output_mode: dictionary for whether to output as 'hex' or 'base64'
- seed: integer for strong entanglement layer (superconductor circuit only)
- float_mode: dictionary with values to process circuit output as either single precision or double precision floats

In [9]:
seed = 10    # seed for strong entanglement interaction (superconductor circuit only)
output_mode = {'hex':'hex', 'base64':'base64'}   
float_mode = {'double':'double', 'single':'single'}

### Execute Photonic Circuit
- output processing only works in double precision mode 
    - float_mode["double"]

In [10]:
# photonic hash for 1st string 
output1 = qxBerryCirq(text_arr, pepper)
output_string1 = helper.processOutput(output1, output_mode['hex'], float_mode['double'])
gradient1 = helper.calcGradHash(output1, 'hex', 'double')

# photonic hash for 2nd string
output2 = qxBerryCirq(text_arr2, pepper)
output_string2 = helper.processOutput(output2, output_mode['hex'], float_mode['double'])
gradient2 = helper.calcGradHash(output2, 'hex', 'double')

# output to console
print(f"stawberry output 1: {output1}")
print(f"stawberry output 1 (length): {len(output1)}")
print(f"strawberry string 1: {output_string1}")
print(f"strawberry grad string 1: {gradient1}")
print("...")
print(f"stawberry output 2: {output2}")
print(f"strawaberry output 2 (length): {len(output2)}")
print(f"strawberry string 2: {output_string2}")
print(f"strawberry grad string 2: {gradient2}")

print('')
if output_string1 == output_string2:
    print("strawberry string values are equal")
else:
    print("strawberry string values are different")
    
print("...")

if output1 == output2:
    print("strawberry raw output  values are equal")
else:
    print("strawberry raw output values are different")    

stawberry output 1: [Array(-6.75006293e-27, dtype=float64), Array(-6.19952588e-27, dtype=float64), Array(-8.34212986e-27, dtype=float64), Array(-8.32685351e-28, dtype=float64)]
stawberry output 1 (length): 4
strawberry string 1: 1cf043e4a7153ae7fbb68a9752d03af022d297a5e23b3ab9c53689432b26
strawberry grad string 1: 5f3a882ecc964047e7d9e6fd06294052c64415bee8a54071230c2d21a2a6
...
stawberry output 2: [Array(-4.2687751e-28, dtype=float64), Array(-3.79789255e-28, dtype=float64), Array(-4.9956685e-28, dtype=float64), Array(-2.05732392e-28, dtype=float64)]
strawaberry output 2 (length): 4
strawberry string 2: 6c24e65d91073aa782003f57252a3aaeebf3e59ef5443a9977eb747f107a
strawberry grad string 2: b5c6748657ac403da00d3422c4c540466ab947153c1f4066cad18db4e52c

strawberry string values are different
...
strawberry raw output values are different


### Execute SuperConductor Circuit
- output processing has 2 modes
    - single precision float 
        - shorter hashes
        - shot number does not matter
        - float_mode["single"]
    - double precision float
        - longer hashes
        - some devices require multiple shots for proper output
        - float_mode["double"]
- simulator: dictionary used to determine which quantum device to use

In [11]:
simulator = {
    "qiskit": "qiskit.aer",
    "cirq": "cirq.simulator",
    "jax": "default.qubit.jax",
    "ionq": "ionq.simulator",
    "qinspire": "quantuminspire.qi",
    "aqt-local": "aqt-local",
    "nvidia": "lightning.gpu",
    "kokkos": "lightning.kokkos",
}

# superconductor hash for 1st string
output3 = qxHashCirq(text_arr, pepper, seed, simulator["qiskit"], 1000)
output_string3 = helper.processOutput(output3, output_mode['hex'], float_mode['double'])
gradient3 = helper.calcGradHash(output3, 'hex', 'double')

# superconductor hash for 2nd string
output4 = qxHashCirq(text_arr2, pepper, seed, simulator["cirq"], 1000)
output_string4 = helper.processOutput(output4, output_mode['hex'], float_mode['double'])
gradient4 = helper.calcGradHash(output4, 'hex', 'double')

# output to console
print(f"pennylane output 3: {output3}")
print(f"pennylane output 3 (length): {len(output3)}")
print(f"pennylane string 3: {output_string3}")
print(f"pennylane grad string 3: {gradient3}")
print("...")
print(f"pennylane output 4: {output4}")
print(f"pennylane output 4 (length): {len(output4)}")
print(f"pennylane string 4: {output_string4}")
print(f"pennylane grad string 4: {gradient4}")

print('')
if output_string3 == output_string4:
    print("pennylane string values are equal")
else:
    print("pennylane string values are different") 
    
print("...")

if output3 == output4:
    print("pennylane raw output  values are equal")
else:
    print("pennylane raw output values are different") 

pennylane output 3: [Array(0.363196, dtype=float64), Array(0.605616, dtype=float64), Array(0.540316, dtype=float64), Array(0.168256, dtype=float64)]
pennylane output 3 (length): 4
pennylane string 3: 28e8a71de69a404e47e28240b782404b040b780346da4030d35a858793de
pennylane grad string 3: 673a6210cfa240491455c8d0d824405a8572490b7bfe4064a862cb63186f
...
pennylane output 4: [Array(0.0784, dtype=float64), Array(0.561756, dtype=float64), Array(0.649536, dtype=float64), Array(0.804636, dtype=float64)]
pennylane output 4 (length): 4
pennylane string 4: 5c28f5c28f59404c167a0f9096b940503d07c84b5dcd40541dab9f559b3d
pennylane grad string 4: 9b95e6e6d0bf4056508f2ed1f5e34051d00090f5594c4058360e9214684f

pennylane string values are different
...
pennylane raw output values are different


### Display Output Sizes
- display string length, byte size, and bit length for outputs
- sizes vary based on data size, use padding option to control hash length

In [12]:
# output hash lengths
print(f"string 1 length: {len(output_string1)}")
print(f"string 2 length: {len(output_string2)}")
print(f"string 3 length: {len(output_string3)}")
print(f"string 4 length: {len(output_string4)}")

print("...")

# output hash byte sizes
print(f"string 1 byte size: {sys.getsizeof(output_string1)}")
print(f"string 2 byte size: {sys.getsizeof(output_string2)}")
print(f"string 3 byte size: {sys.getsizeof(output_string3)}")
print(f"string 4 byte size: {sys.getsizeof(output_string4)}")

print("...")

# output hash bit lengths
print(f"string 1 bit length: {sys.getsizeof(output_string1)*8}")
print(f"string 2 bit length: {sys.getsizeof(output_string2)*8}")
print(f"string 3 bit length: {sys.getsizeof(output_string3)*8}")
print(f"string 4 bit length: {sys.getsizeof(output_string4)*8}")

string 1 length: 60
string 2 length: 60
string 3 length: 60
string 4 length: 60
...
string 1 byte size: 109
string 2 byte size: 109
string 3 byte size: 109
string 4 byte size: 109
...
string 1 bit length: 872
string 2 bit length: 872
string 3 bit length: 872
string 4 bit length: 872
