# LAKAF: Lightweight authentication and key agreement framework for smart grid network

<h5>Authored By - Akber Ali Khan , Vinod Kumar, Musheer Ahmad, Saurabh Rana</h5>

This notebook implements a lightweight authentication and key agreement protocol between a User (U), Service Provider (S), and Trusted Authority (TA).  
It simulates the **initialization phase, user and service provider registration phase, and login authentication and key agreement phase** using hashing, AES encryption, and XOR operations.

#### Referenced Research Paper:  
[LAKAF: Lightweight authentication and key agreement framework for smart grid network](https://doi.org/10.1016/j.sysarc.2021.102053)

In [7]:
import time
import hashlib
import secrets
from tinyec import registry, ec
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes


# Section 3.1: Initialization Phase


curve = registry.get_curve('secp192r1')

q = curve.field.p
g = curve.g

def h(data: bytes) -> bytes:
    return hashlib.sha256(data).digest()

# Fuzzy extractor functions 
def Gen(biometric_data: bytes):
    sigma = h(biometric_data)[:16]
    tau = secrets.token_bytes(16)
    return sigma, tau

def Rep(biometric_data: bytes, tau: bytes):
    sigma = h(biometric_data)[:16]
    return sigma

def aes_encrypt(key: bytes, plaintext: bytes) -> bytes:
    nonce = get_random_bytes(12)
    cipher = AES.new(key[:32], AES.MODE_GCM, nonce=nonce)
    ciphertext, tag = cipher.encrypt_and_digest(plaintext)
    return nonce + tag + ciphertext

def aes_decrypt(key: bytes, ciphertext_bundle: bytes) -> bytes:
    nonce = ciphertext_bundle[:12]
    tag = ciphertext_bundle[12:28]
    ciphertext = ciphertext_bundle[28:]
    cipher = AES.new(key[:32], AES.MODE_GCM, nonce=nonce)
    plaintext = cipher.decrypt_and_verify(ciphertext, tag)
    return plaintext


# Trusted Authority Class (TA)


class TrustedAuthority:
    def __init__(self, curve):
        self._curve = curve
        self._q = curve.field.p
        self._g = curve.g
        self._x = secrets.randbelow(self._q - 1) + 1 
        self._P_pub = self._x * self._g
        self._user_db = {}
        self._sp_db = {}

    def get_public_params(self):
        return {
            "curve_eq": f"w² = v³ + {self._curve.a}v + {self._curve.b} mod {self._q}",
            "q": self._q,
            "g": (self._g.x, self._g.y),
            "hash_func": "SHA-256",
            "P_pub": (self._P_pub.x, self._P_pub.y)
        }


    def register_user(self, IDU, gammaU):
        print("\n[Step 2 - TA Processing User Registration]")
        CU = self._user_db[IDU]["CU"] + 1 if IDU in self._user_db else 1
        print(f"  (2.1) Retrieved/initialized counter CU = {CU}")

        A = h(IDU.encode() +
              self._x.to_bytes((self._x.bit_length() + 7) // 8, 'big') +
              CU.to_bytes(4, 'big'))
        print(f"  (2.2) Computed A = h(IDU || x || CU) → {A.hex()}")

        beta = bytes(a ^ b for a, b in zip(A, gammaU))
        print(f"  (2.3) Computed β = A ⊕ γU → {beta.hex()}")

        self._user_db[IDU] = {"CU": CU, "beta": beta}
        print(f"  (2.4) Stored {{IDU, β, CU}} in TA's database.")
        return beta, CU

    def register_service_provider(self, IDS):
        print("\n[Step 2 - TA Processing SP Registration]")
        CS = self._sp_db[IDS]["CS"] + 1 if IDS in self._sp_db else 1
        print(f"  (2.1) Retrieved/initialized counter CS = {CS}")

        xi = h(IDS.encode() +
               self._x.to_bytes((self._x.bit_length() + 7) // 8, 'big') +
               CS.to_bytes(4, 'big'))
        print(f"  (2.2) Computed ξ = h(IDS || x || CS) → {xi.hex()}")

        self._sp_db[IDS] = {"CS": CS, "xi": xi}
        print(f"  (2.3) Stored {{IDS, ξ, CS}} in TA's SP database.")
        return xi, CS


    def process_user_credential_update(self, IDU, gammaU):
        print("\n[Change PW/Biometrics - TA Processing Request]")
        if IDU not in self._user_db:
            print("  ❌ User not registered.")
            return None

        CU = self._user_db[IDU]["CU"]
        print(f"  (2.1) Retrieved CU = {CU}")

        A = h(IDU.encode() +
              self._x.to_bytes((self._x.bit_length() + 7) // 8, 'big') +
              CU.to_bytes(4, 'big'))
        print(f"  (2.2) Computed A = h(IDU || x || CU) → {A.hex()}")

        beta_new = bytes(a ^ b for a, b in zip(A, gammaU))
        print(f"  (2.3) Computed β_new = A ⊕ γ_new_U → {beta_new.hex()}")

        # Update database with new β
        self._user_db[IDU]["beta"] = beta_new
        print(f"  (2.4) Updated β in TA's user DB.")

        return beta_new, CU


# Service Provider Class (S)

class ServiceProvider:
    def __init__(self, IDS, curve):
        self.__IDS = IDS
        self.__curve = curve
        self.__q = curve.field.p
        self.__g = curve.g
        self.__rS = None
        self.__PKS = None
        self.__xi = None
        self.__CS = None

    def initiate_registration(self):
        print(f"[Step 1 - SP Initiates Registration]")
        print(f"  (1.1) Sending IDS = {self.__IDS}")
        return self.__IDS

    def complete_registration(self, xi, CS):
        print("\n[Step 3 - SP Completes Registration]")
        self.__rS = secrets.randbelow(self.__q - 1) + 1
        print(f"  (3.1) Chose random rS = {self.__rS}")
        self.__PKS = self.__rS * self.__g
        print(f"  (3.2) Computed PKS = rS * g → ({self.__PKS.x}, {self.__PKS.y})")
        self.__xi = xi
        self.__CS = CS
        print(f"  (3.3) Stored ξ and CS locally.")

    def export_public_info(self):
        return {
            "IDS": self.__IDS,
            "PKS": (self.__PKS.x, self.__PKS.y) if self.__PKS else None,
            "CS": self.__CS,
            "xi": self.__xi.hex() if self.__xi else None
        }

    def verify_user(self, M1, TA_xi, K1_user):
        print("\n[LOGIN PHASE - SP PROCESSING]")
        E1, I1, t1 = M1
        t2 = int(time.time())
        if abs(t2 - t1) > 60:
            print("\n  --- Timestamp expired. Abort session.\n")
            return None

        # Recompute I2
        h_t1 = h(t1.to_bytes(8, 'big'))
        I2 = bytes(a ^ b for a, b in zip(I1, h_t1))
        IDU_star = I2.decode(errors='ignore')
        print(f"  (2.1) Recovered IDU* = {IDU_star}")

        rSg = self.__rS * self.__g
        rSg_bytes = rSg.x.to_bytes(24, 'big') + rSg.y.to_bytes(24, 'big')
        t1_bytes = t1.to_bytes(8, 'big')
        t1_xor_rSg = bytes(a ^ b for a, b in zip(t1_bytes.ljust(len(rSg_bytes), b'\x00'), rSg_bytes))
        K2 = h(I2 + h(t1_xor_rSg) + t1_bytes)
        print(f"  (2.2) Computed K2 → {K2.hex()}")

        plaintext_E1 = aes_decrypt(K1_user, E1)
        a = int.from_bytes(plaintext_E1[:24], 'big')
        S1 = plaintext_E1[24:56]
        CU = int.from_bytes(plaintext_E1[56:60], 'big')
        print(f"  (2.3) Decrypted a={a}, CU={CU}")

        expected_S1 = h(IDU_star.encode() + a.to_bytes((a.bit_length() + 7)//8, 'big') + CU.to_bytes(4, 'big'))
        if S1 != expected_S1:
            print(" \n---- S1 mismatch. Abort.\n")
            return None

        print(f"\n User verified successfully.")

        b = secrets.randbelow(self.__q - 1) + 1
        print(f"  (2.4) Chose random b = {b}")

        ab_point = a * (b * self.__g)
        ab_bytes = ab_point.x.to_bytes(24, 'big') + ab_point.y.to_bytes(24, 'big')
        t3 = int(time.time())
        SKS = h(IDU_star.encode() + self.__IDS.encode() +
                CU.to_bytes(4, 'big') + self.__CS.to_bytes(4, 'big') +
                ab_bytes + t3.to_bytes(8, 'big'))
        print(f"  (2.5) Computed session key SKS → {SKS.hex()}")

        S2 = h(self.__IDS.encode() + IDU_star.encode() + S1 + TA_xi + SKS + t1_bytes)
        xi1 = bytes(a ^ b for a, b in zip(TA_xi, h(CU.to_bytes(4,'big') + IDU_star.encode() + K2)))
        eta = bytes(a ^ b for a, b in zip(self.__IDS.encode(), h(b.to_bytes(24, 'big') + self.__CS.to_bytes(4,'big') + CU.to_bytes(4,'big'))))
        K3 = h(IDU_star.encode() + S1 + a.to_bytes(24,'big') + CU.to_bytes(4,'big') + t3.to_bytes(8,'big'))

        plaintext_E2 = eta + xi1 + self.__CS.to_bytes(4,'big') + b.to_bytes(24,'big')
        E2 = aes_encrypt(K3, plaintext_E2)
        M2 = (E2, t3)
        return M2, SKS



# User Class


class User:
    def __init__(self, IDU, PWU, BU, curve):
        self.__IDU = IDU
        self.__PWU = PWU
        self.__BU = BU
        self.__q = curve.field.p

    def initiate_registration(self):
        print("[Step 1 - User Initiates Registration]")
        self.__rU = secrets.randbelow(self.__q - 1) + 1
        print(f"  (1.1) Chose random rU = {self.__rU}")

        self.__sigmaU, self.__tauU = Gen(self.__BU)
        print(f"  (1.2) Ran fuzzy extractor Gen(BU): σU = {self.__sigmaU.hex()}, τU = {self.__tauU.hex()}")

        hash_pw_sigma = h(self.__PWU.encode() + self.__sigmaU)
        print(f"  (1.3) Computed h(PWU || σU) = {hash_pw_sigma.hex()}")

        rU_bytes = self.__rU.to_bytes((self.__rU.bit_length() + 7) // 8, 'big')
        gammaU = bytes(a ^ b for a, b in zip(hash_pw_sigma, rU_bytes.ljust(len(hash_pw_sigma), b'\x00')))
        print(f"  (1.4) Computed γU = h(PWU || σU) ⊕ rU → {gammaU.hex()}")

        self.__gammaU = gammaU
        return self.__IDU, gammaU

    def complete_registration(self, beta, CU):
        print("\n[Step 3 - User Completes Local Registration]")
        self.__beta = beta
        self.__CU = CU
        self.__beta1 = bytes(a ^ b for a, b in zip(beta, self.__sigmaU))
        print(f"  (3.1) Computed β1 = β ⊕ σU → {self.__beta1.hex()}")
        self.__beta2 = h(self.__IDU.encode() + self.__PWU.encode() + self.__beta1)
        print(f"  (3.2) Computed β2 = h(IDU || PWU || β1) → {self.__beta2.hex()}")
        print(f"  (3.3) Stored {{τU, β, β1, β2, CU}} locally.")

    def login(self, sp_public_info, PWU_star=None, BU_star=None):
        print("\n[LOGIN PHASE - USER INITIATION]")

        IDU_star = self.__IDU
        PWU_star = PWU_star if PWU_star is not None else self.__PWU
        BU_star = BU_star if BU_star is not None else self.__BU

        sigma_star = Rep(BU_star, self.__tauU)
        print(f"  (1.1) Reproduced σ*U = {sigma_star.hex()}")

        beta1_star = bytes(a ^ b for a, b in zip(self.__beta, sigma_star))
        print(f"  (1.2) Computed β1* = β ⊕ σ*U → {beta1_star.hex()}")

        beta2_star = h(IDU_star.encode() + PWU_star.encode() + beta1_star)
        print(f"  (1.3) Computed β2* = h(IDU || PWU || β1*) → {beta2_star.hex()}")
        print(f"  (Stored β2 = {self.__beta2.hex()})")

        if beta2_star != self.__beta2:
            print("  ❌ Local authentication failed!")
            return None

        print(f"\n Local authentication passed.")

        a = secrets.randbelow(self.__q - 1) + 1
        print(f"  (1.4) Chose random a = {a}")

        S1 = h(IDU_star.encode() + a.to_bytes((a.bit_length() + 7)//8, 'big') + self.__CU.to_bytes(4, 'big'))
        print(f"  (1.5) Computed S1 = h(IDU || a || CU) → {S1.hex()}")

        t1 = int(time.time())
        t1_bytes = t1.to_bytes(8, 'big')
        h_t1 = h(t1_bytes)
        I1 = bytes(a ^ b for a, b in zip(IDU_star.encode(), h_t1))
        print(f"  (1.6) Computed I1 = IDU ⊕ h(t1) → {I1.hex()}")

        PKS_x, PKS_y = sp_public_info["PKS"]
        PKS_point = ec.Point(curve, PKS_x, PKS_y)
        PKS_bytes = PKS_point.x.to_bytes(24, 'big') + PKS_point.y.to_bytes(24, 'big')
        t1_xor_PKS = bytes(a ^ b for a, b in zip(t1_bytes.ljust(len(PKS_bytes), b'\x00'), PKS_bytes))
        K1 = h(IDU_star.encode() + h(t1_xor_PKS) + t1_bytes)
        print(f"  (1.7) Computed K1 → {K1.hex()}")

        a_bytes = a.to_bytes(24, 'big')
        CU_bytes = self.__CU.to_bytes(4, 'big')
        plaintext_E1 = a_bytes + S1 + CU_bytes
        E1 = aes_encrypt(K1, plaintext_E1)
        print(f"  (1.8) Encrypted E1 → {E1.hex()}")

        M1 = (E1, I1, t1)
        return M1, K1, a, S1

    def process_sp_response(self, M2, K1, a, S1, sp_info):
        print("\n[LOGIN PHASE - USER FINALIZATION]")
        E2, t3 = M2
        t4 = int(time.time())
        if abs(t4 - t3) > 60:
            print("  ❌ Timestamp expired!")
            return None

        K4 = h(
            self.__IDU.encode()
            + S1
            + a.to_bytes(24, 'big')
            + self.__CU.to_bytes(4, 'big')
            + t3.to_bytes(8, 'big')
        )
        plaintext_E2 = aes_decrypt(K4, E2)

        IDS_value = sp_info["IDS"]
        IDS_len = len(IDS_value.encode())

        eta = plaintext_E2[:IDS_len]
        xi1 = plaintext_E2[IDS_len:IDS_len + 32]
        CS = int.from_bytes(plaintext_E2[IDS_len + 32 : IDS_len + 36], 'big')
        b = int.from_bytes(plaintext_E2[-24:], 'big')

        h_eta = h(
            b.to_bytes(24, 'big')
            + CS.to_bytes(4, 'big')
            + self.__CU.to_bytes(4, 'big')
        )

        recovered_IDS_bytes = bytes(
            a ^ b for a, b in zip(eta, h_eta)
        )

        recovered_IDS = recovered_IDS_bytes.decode(errors='ignore')
        print(f"  (3.1) Recovered IDS = {recovered_IDS}")

        ab_point = b * (a * g)
        ab_bytes = ab_point.x.to_bytes(24, 'big') + ab_point.y.to_bytes(24, 'big')
        SKU = h(
            self.__IDU.encode()
            + recovered_IDS.encode()
            + self.__CU.to_bytes(4, 'big')
            + CS.to_bytes(4, 'big')
            + ab_bytes
            + t3.to_bytes(8, 'big')
        )
        print(f"  (3.2) Computed session key SKU → {SKU.hex()}")
        return SKU

    def change_password_and_biometric(self, PWU_star, BU_star, PW_new_U, BU_new_U, TA):
        print("\n[CHANGE PASSWORD / BIOMETRIC PHASE]")
        sigma_star = Rep(BU_star, self.__tauU)
        beta1_star = bytes(a ^ b for a, b in zip(self.__beta, sigma_star))
        beta2_star = h(self.__IDU.encode() + PWU_star.encode() + beta1_star)

        if beta2_star != self.__beta2:
            print("  ❌ Local authentication failed during change process.")
            return None

        print(f"\nLocal authentication successful for changing credentials.")

        sigma_new_U, tau_new_U = Gen(BU_new_U)
        print(f"  (1) Generated new σ_new_U = {sigma_new_U.hex()}, τ_new_U = {tau_new_U.hex()}")

        hash_pw_sigma_new = h(PW_new_U.encode() + sigma_new_U)
        rU_bytes = self.__rU.to_bytes((self.__rU.bit_length() + 7) // 8, 'big')
        gamma_new_U = bytes(a ^ b for a, b in zip(hash_pw_sigma_new, rU_bytes.ljust(len(hash_pw_sigma_new), b'\x00')))
        print(f"  (2) Computed γ_new_U = h(PW_new_U || σ_new_U) ⊕ rU → {gamma_new_U.hex()}")

        beta_new, CU = TA.process_user_credential_update(self.__IDU, gamma_new_U)

        beta_new_1 = bytes(a ^ b for a, b in zip(beta_new, sigma_new_U))
        beta_new_2 = h(self.__IDU.encode() + PW_new_U.encode() + beta_new_1)

        self.__PWU = PW_new_U
        self.__BU = BU_new_U
        self.__beta = beta_new
        self.__beta1 = beta_new_1
        self.__beta2 = beta_new_2
        self.__sigmaU = sigma_new_U
        self.__tauU = tau_new_U

        print("  (3) Updated local credentials:")
        print(f"      β_new = {beta_new.hex()}")
        print(f"      β_new_1 = {beta_new_1.hex()}")
        print(f"      β_new_2 = {beta_new_2.hex()}")
        print(f"      σ_new_U = {sigma_new_U.hex()}")
        print(f"      τ_new_U = {tau_new_U.hex()}")

        return True

    def dynamic_device_login(self, IDU_star, PWU_star, BU_star, sp_public_info):
        print("\n[DYNAMIC DEVICE LOGIN PHASE]")
        sigma_star = Rep(BU_star, self.__tauU)
        print(f"  σ*U = {sigma_star.hex()}")

        beta1_star = bytes(a ^ b for a, b in zip(self.__beta, sigma_star))
        print(f"  β1* = {beta1_star.hex()}")

        beta2_star = h(IDU_star.encode() + PWU_star.encode() + beta1_star)
        print(f"  β2* = {beta2_star.hex()}")
        print(f"  Stored β2 = {self.__beta2.hex()}")

        if beta2_star != self.__beta2:
            print("  ❌ Local auth failed during dynamic device login.")
            return None

        print(f"\nLocal auth successful on new device.")
        return self.login(sp_public_info, PWU_star, BU_star)



# Simulation


import time
import tracemalloc

print("\n\n==== BENCHMARKING ====\n")

tracemalloc.start()
t0 = time.perf_counter()


# ------- Initialization Phase -------
TA = TrustedAuthority(curve)
public_params = TA.get_public_params()
print(f"\n",public_params,"\n")

IDU = "User001"
PWU = "mypassword123"
BU = b"MyFingerprintScanData"

user = User(IDU, PWU, BU, curve)
IDU_sent, gammaU_sent = user.initiate_registration()
beta, CU = TA.register_user(IDU_sent, gammaU_sent)
user.complete_registration(beta, CU)

IDS = "ServiceProvider001"
sp = ServiceProvider(IDS, curve)
IDS_sent = sp.initiate_registration()
xi, CS = TA.register_service_provider(IDS_sent)
sp.complete_registration(xi, CS)

# ------- Initial Login -------
M1, K1, a, S1 = user.login(sp.export_public_info())
M2, SKS = sp.verify_user(M1, xi, K1)
SKU = user.process_sp_response(M2, K1, a, S1, sp.export_public_info())

print(f"\n User and SP session keys:")
print(f"SKS = {SKS.hex()}")
print(f"SKU = {SKU.hex()}")
print(f"Session keys match: {SKS == SKU}")

# ------- Change Password/Biometrics -------
PWU_star = "mypassword123"
BU_star = b"MyFingerprintScanData"
PW_new_U = "myNEWpassword456"
BU_new_U = b"MyNEWFingerprintScanData"

user.change_password_and_biometric(
    PWU_star, BU_star, PW_new_U, BU_new_U, TA
)

# ------- Dynamic Device Login -------
IDU_star = "User001"
PWU_star = "myNEWpassword456"
BU_star = b"MyNEWFingerprintScanData"



M1_dyn, K1_dyn, a_dyn, S1_dyn = user.dynamic_device_login(
    IDU_star, PWU_star, BU_star, sp.export_public_info()
)

M2_dyn, SKS_dyn = sp.verify_user(M1_dyn, xi, K1_dyn)
SKU_dyn = user.process_sp_response(
    M2_dyn, K1_dyn, a_dyn, S1_dyn, sp.export_public_info()
)

print(f"\n Dynamic Device Session Keys:")
print("SKS =", SKS_dyn.hex())
print("SKU =", SKU_dyn.hex())
print("Session keys match:", SKS_dyn == SKU_dyn)


t1 = time.perf_counter()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

print(f"\n==== BENCHMARK RESULTS ====")
print(f"Total execution time: {t1 - t0:.4f} seconds")
print(f"Peak memory usage: {peak / 1024:.2f} KB")
print(f"Current memory usage after run: {current / 1024:.2f} KB")






==== BENCHMARKING ====


 {'curve_eq': 'w² = v³ + 6277101735386680763835789423207666416083908700390324961276v + 2455155546008943817740293915197451784769108058161191238065 mod 6277101735386680763835789423207666416083908700390324961279', 'q': 6277101735386680763835789423207666416083908700390324961279, 'g': (602046282375688656758213480587526111916698976636884684818, 174050332293622031404857552280219410364023488927386650641), 'hash_func': 'SHA-256', 'P_pub': (2804158525071983397232857813508236121422805956824971701294, 4385352266885458267816593232874247373722639739024369716507)} 

[Step 1 - User Initiates Registration]
  (1.1) Chose random rU = 2229056439814058037861717518702924202815641636162178766751
  (1.2) Ran fuzzy extractor Gen(BU): σU = 016cd3137b7bd32773d6fb516e831cbe, τU = 9ef81c944b3a54c43638b7d366cf5465
  (1.3) Computed h(PWU || σU) = 1a4c5a25908d2c67aa7cf698d07d00104ec32b2586c443cb594f5d806e676cda
  (1.4) Computed γU = h(PWU || σU) ⊕ rU → 40a435460e143e8016c07c26c4590c7ec56ace