<p align='center'><img src="img/qash.png" width="250"></p>

# Qash - QKDC (Quantum Key Derivation Circuits)
- key derivation using quantum computations

In [1]:
import pennylane as qml
from jax import jit, random
from jax import numpy as jnp
import sys
import qkdc_helper as helper
import qkdc_electron as electron
import qkdc_photon as photon 
from functools import partial

### Turn on x64 Float Mode
- double precision mode is required when using many of the JAX features, such as JIT

In [2]:
x64_mode = True   # double precision float mode for JAX 
helper.x64Switch(x64_mode)

jax float64 = enabled


### Check whether JAX is using CPU or GPU

In [3]:
helper.getBackend()

platform = gpu


### Process Data 
- pepper is an optional array of floats that represent a quantum device state, this state is applied in order to make the hash more effective (similar purpose to a salt)
- pepper can be applied alongside an optional salt value (salt is applied the same as any othe key derivation program)
- if pepper is not set than circuits will handle it accordingly (circuits will still run)

In [4]:
pepper = jnp.array([0.0, 0.0, 0.0, 0.0])   # user defined array of floats that help randomize hash

text = "test"
text2 = "tess"
total_char_len = 6    # total desired char length to base hash on

pad_len1 = total_char_len - len(text)   # pad length for 1st string
pad_len2 = total_char_len - len(text2)  # pad length for 2nd string

text_arr = helper.createAndPad(text, pad_len1)    # create and pad data (1st string/array)
text_arr2 = helper.createAndPad(text2, pad_len2)  # create and pad data (2nd string/array)

num_wires = len(text_arr)   # calculate number of wires needed for quantum devices

print(f"test data (string): {text}")
print(f"test data (jax.numpy): {text_arr}")
print("...")
print(f"test data 2 (string): {text2}")
print(f"test data 2 (jax.numpy): {text_arr2}")
print("...")
if text_arr2.all == text_arr.all:
    print("processed string inputs are equal")
else:
    print("processed string inputs are different")

test data (string): test
test data (jax.numpy): [116.         101.         115.         116.         -28.25539552
  51.18201262]
...
test data 2 (string): tess
test data 2 (jax.numpy): [116.         101.         115.         115.         -28.25539552
  51.18201262]
...
processed string inputs are different


### SuperConductor Circuit
- this circuit is meant to run on superconducting QPUs 
    - trapped-ion QPU compatibility is currently in testing
- using variance measurement instead of expectation value allows for better compatibility with various devices/simulators
    - better compatibility with google/cirq devices

In [5]:
@partial(jit, static_argnames=['device'])
def qxHashCirq(input, pepper, seed, device):
    key = random.PRNGKey(seed)
    if device == 'default.qubit.jax':
        qdev = qml.device(device, wires=num_wires, prng_key=key)
    else:
        qdev = qml.device(device, wires=num_wires)

    @qml.qnode(qdev, interface="jax")
    def cirq(input, pepper, key):
        if pepper is not None:
            electron.angleEmbed(input,pepper)
        else:
            electron.superPos(input)
        electron.rotLoop(input)
        electron.singleX(input)
        electron.rotLoop(input)
        electron.strongTangle(input, key)
        electron.rotLoop(input)
        return [qml.var(qml.PauliZ(wires=i)) for i in range(num_wires)]
    
    return cirq(input, pepper, key)

### Photonic Circuit
- this circuit is meant to run on photonic QPUs (only compatible with photonic processors)
- Note: only works with fock backends, for gaussian implementation see [GausQash](https://github.com/TimeMelt/GausQash)

In [6]:
def qxBerryCirq(input, pepper):
    berry_device = qml.device('strawberryfields.fock', wires=num_wires, cutoff_dim=2)

    @qml.qnode(berry_device, interface="jax")
    def cirq(input, pepper):
        photon.prepareCohState(input, pepper)
        photon.thermalState(input)
        photon.photonRotate(input)
        photon.displaceStep(input)
        photon.beamSplit(input)
        photon.crossKerr(input)
        photon.cubicPhase(input)
        return [qml.expval(qml.NumberOperator(wires=i)) for i in range(num_wires)]

    return cirq(input, pepper)

### Define Parameters

In [7]:
output_mode = 'hex'   # can be 'hex' or 'base64'
seed = 10    # seed for strong entanglement interaction (superconductor circuit only)

### Execute Photonic Circuit

In [8]:
# photonic hash for 1st string 
output1 = qxBerryCirq(text_arr, pepper)
output_string1 = helper.processOutput(output1, output_mode)

# photonic hash for 2nd string
output2 = qxBerryCirq(text_arr2, pepper)
output_string2 = helper.processOutput(output2, output_mode)

# output to console
print(f"stawberry output 1 (length): {len(output1)}")
print(f"strawberry string 1: {output_string1}")
if output_mode == 'hex':
    print("...")
print(f"strawaberry output 2 (length): {len(output2)}")
print(f"strawberry string 2: {output_string2}")

print('')
if output_string1 == output_string2:
    print("strawberry string values are equal")
else:
    print("strawberry string values are different")
    
print("...")

if output1 == output2:
    print("strawberry raw output  values are equal")
else:
    print("strawberry raw output values are different")    

stawberry output 1 (length): 6
strawberry string 1: 443a8822527d2e8e4c87b9119c812e82a965dfaa0a852e495a2914ba9d742e99f9432cf7bde12e74f5789a58800e
...
strawaberry output 2 (length): 6
strawberry string 2: b439d134ee242e3ecc8807dd0ea02e31c6b42de67ff32e212c743590760c2e4ba5caab0680132e30065eea0eabaf

strawberry string values are different
...
strawberry raw output values are different


### Execute SuperConductor Circuit

In [9]:
simulator = {
    "cirq": "cirq.simulator",
    "jax": "default.qubit.jax",
}

# superconductor hash for 1st string
output3 = qxHashCirq(text_arr, pepper, seed, simulator["cirq"])
output_string3 = helper.processOutput(output3, output_mode)

# superconductor hash for 2nd string
output4 = qxHashCirq(text_arr2, pepper, seed, simulator["cirq"])
output_string4 = helper.processOutput(output4, output_mode)

# output to console
print(f"pennylane output 3 (length): {len(output3)}")
print(f"pennylane string 3: {output_string3}")
if output_mode == 'hex':
    print("...")
print(f"pennylane output 4 (length): {len(output4)}")
print(f"pennylane string 4: {output_string4}")

print('')
if output_string3 == output_string4:
    print("pennylane string values are equal")
else:
    print("pennylane string values are different")
    
print("...")

if output3 == output4:
    print("pennylane raw output  values are equal")
else:
    print("pennylane raw output values are different") 

pennylane output 3 (length): 6
pennylane string 3: 34595c1f17a8404bdc16080ba1044056a7c82cf34f0040235053ec7b813640393f2ee70ef2844055dbbef22bb6ac
...
pennylane output 4 (length): 6
pennylane string 4: 59cfc48b4dfe404f2a8eb323c2f040585ca199a375f04054233d6457a5c04039447e2053a7764055d6d5c928560f

pennylane string values are different
...
pennylane raw output values are different


### Display Output Sizes

In [10]:
# output hash lengths
print(f"string 1 length: {len(output_string1)}")
print(f"string 2 length: {len(output_string2)}")
print(f"string 3 length: {len(output_string3)}")
print(f"string 4 length: {len(output_string4)}")

print("...")

# output hash byte sizes
print(f"string 1 byte size: {sys.getsizeof(output_string1)}")
print(f"string 2 byte size: {sys.getsizeof(output_string2)}")
print(f"string 3 byte size: {sys.getsizeof(output_string3)}")
print(f"string 4 byte size: {sys.getsizeof(output_string4)}")

print("...")

# output hash bit lengths
print(f"string 1 bit length: {sys.getsizeof(output_string1)*8}")
print(f"string 2 bit length: {sys.getsizeof(output_string2)*8}")
print(f"string 3 bit length: {sys.getsizeof(output_string3)*8}")
print(f"string 4 bit length: {sys.getsizeof(output_string4)*8}")

string 1 length: 92
string 2 length: 92
string 3 length: 92
string 4 length: 92
...
string 1 byte size: 141
string 2 byte size: 141
string 3 byte size: 141
string 4 byte size: 141
...
string 1 bit length: 1128
string 2 bit length: 1128
string 3 bit length: 1128
string 4 bit length: 1128
