<p align='center'><img src="img/qash.png" width="250"></p>

# Qash - QKDC (Quantum Key Derivation Circuits)
- key derivation using quantum computations
- configurations set using config.toml file

In [1]:
import pennylane as qml
from jax import jit, random
from jax import numpy as jnp
import sys
import qkdc_helper as helper
import qkdc_electron as electron
import qkdc_photon as photon 
from functools import partial
from quantuminspire.credentials import enable_account
from qiskit_aqt_provider import AQTProvider

### Load API Key (Quantum Inspire)
- config.toml file is used to configure/store API key

In [2]:
try:
    qml.default_config.load(".config/pennylane/config.toml")
    enable_account(qml.default_config['quantuminspire']['global']['api_key'])
except Exception as e: 
    print(e)

### Turn on x64 Float Mode
- double precision mode is required when using many of the JAX features, such as JIT

In [3]:
x64_mode = True   # double precision float mode for JAX 
helper.x64Switch(x64_mode)

jax float64 = enabled


### Check whether JAX is using CPU or GPU

In [4]:
helper.getBackend()

platform = gpu


### Process Data 
- pepper is an optional array of floats that represent a quantum device state, this state is applied in order to make the hash more effective (similar purpose to a salt)
- pepper can be applied alongside an optional salt value (salt is applied the same as any othe key derivation program)
- if pepper is not set than circuits will handle it accordingly (circuits will still run)
- use of padding allows for string to be interpreted as longer value
    - 6 = 5 + 1 (total length = original string length + padding length)
    - hashes vary in size based on string length (use padding to control length of hashes)

In [5]:
pepper = jnp.array([0.0, 0.0, 0.0, 0.0])   # user defined array of floats that help randomize hash

text = "test"
text2 = "tess"
total_char_len = 5    # total desired char length to base hash on

pad_len1 = total_char_len - len(text)   # pad length for 1st string
pad_len2 = total_char_len - len(text2)  # pad length for 2nd string

text_arr = helper.createAndPad(text, pad_len1)    # create and pad data (1st string/array)
text_arr2 = helper.createAndPad(text2, pad_len2)  # create and pad data (2nd string/array)

num_wires = len(text_arr)   # calculate number of wires needed for quantum devices

print(f"test data (string): {text}")
print(f"test data (jax.numpy): {text_arr}")
print("...")
print(f"test data 2 (string): {text2}")
print(f"test data 2 (jax.numpy): {text_arr2}")
print("...")
if text_arr2.all == text_arr.all:
    print("processed string inputs are equal")
else:
    print("processed string inputs are different")

test data (string): test
test data (jax.numpy): [116.         101.         115.         116.         -28.25539552]
...
test data 2 (string): tess
test data 2 (jax.numpy): [116.         101.         115.         115.         -28.25539552]
...
processed string inputs are different


### SuperConductor Circuit
- this circuit is meant to run on superconducting QPUs 
    - is also compatible with trapped-ion devices (IonQ, AQT)
- using variance measurement instead of expectation value allows for better compatibility with various devices/simulators
    - better compatibility with various quantum devices

In [6]:
@partial(jit, static_argnames=['device'])
def qxHashCirq(input, pepper, seed, device):
    key = random.PRNGKey(seed)
    if device == 'default.qubit.jax':
        qdev = qml.device(device, wires=num_wires, prng_key=key)
    elif device == 'aqt-local':
        aqt_backend = AQTProvider("").get_backend("offline_simulator_no_noise")
        qdev = qml.device('qiskit.remote', wires=num_wires, backend=aqt_backend, shots=200)
    elif device == 'qiskit.aer' or device == 'ionq.simulator' or device == 'quantuminspire.qi':
        qdev = qml.device(device, wires=num_wires, shots=500)
    else:
        qdev = qml.device(device, wires=num_wires)

    @qml.qnode(qdev, interface="jax")
    def cirq(input, pepper, key):
        if pepper is not None:
            electron.angleEmbed(input,pepper)
        else:
            electron.superPos(input)
        electron.rotLoop(input)
        electron.singleX(input)
        electron.rotLoop(input)
        electron.strongTangle(input, key)
        electron.rotLoop(input)
        return [qml.var(qml.PauliZ(wires=i)) for i in range(num_wires)]
    
    return cirq(input, pepper, key)

### Photonic Circuit
- this circuit is meant to run on photonic QPUs (only compatible with photonic processors)
- Note: only works with fock backends, for gaussian implementation see [GausQash](https://github.com/TimeMelt/GausQash)

In [7]:
def qxBerryCirq(input, pepper):
    berry_device = qml.device('strawberryfields.fock', wires=num_wires, cutoff_dim=2)

    @qml.qnode(berry_device, interface="jax")
    def cirq(input, pepper):
        photon.prepareCohState(input, pepper)
        photon.thermalState(input)
        photon.photonRotate(input)
        photon.displaceStep(input)
        photon.beamSplit(input)
        photon.crossKerr(input)
        photon.cubicPhase(input)
        return [qml.expval(qml.NumberOperator(wires=i)) for i in range(num_wires)]

    return cirq(input, pepper)

### Define Parameters
- output_mode: whether to output hash in 'hex' or 'base64'
- seed: integer for strong entanglement layer (superconductor circuit only)
- float_mode: dictionary with values to process circuit output as either single precision or double precision floats

In [8]:
output_mode = 'hex'   # can be 'hex' or 'base64'
seed = 10    # seed for strong entanglement interaction (superconductor circuit only)
float_mode = {
    'double': 'double',
    'single': 'single',
}

### Execute Photonic Circuit
- output processing only works in double precision mode 
    - float_mode["double"]

In [9]:
# photonic hash for 1st string 
output1 = qxBerryCirq(text_arr, pepper)
output_string1 = helper.processOutput(output1, output_mode, float_mode['double'])

# photonic hash for 2nd string
output2 = qxBerryCirq(text_arr2, pepper)
output_string2 = helper.processOutput(output2, output_mode, float_mode['double'])

# output to console
print(f"stawberry output 1 (length): {len(output1)}")
print(f"strawberry string 1: {output_string1}")
if output_mode == 'hex':
    print("...")
print(f"strawaberry output 2 (length): {len(output2)}")
print(f"strawberry string 2: {output_string2}")

print('')
if output_string1 == output_string2:
    print("strawberry string values are equal")
else:
    print("strawberry string values are different")
    
print("...")

if output1 == output2:
    print("strawberry raw output  values are equal")
else:
    print("strawberry raw output values are different")    

stawberry output 1 (length): 5
strawberry string 1: e4836bacb55436668f58be0edfed3669e9b184c71b3b362d20a807cfc898367008b638ed1120
...
strawaberry output 2 (length): 5
strawberry string 2: 1eeda8be3b9d362385d27c3706f83625c17fa66d269e361e425f085a90eb362bad1ce7cba9df

strawberry string values are different
...
strawberry raw output values are different


### Execute SuperConductor Circuit
- output processing has 2 modes
    - single precision float 
        - shorter hashes
        - shot number does not matter
        - float_mode["single"]
    - double precision float
        - longer hashes
        - some device require multiple shots for proper output
        - float_mode["double"]
- simulator: dictionary used to determine which quantum device to use

In [10]:
simulator = {
    "qiskit": "qiskit.aer",
    "cirq": "cirq.simulator",
    "jax": "default.qubit.jax",
    "ionq": "ionq.simulator",
    "qinspire": "quantuminspire.qi",
    "aqt-local": "aqt-local",
}

# superconductor hash for 1st string
output3 = qxHashCirq(text_arr, pepper, seed, simulator["qiskit"])
output_string3 = helper.processOutput(output3, output_mode, float_mode['double'])

# superconductor hash for 2nd string
output4 = qxHashCirq(text_arr2, pepper, seed, simulator["cirq"])
output_string4 = helper.processOutput(output4, output_mode, float_mode['double'])

# output to console
print(f"pennylane output 3 (length): {len(output3)}")
print(f"pennylane string 3: {output_string3}")
if output_mode == 'hex':
    print("...")
print(f"pennylane output 4 (length): {len(output4)}")
print(f"pennylane string 4: {output_string4}")

print('')
if output_string3 == output_string4:
    print("pennylane string values are equal")
else:
    print("pennylane string values are different")
    
print("...")

if output3 == output4:
    print("pennylane raw output  values are equal")
else:
    print("pennylane raw output values are different") 

pennylane output 3 (length): 5
pennylane string 3: 59b3d07c84b34050970a3d70a3d74058fd70a3d70a3d40368f5c28f5c2904022bd3c36113405
...
pennylane output 4 (length): 5
pennylane string 4: 1905f453206e404f0c291f7215894058fd8d4f7e645c404432fca3762bae402081d839e5a8bc

pennylane string values are different
...
pennylane raw output values are different


### Display Output Sizes
- sizes vary based on data size, use padding option to control hash length

In [11]:
# output hash lengths
print(f"string 1 length: {len(output_string1)}")
print(f"string 2 length: {len(output_string2)}")
print(f"string 3 length: {len(output_string3)}")
print(f"string 4 length: {len(output_string4)}")

print("...")

# output hash byte sizes
print(f"string 1 byte size: {sys.getsizeof(output_string1)}")
print(f"string 2 byte size: {sys.getsizeof(output_string2)}")
print(f"string 3 byte size: {sys.getsizeof(output_string3)}")
print(f"string 4 byte size: {sys.getsizeof(output_string4)}")

print("...")

# output hash bit lengths
print(f"string 1 bit length: {sys.getsizeof(output_string1)*8}")
print(f"string 2 bit length: {sys.getsizeof(output_string2)*8}")
print(f"string 3 bit length: {sys.getsizeof(output_string3)*8}")
print(f"string 4 bit length: {sys.getsizeof(output_string4)*8}")

string 1 length: 76
string 2 length: 76
string 3 length: 76
string 4 length: 76
...
string 1 byte size: 125
string 2 byte size: 125
string 3 byte size: 125
string 4 byte size: 125
...
string 1 bit length: 1000
string 2 bit length: 1000
string 3 bit length: 1000
string 4 bit length: 1000
