# SGAK: A Robust ECC-Based Authenticated Key Exchange Protocol for Smart Grid Networks

<h5>Authored By - Akber Ali Khan, Vinod Kumar, Ramakant Prasad,M. Javed Idrisi</h5>

This notebook implements a **ECC-Based Authenticated Key Exchange Protocol** for Smart Grid Networks between a **User (U)**, **Trusted Authority (TA)** and **User Device (Ub).**
It covers the system setup, user key generation, mutual authentication, and session key agreement.

#### Referenced Research Paper:  
[SGAK: A Robust ECC-Based Authenticated Key Exchange Protocol for Smart Grid Networks](https://ieeexplore.ieee.org/document/10613397)

In [3]:
import secrets
import time
from hashlib import sha256
from tinyec import registry
from Crypto.Cipher import AES
from tinyec.ec import Point

# Helper Functions

curve = registry.get_curve('secp256r1')

def H(*args) -> bytes:
    data = b""
    for v in args:
        if isinstance(v, str):
            data += v.encode()
        elif isinstance(v, int):
            data += v.to_bytes((v.bit_length() + 7) // 8, 'big')
        elif isinstance(v, tuple):
            data += v[0].to_bytes(32, 'big') + v[1].to_bytes(32, 'big')
        elif isinstance(v, bytes):
            data += v
        else:
            raise TypeError(f"Unsupported type for hashing: {type(v)}")
    digest = sha256(data).digest()
    return digest

def xor_bytes(a: bytes, b: bytes) -> bytes:
    return bytes(x ^ y for x, y in zip(a, b))

def point_to_bytes(P):
    return P.x.to_bytes(32, 'big') + P.y.to_bytes(32, 'big')

def bytes_to_point(data):
    x = int.from_bytes(data[:32], 'big')
    y = int.from_bytes(data[32:], 'big')
    return Point(curve, x, y)

def encrypt(key, plaintext):
    cipher = AES.new(key, AES.MODE_GCM)
    ciphertext, tag = cipher.encrypt_and_digest(plaintext)
    return cipher.nonce + tag + ciphertext

def decrypt(key, ciphertext):
    nonce = ciphertext[:16]
    tag = ciphertext[16:32]
    ct = ciphertext[32:]
    cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
    plaintext = cipher.decrypt_and_verify(ct, tag)
    return plaintext


# Trusted Authority

class TrustedAuthority:
    def __init__(self):
        self.curve = curve
        self.q = self.curve.field.p
        self.r = self.curve.a
        self.s = self.curve.b
        self.g = self.curve.g
        self.n = self.curve.field.n

        self.xi = secrets.randbelow(self.n - 1) + 1
        self.M = secrets.randbelow(self.n - 1) + 1
        self.PK_TA = self.g * self.M

        print(f"[TA] Initialized. PK_TA = {self.PK_TA}")

    def get_public_params(self):
        return {
            "curve_eq": f"v² = u³ + {self.r}u + {self.s} mod {self.q}",
            "q": self.q,
            "r": self.r,
            "s": self.s,
            "g": (self.g.x, self.g.y),
            "H": "SHA-256",
            "PK_TA": (self.PK_TA.x, self.PK_TA.y)
        }

    def register_user(self, IDi, TR1, TR2, T_window):
        if TR2 - TR1 > T_window:
            print("[TA] Registration rejected: timestamp window exceeded.")
            return None

        xi = secrets.randbelow(self.n - 1) + 1
        Xi = self.g * xi
        di = int.from_bytes(H(IDi, Xi.x, Xi.y), 'big') % self.n
        Yi = (xi + self.M * di) % self.n

        print(f"[TA] Registered user {IDi}. Xi = {Xi}, Yi = {Yi}")
        return {
            "Yi": Yi,
            "Xi": Xi
        }


# User Agent (UA)

class User:
    def __init__(self, IDi, TA: TrustedAuthority):
        self.IDi = IDi
        self.TA = TA
        self.curve = TA.curve
        self.g = TA.g
        self.n = TA.n
        self.PK_TA = TA.PK_TA

    def send_registration_request(self):
        self.TR1 = int(time.time())
        return self.IDi, self.TR1

    def receive_registration_reply(self, reply, TR2):
        Yi = reply["Yi"]
        Xi = reply["Xi"]
        di = int.from_bytes(H(self.IDi, Xi.x, Xi.y), 'big') % self.n

        left = self.g * Yi
        right = Xi + (self.PK_TA * di)

        if left == right:
            print(f"[User:{self.IDi}] Registration verified.")
            self.PKi = left
            return True
        else:
            print(f"[User:{self.IDi}] Registration FAILED.")
            return False

    def initiate_login(self, XA, PKB, IDB):
        print(f"[UA] Initiating login to {IDB}")
        self.XA = XA
        self.PKB = PKB
        self.IDB = IDB

        self.a = secrets.randbelow(self.n - 1) + 1
        self.MA = self.a * self.g
        T1 = int(time.time())

        print(f"[UA] Random a = {self.a}")

        hval = H(point_to_bytes(self.MA), point_to_bytes(self.PK_TA), T1)
        IDA_bytes = self.IDi.encode().ljust(len(hval), b'\x00')
        IDA1 = xor_bytes(IDA_bytes, hval)

        K1 = H(point_to_bytes(self.PKB), point_to_bytes(self.XA), T1)[:16]
        plaintext = IDA1 + point_to_bytes(self.MA)
        E1 = encrypt(K1, plaintext)

        return {"E1": E1, "T1": T1}

    def finalize_login(self, E2, T3,T1):
        K2 = H(point_to_bytes(self.MA), T3, self.IDi)[:16]
        plaintext = decrypt(K2, E2)

        IDB1 = plaintext[:32]
        N = plaintext[32:64]
        MB_bytes = plaintext[64:]
        MB = bytes_to_point(MB_bytes)

        hval = H(point_to_bytes(self.MA), point_to_bytes(MB), T3)
        IDB_bytes = xor_bytes(IDB1, hval)
        IDB = IDB_bytes.decode().strip('\x00')

        print(f"[UA] Recovered peer ID: {IDB}")

        N_star = H(
            point_to_bytes(self.MA),
            point_to_bytes(MB),
            self.IDi,
            IDB,
            T1
        )

        if N_star != N:
            print("[UA] Authentication failed: N mismatch.")
            return None

        shared = (self.a * MB).x
        SKA = H(self.IDi, IDB, N, point_to_bytes(self.MA), point_to_bytes(MB), shared, T3)
        print(f"[UA] Session key established.")

        return SKA

    def initiate_key_update(self, PKB, XA):

        # UA initiates Key Update Phase.

        self.anew = secrets.randbelow(self.n - 1) + 1
        self.MnewA = self.anew * self.g
        Tnew1 = int(time.time())
        
        # Compute h(MnewA || PK_TA || Tnew1)
        hval = H(point_to_bytes(self.MnewA), point_to_bytes(self.PK_TA), Tnew1)
        
        IDA_bytes = self.IDi.encode().ljust(len(hval), b'\x00')
        IDnewA1 = xor_bytes(IDA_bytes, hval)
        
        Knew1 = H(point_to_bytes(PKB), point_to_bytes(XA), Tnew1)[:16]
        
        plaintext = IDnewA1 + point_to_bytes(self.MnewA)
        Enew1 = encrypt(Knew1, plaintext)
        
        print(f"[UA] Initiating Key Update, anew = {self.anew}")
        return {
            "Enew1": Enew1,
            "Tnew1": Tnew1
        }


    def finalize_key_update(self, msg, Tnew1):
        Enew2 = msg["Enew2"]
        Tnew3 = msg["Tnew3"]
        Tnew4 = int(time.time())
        
        if Tnew4 - Tnew3 > 30:
            print("[UA] Key update rejected: timestamp expired.")
            return None
    
        IDstar_newA = self.IDi  # From the previous session or as returned by UB
    
        Knew2 = H(Tnew3, IDstar_newA)[:16]
        plaintext = decrypt(Knew2, Enew2)
    
        IDnewB1 = plaintext[:32]
        Nnew = plaintext[32:64]
        MnewB_bytes = plaintext[64:]
        MnewB = bytes_to_point(MnewB_bytes)
    
        hval2 = H(point_to_bytes(self.MnewA), point_to_bytes(MnewB), Tnew3)
        IDnewB_bytes = xor_bytes(IDnewB1, hval2)
        IDstar_newB = IDnewB_bytes.decode().strip("\x00")
    
        Nstar_new = H(
            point_to_bytes(self.MnewA),
            point_to_bytes(MnewB),
            self.IDi,
            IDstar_newB,
            Tnew1
        )
        
        if Nstar_new != Nnew:
            print("[UA] Key update failed: N mismatch.")
            return None
    
        shared_point = self.anew * MnewB
        SKnewA = H(
            self.IDi,
            IDstar_newB,
            Nstar_new,
            point_to_bytes(self.MnewA),
            point_to_bytes(MnewB),
            shared_point.x,
            Tnew3
        )
        
        print(f"[UA] Key update completed. New session key established.")
        return SKnewA



# User Device (UB)

class UserDevice:
    def __init__(self, IDB, PKB, XA):
        self.IDB = IDB
        self.PKB = PKB
        self.XA = XA
        self.curve = curve
        self.g = curve.g
        self.n = curve.field.n

    def process_login(self, msg):
        E1 = msg["E1"]
        T1 = msg["T1"]

        T2 = int(time.time())
        if T2 - T1 > 30:
            print("[UB] Login rejected: timestamp expired.")
            return None

        K1 = H(point_to_bytes(self.PKB), point_to_bytes(self.XA), T1)[:16]
        plaintext = decrypt(K1, E1)

        IDA1 = plaintext[:32]
        MA_bytes = plaintext[32:]
        MA = bytes_to_point(MA_bytes)

        hval = H(MA_bytes, point_to_bytes(self.PKB), T1)
        IDA_bytes = xor_bytes(IDA1, hval)
        IDA = IDA_bytes.decode().strip("\x00")

        print(f"[UB] Recovered peer ID: {IDA}")

        self.b = secrets.randbelow(self.n - 1) + 1
        MB = self.b * self.g
        print(f"[UB] Random b = {self.b}")

        N = H(point_to_bytes(MA), point_to_bytes(MB), IDA, self.IDB, T1)

        shared = (self.b * MA).x
        T3 = int(time.time())
        SKB = H(IDA, self.IDB, N, point_to_bytes(MA), point_to_bytes(MB), shared, T3)

        IDB1 = xor_bytes(
            self.IDB.encode().ljust(len(N), b'\x00'),
            H(point_to_bytes(MA), point_to_bytes(MB), T3)
        )

        K2 = H(point_to_bytes(MA), T3, IDA)[:16]
        plaintext = IDB1 + N + point_to_bytes(MB)
        E2 = encrypt(K2, plaintext)

        print(f"[UB] Session key established.")
        return {"E2": E2, "T3": T3, "SKB": SKB}


    def process_key_update(self, msg):
        
        # UB processes Key Update request from UA.
        
        Enew1 = msg["Enew1"]
        Tnew1 = msg["Tnew1"]
        Tnew2 = int(time.time())
        
        if Tnew2 - Tnew1 > 30:
            print("[UB] Key update rejected: timestamp expired.")
            return None
    
        Knew1 = H(point_to_bytes(self.PKB), point_to_bytes(self.XA), Tnew1)[:16]
        plaintext = decrypt(Knew1, Enew1)
    
        IDnewA1 = plaintext[:32]
        MnewA_bytes = plaintext[32:]
        MnewA = bytes_to_point(MnewA_bytes)
    
        hval = H(MnewA_bytes, point_to_bytes(self.PKB), Tnew1)
        IDA_bytes = xor_bytes(IDnewA1, hval)
        IDstar_newA = IDA_bytes.decode().strip("\x00")
        
        self.bnew = secrets.randbelow(self.n - 1) + 1
        MnewB = self.bnew * self.g
        
        Tnew3 = int(time.time())
        Nnew = H(point_to_bytes(MnewA), point_to_bytes(MnewB), IDstar_newA, self.IDB, Tnew1)
        
        shared_point = self.bnew * MnewA
        SKnewB = H(
            IDstar_newA,
            self.IDB,
            Nnew,
            point_to_bytes(MnewA),
            point_to_bytes(MnewB),
            shared_point.x,
            Tnew3
        )
        
        hval2 = H(point_to_bytes(MnewA), point_to_bytes(MnewB), Tnew3)
        IDnewB1 = xor_bytes(
            self.IDB.encode().ljust(len(hval2), b'\x00'),
            hval2
        )
        
        Knew2 = H(Tnew3, IDstar_newA)[:16]
        
        plaintext2 = IDnewB1 + Nnew + point_to_bytes(MnewB)
        Enew2 = encrypt(Knew2, plaintext2)
        
        print(f"[UB] Processed Key Update, bnew = {self.bnew}")
        
        return {
            "Enew2": Enew2,
            "Tnew3": Tnew3,
            "SKnewB": SKnewB,
            "MnewB": MnewB,
            "Nnew": Nnew,
            "IDstar_newA": IDstar_newA
        }



In [4]:
import time
import tracemalloc

print("\n\n==== BENCHMARKING ====\n")

tracemalloc.start()
t0 = time.perf_counter()

# Step 1: Setup Trusted Authority (TA)
TA = TrustedAuthority()

# Step 2: Register User (UA)
UA = User("User", TA)
IDi, TR1 = UA.send_registration_request()
reply = TA.register_user(IDi, TR1, TR1 + 10, T_window=60)
UA.receive_registration_reply(reply, TR1 + 10)

# Step 3: Prepare UserDevice (UB) with TA public key
PKB = TA.PK_TA      # use TA public key as UB's public key
XA = UA.g * secrets.randbelow(UA.n - 1)
UB = UserDevice("UserDevice", PKB, XA)

# Step 4: UA initiates login
msg1 = UA.initiate_login(XA, PKB, "UserDevice")

# Step 5: UB processes the login message
msg2 = UB.process_login(msg1)

# Step 6: UA finalizes login
SKA = UA.finalize_login(msg2["E2"], msg2["T3"], msg1["T1"])
SKB = msg2["SKB"]

# Step 7: Output result
print(f"Session Key (UA): {SKA.hex()}")
print(f"Session Key (UB): {SKB.hex()}")
print("Keys match:", SKA == SKB)


# UA initiates key update
msg_update_1 = UA.initiate_key_update(PKB, XA)

# UB processes key update
msg_update_2 = UB.process_key_update(msg_update_1)

# UA finalizes key update
SKnewA = UA.finalize_key_update(msg_update_2, msg_update_1["Tnew1"])
SKnewB = msg_update_2["SKnewB"]

print("New Session Key (UA):", SKnewA.hex())
print("New Session Key (UB):", SKnewB.hex())
print("Keys match:", SKnewA == SKnewB)



t1 = time.perf_counter()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

print(f"\n==== BENCHMARK RESULTS ====")
print(f"Total execution time: {t1 - t0:.4f} seconds")
print(f"Peak memory usage: {peak / 1024:.2f} KB")
print(f"Current memory usage after run: {current / 1024:.2f} KB")





==== BENCHMARKING ====

[TA] Initialized. PK_TA = (113924635000372168493579011367318926725177492019983546479738305477293492862610, 64984722967561783211858998758421907360510777004522356183046091780404059283625) on "secp256r1" => y^2 = x^3 + 115792089210356248762697446949407573530086143415290314195533631308867097853948x + 41058363725152142129326129780047268409114441015993725554835256314039467401291 (mod 115792089210356248762697446949407573530086143415290314195533631308867097853951)
[TA] Registered user User. Xi = (42504485038201844400439293296210959755935356321032354830218038221972112396773, 97751789295715626529685366834677599800179737800208338403207922468342984625603) on "secp256r1" => y^2 = x^3 + 115792089210356248762697446949407573530086143415290314195533631308867097853948x + 41058363725152142129326129780047268409114441015993725554835256314039467401291 (mod 115792089210356248762697446949407573530086143415290314195533631308867097853951), Yi = 10904865404020276191293962686778871748807