# Smart meters using Semi Quantum Key Distribution

In [1]:
import os
from qiskit import *
import math
import numpy as np
import random

In [2]:
simulator = Aer.get_backend('qasm_simulator')

## File handling

In [3]:
# path to control centre and edge node database
cc_en_db_path = './database/control_center_edge_node_database.txt'

# path to control centre and smart meter database
cc_sm_db_path = './database/control_center_smart_meter_database.txt'

In [4]:
def addNewDevice(file_path, devicename=None, password=None):
    '''
    file_path(str): path to file on which username and password are stored
    devicename(str): devicename of the new device
    password(str): password of the new device
    '''

    if file_path == None:
        return "File path None"
    
    if devicename == None:
        return "Username None"

    if password == None:
        return "Password None"

    with open(file_path, "a") as file:
        file.write(f'{devicename}:{password}\n')
    return "Successfully registered"

In [5]:
def loginDevice(file_path, devicename, password):
    '''
    file_path(str): path to file on which username and password are stored
    devicename(str): devicename of the device
    password(str): password of the device

    return bool: True if user is authenticated else False
    '''

    if devicename == None or password == None or file_path == None:
        return False

    with open(file_path, "r") as file:
        for device in file.read().split('\n')[:-1]:
            [_devicename, _password] = device.split(':')
            if devicename == _devicename and password == _password:
                return True
    return False

## SQKD

In [76]:
# size of the key
n = 2

# dictionary to store secret keys
secret_keys = {}

# delta parameter for sqkd
delta = 1/8

# (Z-)error threshold for CTRL
p_ctrl = 0.5

# (Z-)error threshold for SIFT
p_test = 0.5

In [75]:
def encodeSQKD(basis, message_bit, circuit, idx):
	'''
	basis(int): defines what basis is the message encode. 0 (Computational) and 1 (Hadamard)
	message_bit (int):
    0 is encoded as |0> or |+>
    1 is encoded as |1> or |->
    circuit(QuantumCircuit): the circuit used for communication
    idx(int): index of the message_bit
	'''
	if message_bit == 1:
		circuit.x(idx)
	if basis == 1:
		circuit.h(idx)
	return circuit

In [8]:
# def sift(decesion, circuit, idx):
# 	'''
# 	decesion(int): 0 (CTRL) and 1 (SIFT)
#     circuit(QuantumCircuit): the circuit used for communication
#     idx(int): index of the message_bit
# 	'''
# 	if decesion == 1:
# 		circuit.measure(idx, idx)
# 	return circuit

In [142]:
def sqkd(device, devicename):
    '''
    device(str): "SMART_METER" or "EDGE_NODE"
    devicename(str): name of device
    '''

    ########### debug #################
    # not_my_data = set(dir())
    ###################################

    # number of bits to be set to device
    N = math.ceil(8*n*(1+delta))

    # generate a message for sqkd (binary array)
    message = np.random.randint(2, size=N)
    print(f'{message=}')

    # generate basis for encoding qubits by control center (alice)
    cc_basis = np.random.randint(2, size=N)
    print(f'{cc_basis=}')

    # circuit for sending data to device
    circuit = QuantumCircuit(N, N)

    for (idx, (basis, message_bit)) in enumerate(zip(cc_basis, message)):
        circuit = encodeSQKD(basis, message_bit, circuit, idx)

    circuit.barrier()

    # device basis for SIFT or CTRL
    device_decisions = np.random.randint(2, size=N)
    print(f'{device_decisions=}')

    # get the reflected (CTRL) and measured (SIFT) qubits

    ctrl_qubits = []

    for (idx, decesion) in enumerate(device_decisions):
        if decesion == 0:
            ctrl_qubits.append(idx)

    print(f'{ctrl_qubits=}')
    sift_qubits = []

    for (idx, decesion) in enumerate(device_decisions):
        if decesion == 1:
            sift_qubits.append(idx)

    print(f'{sift_qubits=}')

    z_sift = [sift_qubits[i] for i in range(len(sift_qubits)) if cc_basis[sift_qubits[i]] == 0]

    print(f'{z_sift=}')
    # SIFT by device

    for idx in sift_qubits:
        circuit.measure(idx, idx)
    # measurement of sifted qubits by device

    result_device = execute(circuit, backend=simulator, shots=1).result()
    count_device = result_device.get_counts(circuit)

    device_measured_value = list(count_device.keys())[0][::-1]
    
    circuit.barrier()


    print(f'{device_measured_value=}')

    ######### PENDING ####################
    # reorder the reflected (CTRL) qubits

    ###################################

    # measurement of reflected qubits by control centre

    for idx in ctrl_qubits:
        if cc_basis[idx] == 1:
            circuit.h(idx)
        circuit.measure(idx, idx)

    result_ctrl = execute(circuit, backend=simulator, shots=1).result()
    count_ctrl = result_ctrl.get_counts(circuit)

    ctrl_measured_value = list(count_ctrl.keys())[0][::-1]

    print(f'{ctrl_measured_value=}')

    # calculating error for ctrl

    z_error_ctrl = 0
    x_error_ctrl = 0
    len_z = 0
    len_x = 0

    for idx in ctrl_qubits:
        if cc_basis[idx] == 0:
            if message[idx] != int(ctrl_measured_value[idx]):
                print("z")
                print(idx)
                z_error_ctrl += 1
            len_z += 1
        elif cc_basis[idx] == 1:
            if message[idx] != int(ctrl_measured_value[idx]):
                print("x")
                print(idx)
                x_error_ctrl += 1
            len_x += 1

    print(f'{z_error_ctrl/len_z=}')
    print(f'{x_error_ctrl/len_x=}')

    if z_error_ctrl/len_z < p_ctrl and x_error_ctrl/len_x < p_ctrl:

        # select n random sift bits in z basis as test bits

        test_bits = set()
        while len(test_bits) < n:
            test_bits = set(random.choices(z_sift, k=n))
        
        print(f'{test_bits=}')

        # defining remaning string

        v = []

        for idx in z_sift:
            if idx not in test_bits:
                v.append(idx)

        print(f'{v=}')
        # calculating z error for sift

        z_error_test = 0

        for idx in test_bits:
            if int(device_measured_value[idx]) != int(ctrl_measured_value[idx]):
                z_error_test += 1

        print(f'{z_error_test/len(test_bits)=}')

        if z_error_test/len(test_bits) < p_test and len(v) >= n:

            sk_bits = v[:n]

            sk = ""

            for idx in sk_bits:
                sk += device_measured_value[idx]

            print(f'{sk=}')
            ######## SAVE IN FILE ###############

            file_path = None

            if device == "SMART_METER":
                file_path = cc_sm_db_path

            if device == "EDGE_NODE":
                file_path = cc_en_db_path

            addNewDevice(file_path, devicename, sk)

            print("SECRET KEY SAVED")

            return
    
    print("SQKD FAILED")
    ########### debug #################
    # my_data = set(dir()) - not_my_data
    # for name in my_data:
    #     if name != "not_my_data":
    #         val = eval(name)
    #         print(name, "is", type(val), "and is equal to ", val)
    ###################################

## Test cells

### Add Smart Meter

In [10]:
smart_meter_devicename = "sm1"
smart_meter_password = "pass1"

### Verify Smart Meter

In [11]:
smart_meter_devicename = "sm1"
smart_meter_password = "pass1"

In [147]:
sqkd("SMART_METER", smart_meter_devicename)

message=array([1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1])
cc_basis=array([1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0])
device_decisions=array([0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1])
ctrl_qubits=[0, 3, 5, 6, 8, 9, 11, 12, 14, 16]
sift_qubits=[1, 2, 4, 7, 10, 13, 15, 17]
z_sift=[1, 10, 13, 15, 17]
device_measured_value='010000000010000001'
ctrl_measured_value='111001010111101011'
z_error_ctrl/len_z=0.0
x_error_ctrl/len_x=0.0
test_bits={13, 15}
v=[1, 10, 17]
z_error_test/len(test_bits)=0.0
sk='11'
SECRET KEY SAVED


### Add Edge Node

In [13]:
edge_node_devicename = "en1"
edge_node_password = "pass2"

### Verify Edge Node

In [14]:
edge_node_devicename = "en1"
edge_node_password = "pass2"