# An Enhanced Authentication and Key Agreement Protocol for Smart Grid Communication

<h5>Authored By - Zewei Liu , Student Member, IEEE, Chunqiang Hu , Member, IEEE, Conghao Ruan , Student Member, IEEE,
Pengfei Hu , Member, IEEE, Meng Han , Senior Member, IEEE, and Jiguo Yu , Fellow, IEEE</h5>

This notebook implements a lightweight authentication and key agreement protocol between a **Smart Meter (SM)**, an **Energy Supplier (ES)**, and a **Key Generation Center (KGC)**.
It covers the system setup, user key generation, mutual authentication, and session key agreement using number-theoretic operations and secure hash functions.

#### Referenced Research Paper:  
[An Enhanced Authentication and Key Agreement Protocol for Smart Grid Communication](https://ieeexplore.ieee.org/document/10478494)

In [40]:
from sympy import randprime
from secrets import randbelow
from hashlib import sha256
from math import gcd


class KGC:
    def __init__(self, λ):
        self.λ = λ
        self.ListID = set()
        self.setup()

    def setup(self):
        self.p = randprime(2 ** (self.λ - 1), 2 ** self.λ)
        self.q = randprime(2 ** (self.λ - 1), 2 ** self.λ)
        self.n = self.p * self.q
        while True:
            self.g = randbelow(self.n - 2) + 2
            if gcd(self.g, self.n) == 1:
                break
        self.msk = self.p
        self.MPK = pow(self.g, self.p, self.n)

        print("[KGC] System initialized:")
        print(f"   p (msk) = {self.p}")
        print(f"   q       = {self.q}")
        print(f"   n       = {self.n}")
        print(f"   g       = {self.g}")
        print(f"   MPK     = {self.MPK}")

    def H1(self, *args):
        data = b""
        for v in args:
            if isinstance(v, str):
                data += v.encode()
            elif isinstance(v, int):
                data += v.to_bytes((v.bit_length() + 7) // 8, 'big')
            else:
                raise TypeError("Invalid H1 input")
        h = sha256(data).digest()
        return (int.from_bytes(h, 'big') % (self.n - 1)) + 1

    def H2(self, *args):
        data = b""
        for v in args:
            if isinstance(v, str):
                data += v.encode()
            elif isinstance(v, int):
                data += v.to_bytes((v.bit_length() + 7) // 8, 'big')
            else:
                raise TypeError("Invalid H2 input")
        h = sha256(data).digest()
        return (int.from_bytes(h, 'big') % (self.n - 1)) + 1

    def register_user(self, IDi, pki):
        if IDi in self.ListID:
            print(f"[KGC] ID {IDi} already registered.")
            return None
        while True:
            ri = randbelow(self.n - 2) + 2
            if gcd(ri, self.n) == 1:
                break
        Ri = pow(self.g, ri, self.n)
        hi = self.H1(IDi, pki, Ri)
        si = ri + self.p * hi
        self.ListID.add(IDi)
        print(f"[KGC] Registered user {IDi}")
        return si, Ri, hi


class User:
    def __init__(self, kgc, ID):
        self.kgc = kgc
        self.ID = ID
        self.n = kgc.n
        self.g = kgc.g
        self.MPK = kgc.MPK

    def generate_user_keys(self):
        while True:
            xi = randbelow(self.n - 2) + 2
            if gcd(xi, self.n) == 1:
                break
        Xi = pow(self.g, xi, self.n)
        result = self.kgc.register_user(self.ID, Xi)
        if result is None:
            print("[User] Registration failed.")
            return False
        si, Ri, hi = result
        self.xi = xi
        self.s_hat = si
        self.Xi = Xi
        self.Ri = Ri
        self.hi = hi
        self.SKi = (xi, si)
        self.PKi = (Xi, Ri, hi)

        while True:
            self.theta_rand = randbelow(self.n - 2) + 2
            if gcd(self.theta_rand, self.n) == 1:
                break
        self.theta_a = pow(self.g, self.theta_rand, self.n)

        print(f"[{self.ID}] User keys generated.")
        return True

    def generate_auth_message(self, ID_other, PK_other):
        while True:
            phi_a = randbelow(self.n - 2) + 2
            if gcd(phi_a, self.n) == 1:
                break
        a = pow(self.g, phi_a, self.n)
        Ha = self.kgc.H1(self.ID, self.theta_a, a)
        alpha_a = self.theta_rand + Ha * (self.xi + self.s_hat)

        print(f"[{self.ID}] Auth message generated.")
        self.last_phi = phi_a
        self.last_a = a

        return {
            'a': a,
            'alpha_a': alpha_a,
            'theta_a': self.theta_a,
            'IDa': self.ID,
            'PKa': self.PKi
        }

    def verify_auth_message(self, msg, ID_other, PK_other):
        a = msg['a']
        alpha_a = msg['alpha_a']
        theta_a = msg['theta_a']
        IDa = msg['IDa']
        PKa = msg['PKa']
        pka, p_hat_ka, ha = PKa

        Ha = self.kgc.H1(IDa, theta_a, a)
        combined = (pka * p_hat_ka * pow(self.MPK, ha, self.n)) % self.n
        right = (theta_a * pow(combined, Ha, self.n)) % self.n
        left = pow(self.g, alpha_a, self.n)

        verified = (left == right)
        print(f"[{self.ID}] Verification of {IDa} auth message: {verified}")
        return verified

    def compute_session_key(self, other_msg, ID_other, PK_other):
        other_a = other_msg["a"]
        own_phi = self.last_phi
        own_a = self.last_a

        SSK1 = pow(other_a, own_phi, self.n)

        skb = PK_other[0]
        pk_hat_b = PK_other[1]
        hb = PK_other[2]

        exponent = (self.xi + self.s_hat)
        other_a_exp = pow(other_a, exponent, self.n)
        pkb_exp = pow(skb, own_phi, self.n)
        Rb_MPKhb = (pk_hat_b * pow(self.MPK, hb, self.n)) % self.n
        Rb_MPKhb_exp = pow(Rb_MPKhb, own_phi, self.n)

        SSK2 = (other_a_exp * pkb_exp * Rb_MPKhb_exp) % self.n

        # Canonical ordering
        ID_1, ID_2 = sorted([self.ID, ID_other])
        a_1, a_2 = sorted([own_a, other_a])

        session_key = self.kgc.H2(
            ID_1, ID_2,
            a_1, a_2,
            SSK1, SSK2
        )

        print(f"[{self.ID}] Session Key = {session_key}")
        return session_key



In [43]:
import time
import tracemalloc

print("\n\n==== BENCHMARKING ====\n")
tracemalloc.start()
t0 = time.perf_counter()

λ = 134
kgc = KGC(λ)

# Create SM and ES users
sm = User(kgc, "SM001")
es = User(kgc, "ES002")

sm.generate_user_keys()
es.generate_user_keys()

# SM → ES
auth_msg_sm = sm.generate_auth_message(es.ID, es.PKi)
verified_sm = es.verify_auth_message(auth_msg_sm, sm.ID, sm.PKi)
print(f"SM→ES Auth: {verified_sm}")

# ES → SM
auth_msg_es = es.generate_auth_message(sm.ID, sm.PKi)
verified_es = sm.verify_auth_message(auth_msg_es, es.ID, es.PKi)
print(f"ES→SM Auth: {verified_es}")

# Session key computation
ssk_sm = sm.compute_session_key(auth_msg_es, es.ID, es.PKi)
ssk_es = es.compute_session_key(auth_msg_sm, sm.ID, sm.PKi)

print(f"[SM001] Session Key: {ssk_sm}")
print(f"[ES002] Session Key: {ssk_es}")
print(f"Session keys match: {ssk_sm == ssk_es}")

t1 = time.perf_counter()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

print(f"\n==== BENCHMARK RESULTS ====")
print(f"Total execution time: {t1 - t0:.4f} seconds")
print(f"Peak memory usage: {peak / 1024:.2f} KB")
print(f"Current memory usage after run: {current / 1024:.2f} KB")



==== BENCHMARKING ====

[KGC] System initialized:
   p (msk) = 18848233995731295247546671581070425043289
   q       = 12170484092486615097641148319054222198349
   n       = 229392132016513159943017963871290424730379058573326837853256806601625953069329861
   g       = 143320226100105027282395231980749843381342144024240005422257981992165182125766172
   MPK     = 58788891714756002342587926916326177321920887088198061599195581353170154825558353
[KGC] Registered user SM001
[SM001] User keys generated.
[KGC] Registered user ES002
[ES002] User keys generated.
[SM001] Auth message generated.
[ES002] Verification of SM001 auth message: True
SM→ES Auth: True
[ES002] Auth message generated.
[SM001] Verification of ES002 auth message: True
ES→SM Auth: True
[SM001] Session Key = 89296552708809283062954060844582287433605241581222968017463165905397441612914
[ES002] Session Key = 89296552708809283062954060844582287433605241581222968017463165905397441612914
[SM001] Session Key: 89296552708809283062954