# Lightweight and Efficient Protocol Based on ECDH for Securing Smart Grid Communication Infrastructure 

<h5>Authored By - Ali Peivandizadeh, Rafe Alasem, Mostafa Farhadi Moghadam, Haitham Y. Adarbah, Behzad Molavi, Amirhossein Mohajerzadeh, Afzel Noore</h5>

This notebook implements a **lightweight mutual authentication and session key agreement protocol** for Smart Grid networks, using elliptic curve cryptography to ensure security against impersonation, replay, and man-in-the-middle attacks.  
It demonstrates efficient key exchange and authentication suitable for resource-constrained smart meters.


#### Referenced Research Paper:  
[Lightweight and Efficient Protocol Based on ECDH for Securing Smart Grid Communication Infrastructure ](https://www.researchgate.net/publication/393426876_Lightweight_and_Efficient_Protocol_Based_on_ECDH_for_Securing_Smart_Grid_Communication_Infrastructure)


In [3]:
import secrets
from hashlib import sha256
from tinyec import registry
from tinyec.ec import Point
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os
import time
import tracemalloc

# secp256r1 curve
curve = registry.get_curve("secp256r1")
curve_order = curve.field.n

# Helper functions

def H(*args) -> int:
    data = b""
    for arg in args:
        if isinstance(arg, str):
            data += arg.encode()
        elif isinstance(arg, int):
            data += arg.to_bytes((arg.bit_length() + 7) // 8 or 1, 'big')
        elif isinstance(arg, bytes):
            data += arg
        else:
            raise TypeError(f"Unsupported type for hashing: {type(arg)}")
    digest = sha256(data).digest()
    return int.from_bytes(digest, "big")

def random_scalar() -> int:
    return secrets.randbelow(curve_order - 1) + 1

def point_to_bytes(P):
    return P.x.to_bytes(32, 'big') + P.y.to_bytes(32, 'big')

def bytes_to_point(b):
    x = int.from_bytes(b[:32], 'big')
    y = int.from_bytes(b[32:], 'big')
    return Point(curve, x, y)

def encrypt_gcm(key, plaintext):
    nonce = os.urandom(12)
    cipher = Cipher(algorithms.AES(key), modes.GCM(nonce), backend=default_backend())
    encryptor = cipher.encryptor()
    ciphertext = encryptor.update(plaintext) + encryptor.finalize()
    return nonce + encryptor.tag + ciphertext

def decrypt_gcm(key, data):
    nonce = data[:12]
    tag = data[12:28]
    ciphertext = data[28:]
    cipher = Cipher(algorithms.AES(key), modes.GCM(nonce, tag), backend=default_backend())
    decryptor = cipher.decryptor()
    return decryptor.update(ciphertext) + decryptor.finalize()


# Smart Meter

class SmartMeter:
    def __init__(self, ID_SM):
        self.ID_SM = ID_SM
        self.PVSM = random_scalar()
        self.PUSM = self.PVSM * curve.g
        self.AC = None
        self.K_sym = None
        self.session_key = None

    def register(self):
        ci = random_scalar()
        li = random_scalar()
        h_val = H(self.PVSM, li)
        Ni = (h_val * ci) % curve_order
        return {
            "ID_SM": self.ID_SM,
            "Ni": Ni,
            "PUSM": point_to_bytes(self.PUSM)
        }

    def store_registration_data(self, AC, K_sym_hex):
        self.AC = int(AC)
        self.K_sym = bytes.fromhex(K_sym_hex)

    def initiate_auth(self):
        di = random_scalar()
        ei = random_scalar()
        fi = random_scalar()

        Ui = H(self.PVSM, di, ei)
        fi_p = fi * curve.g
        fi_p_bytes = point_to_bytes(fi_p)
        S = H(self.AC, Ui, int.from_bytes(fi_p_bytes, 'big'))

        A1 = int(time.time())
        M1 = H(S, self.AC, Ui, A1)

        payload = (
            M1.to_bytes(32, 'big') +
            fi_p_bytes +
            Ui.to_bytes(32, 'big') +
            A1.to_bytes(8, 'big')
        )
        ciphertext = encrypt_gcm(self.K_sym, payload)

        print(f"\n[SM] Initiating authentication for {self.ID_SM}...")

        ephemeral_data = {
            "M1": M1,
            "fi": fi,
            "di": di,
            "ei": ei,
            "A1": A1
        }

        return ciphertext, ephemeral_data

    def process_server_response(self, encrypted_response, record_from_reg_phase, M1, fi, di, ei, A1):
        K_sym = bytes.fromhex(record_from_reg_phase["K_sym"])
        AC = record_from_reg_phase["AC"]

        plaintext = decrypt_gcm(K_sym, encrypted_response)

        Vi_bytes = plaintext[0:32]
        M2_bytes = plaintext[32:64]
        gi_p_bytes = plaintext[64:]

        Vi = int.from_bytes(Vi_bytes, 'big')
        M2_received = int.from_bytes(M2_bytes, 'big')
        gi_p = bytes_to_point(gi_p_bytes)

        Ui = H(self.PVSM, di, ei)
        fi_p = fi * curve.g
        fi_p_bytes = point_to_bytes(fi_p)
        S = H(AC, Ui, int.from_bytes(fi_p_bytes, 'big'))
        R_prime = H(Vi, AC, Ui, S)

        shared_secret_point_sm = fi * gi_p
        shared_secret_x_int = shared_secret_point_sm.x

        SK_prime = H(AC, S, M1, shared_secret_x_int)
        M2_prime = H(AC, S, R_prime, SK_prime)

        print("\n[SM] Processing server response...")
        print(f"[SM] Computed session key SK_prime: {SK_prime}")

        if M2_prime == M2_received:
            print("[SM] Server authenticated successfully. Session key established.")
            self.session_key = SK_prime
            return True
        else:
            print("[SM] ❌ Server authentication failed.")
            return False


# Server

class Server:
    def __init__(self, curve_order):
        self.PVserver = random_scalar()
        self.registered_devices = {}
        self.session_key = None

    def process_registration(self, reg_msg):
        ID_SM = reg_msg["ID_SM"]
        Ni = reg_msg["Ni"]
        PUSM_bytes = reg_msg["PUSM"]
        PUSM = bytes_to_point(PUSM_bytes)

        ai = random_scalar()
        bi = random_scalar()
        Pv_ai_point = (self.PVserver * ai) * curve.g
        combined = (Pv_ai_point.x + bi) % curve_order
        Ti = H(combined)
        AC = H(Ti, ID_SM, Ni)
        K_sym = secrets.token_bytes(32)

        self.registered_devices[ID_SM] = {
            "PUSM": PUSM,
            "Ni": Ni,
            "Ti": Ti,
            "AC": AC,
            "K_sym": K_sym
        }

        print(f"[Server] Registered device {ID_SM}.")
        return {
            "AC": AC,
            "K_sym": K_sym.hex()
        }

    def process_authentication(self, ID_SM, encrypted_msg):
        record = self.registered_devices[ID_SM]
        K_sym = record["K_sym"]
        AC = record["AC"]

        plaintext = decrypt_gcm(K_sym, encrypted_msg)
        M1 = int.from_bytes(plaintext[0:32], 'big')
        fi_p_bytes = plaintext[32:96]
        Ui = int.from_bytes(plaintext[96:128], 'big')
        A1 = int.from_bytes(plaintext[128:], 'big')

        fi_p = bytes_to_point(fi_p_bytes)

        S_server = H(AC, Ui, int.from_bytes(fi_p_bytes, 'big'))
        M1_server = H(S_server, AC, Ui, A1)

        print(f"\n[Server] Processing authentication for {ID_SM}...")

        if M1_server != M1:
            print("[Server] ❌ Smart Meter authentication failed.")
            return None

        print("[Server]  Smart Meter authentication succeeded.")

        gi = random_scalar()
        zi = random_scalar()
        wi = random_scalar()
        temp = (self.PVserver * zi) % curve_order
        Vi = H(temp, wi)

        gi_p = gi * curve.g
        gi_p_bytes = point_to_bytes(gi_p)
        R = H(Vi, AC, Ui, S_server)

        shared_secret_point = gi * fi_p
        shared_secret_x_int = shared_secret_point.x

        SK = H(AC, S_server, M1, shared_secret_x_int)
        M2 = H(AC, S_server, R, SK)

        print(f"[Server] Computed session key SK: {SK}")

        payload = Vi.to_bytes(32, 'big') + M2.to_bytes(32, 'big') + gi_p_bytes
        encrypted_response = encrypt_gcm(K_sym, payload)

        record["S"] = S_server
        record["Ui"] = Ui
        record["fi_p"] = fi_p
        record["R"] = R
        record["SK"] = SK
        record["gi_p"] = gi_p
        record["M2"] = M2
        record["Vi"] = Vi
        self.session_key = SK

        return encrypted_response


In [4]:

print("\n\n==== BENCHMARKING ====\n")
tracemalloc.start()
t0 = time.perf_counter()

curve = registry.get_curve("secp256r1")
curve_order = curve.field.n
server = Server(curve_order=curve_order)
smart_meter = SmartMeter(ID_SM="SM001")

# Registration Phase
reg_msg = smart_meter.register()
reg_reply = server.process_registration(reg_msg)

# Smart Meter stores AC and symmetric key
smart_meter.store_registration_data(
    AC=reg_reply["AC"],
    K_sym_hex=reg_reply["K_sym"]
)

# Authentication Phase - Step 1
encrypted_msg, ephemeral_data = smart_meter.initiate_auth()

# Server processes authentication (step 2)
encrypted_response = server.process_authentication("SM001", encrypted_msg)

# Extract ephemeral values needed for Step 3
M1 = ephemeral_data["M1"]
fi = ephemeral_data["fi"]
di = ephemeral_data["di"]
ei = ephemeral_data["ei"]
A1 = ephemeral_data["A1"]

# Smart Meter processes server’s response
success = smart_meter.process_server_response(
    encrypted_response,
    reg_reply,
    M1,
    fi,
    di,
    ei,
    A1
)

if success:
    print(" Mutual authentication done. Secure session established.")
    print(f"Smart Meter Session Key: {smart_meter.session_key}")
    print(f"Server Session Key:      {server.session_key}")
    print(f"Keys match  ", smart_meter.session_key == server.session_key)
else:
    print("❌ Authentication failed.")



t1 = time.perf_counter()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

print(f"\n==== BENCHMARK RESULTS ====")
print(f"Total execution time: {t1 - t0:.4f} seconds")
print(f"Peak memory usage: {peak / 1024:.2f} KB")
print(f"Current memory usage after run: {current / 1024:.2f} KB")





==== BENCHMARKING ====

[Server] Registered device SM001.

[SM] Initiating authentication for SM001...

[Server] Processing authentication for SM001...
[Server]  Smart Meter authentication succeeded.
[Server] Computed session key SK: 82798786234533191035176045163643910852762285727073688329739515913726615007181

[SM] Processing server response...
[SM] Computed session key SK_prime: 82798786234533191035176045163643910852762285727073688329739515913726615007181
[SM] Server authenticated successfully. Session key established.
 Mutual authentication done. Secure session established.
Smart Meter Session Key: 82798786234533191035176045163643910852762285727073688329739515913726615007181
Server Session Key:      82798786234533191035176045163643910852762285727073688329739515913726615007181
Keys match   True

==== BENCHMARK RESULTS ====
Total execution time: 12.5102 seconds
Peak memory usage: 179.07 KB
Current memory usage after run: 154.76 KB
