# LightGrid-AKA: Lightweight Grid Authentication and Key Agreement Protocol for Smart Grid Security
Authored By - Bandi Prem Kumar, Chhaya Hirwani , Syed Taqi Ali


In [1]:
import secrets
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric.utils import (
    encode_dss_signature,
    decode_dss_signature,
)
from cryptography.exceptions import InvalidSignature
from tinyec import registry
from hashlib import sha256
import time
import tracemalloc


def H(data: bytes, q: int) -> int:
    return int.from_bytes(sha256(data).digest(), 'big') % q


class TrustedAuthority:
    def __init__(self):
        # ECC private key
        self.private_key = ec.generate_private_key(ec.SECP256R1())
        self.public_key = self.private_key.public_key()

        # curve parameters
        self.curve = registry.get_curve('secp256r1')
        self.p = self.curve.field.p
        self.a = self.curve.a
        self.b = self.curve.b
        self.P = self.curve.g
        self.q = self.curve.field.n

        # public parameters
        self.public_params = {
            "curve_eq": f"y² = x³ + {self.a}x + {self.b} mod {self.p}",
            "p": self.p,
            "q": self.q,
            "P": (self.P.x, self.P.y),
            "H": "SHA-256",
            "pk_TA": self.get_public_key_bytes(),
        }

    def get_public_params(self):
        return self.public_params

    def get_public_key_bytes(self):
        return self.public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )

    def sign_SM(self, IDi: str, Pi):

        # Generate a digital signature on (IDi || Pi.x || Pi.y)

        ID_bytes = IDi.encode()
        Px_bytes = Pi.x.to_bytes(32, 'big')
        Py_bytes = Pi.y.to_bytes(32, 'big')
        msg = ID_bytes + Px_bytes + Py_bytes

        # ECDSA sign
        signature = self.private_key.sign(
            msg,
            ec.ECDSA(hashes.SHA256())
        )
        return signature

    def verify_SM_signature(self, IDi, Pi, signature):

        # Verify the digital signature on (IDi || Pi.x || Pi.y)

        ID_bytes = IDi.encode()
        Px_bytes = Pi.x.to_bytes(32, 'big')
        Py_bytes = Pi.y.to_bytes(32, 'big')
        msg = ID_bytes + Px_bytes + Py_bytes

        try:
            self.public_key.verify(
                signature,
                msg,
                ec.ECDSA(hashes.SHA256())
            )
            print(f"[SUCCESS] Signature valid for Smart Meter {IDi}.\n")
            return True
        except InvalidSignature:
            print(f"[ERROR] Invalid signature for Smart Meter {IDi}.\n")
            return False



In [2]:
# Smart Meter 

class SmartMeter:
    def __init__(self, IDi, TA: TrustedAuthority):
        self.IDi = IDi
        self.TA = TA
        self.curve = TA.curve
        self.q = TA.q
        self.P = TA.P

        # private key
        self.xi = secrets.randbelow(self.q - 1) + 1

        # Compute public key
        self.Pi = self.xi * self.P

    def register(self):
        # Send (IDi, Pi) to TA for certificate generation
        signature = self.TA.sign_SM(self.IDi, self.Pi)

        self.CertSM = signature

        print(f"[SUCCESS] SM {self.IDi} registered.")
        print(f"Public Key Pi: ({self.Pi.x}, {self.Pi.y})")
        print(f"Certificate (CertSM): {self.CertSM.hex()}")

        return {
            "IDi": self.IDi,
            "Pi": self.Pi,
            "CertSM": signature
        }

    def initiate_authentication(self):
        # Step 1 of Authentication Phase

        # Generate ephemeral values
        self.r1 = secrets.randbelow(self.q - 1) + 1
        self.alpha = secrets.randbelow(self.q - 1) + 1

        self.R1 = self.r1 * self.P
        self.Q1 = self.alpha * self.P

        # Timestamp
        T1 = int(time.time())

        # Serialize data for hashing
        ID_bytes = self.IDi.encode()
        R1x_bytes = self.R1.x.to_bytes(32, 'big')
        R1y_bytes = self.R1.y.to_bytes(32, 'big')
        T1_bytes = T1.to_bytes(8, 'big')  # enough space for timestamp
        CertSM_bytes = self.CertSM

        data = ID_bytes + R1x_bytes + R1y_bytes + T1_bytes + CertSM_bytes

        M1 = H(data, self.q)

        print(f"[SM] Step 1 initiated by {self.IDi}")
        print(f"r1: {self.r1}")
        print(f"alpha: {self.alpha}")
        print(f"R1: ({self.R1.x}, {self.R1.y})")
        print(f"Q1: ({self.Q1.x}, {self.Q1.y})")
        print(f"T1: {T1}")
        print(f"M1: {M1}\n")

        return {
            "IDi": self.IDi,
            "R1": self.R1,
            "T1": T1,
            "M1": M1,
            "CertSM": self.CertSM,
            "Q1": self.Q1,
            "Pi": self.Pi,
        }


    def respond_to_SP(self, response_msg_from_SP, TA: TrustedAuthority):
        print(f"\n[SM] Step 3 processing response from SP...")
    
        IDj = response_msg_from_SP["IDj"]
        R2 = response_msg_from_SP["R2"]
        Q2 = response_msg_from_SP["Q2"]
        T2 = response_msg_from_SP["T2"]
        M2 = response_msg_from_SP["M2"]
        CertSP = response_msg_from_SP["CertSP"]
        Pj = response_msg_from_SP["Pj"]
    
        # Step 3-1: Verify CertSP
        is_valid = TA.verify_SM_signature(IDj, Pj, CertSP)
        if not is_valid:
            print(f"[ERROR] CertSP verification failed for SP {IDj}")
            return None
    
        print(f"[SUCCESS] CertSP verified for SP {IDj}")
    
        # Step 3-2: Check timestamp T2
        
        current_time = int(time.time())
        if abs(current_time - T2) > 60:
            print(f"[ERROR] Timestamp T2 expired or invalid.")
            return None
        print(f"[SUCCESS] Timestamp T2 valid.")
    
        # Step 3-3: Compute shared secret Z′ = r1 · R2
        Z_prime = self.r1 * R2
        print(f"[SM] Z′ computed: ({Z_prime.x}, {Z_prime.y})")
    
        # Step 3-4: Verify M2
        concat_data = (
            IDj.encode() +
            R2.x.to_bytes(32, 'big') +
            R2.y.to_bytes(32, 'big') +
            Q2.x.to_bytes(32, 'big') +
            Q2.y.to_bytes(32, 'big') +
            Z_prime.x.to_bytes(32, 'big') +
            Z_prime.y.to_bytes(32, 'big') +
            T2.to_bytes(8, 'big') +
            CertSP
        )
        M2_check = H(concat_data, self.q)
        if M2 != M2_check:
            print("[ERROR] M2 verification failed!")
            return None
        print("[SUCCESS] M2 verified successfully.")
    
        # Step 3-5: Generate T3
        T3 = int(time.time())
    
        # Step 3-6: Compute M3
        concat_m3 = (
            Z_prime.x.to_bytes(32, 'big') +
            Z_prime.y.to_bytes(32, 'big') +
            self.R1.x.to_bytes(32, 'big') +
            self.R1.y.to_bytes(32, 'big') +
            Q2.x.to_bytes(32, 'big') +
            Q2.y.to_bytes(32, 'big') +
            T3.to_bytes(8, 'big')
        )
        M3 = H(concat_m3, self.q)
    
        print(f"[SM] M3 computed: {M3}")
        print(f"[SM] Sending M3 and T3 back to SP.")
    
        return {
            "M3": M3,
            "T3": T3
        }
        

In [3]:
class ServiceProvider:
    def __init__(self, IDj, TA: TrustedAuthority):
        self.IDj = IDj
        self.TA = TA
        self.curve = TA.curve
        self.q = TA.q
        self.P = TA.P

        # Generate private key
        self.xj = secrets.randbelow(self.q - 1) + 1

        # Compute public key
        self.Pj = self.xj * self.P

    def register(self):
        # Send (IDj, Pj) to TA for certificate generation
        signature = self.TA.sign_SM(self.IDj, self.Pj)

        self.CertSP = signature

        print(f"[SUCCESS] SP {self.IDj} registered.")
        print(f"Public Key Pj: ({self.Pj.x}, {self.Pj.y})")
        print(f"Certificate (CertSP): {self.CertSP.hex()}")

        return {
            "IDj": self.IDj,
            "Pj": self.Pj,
            "CertSP": signature
        }

    def respond_to_auth_request(self, auth_msg_from_SM):
        IDi = auth_msg_from_SM["IDi"]
        Pi = auth_msg_from_SM["Pi"]
        R1 = auth_msg_from_SM["R1"]
        T1 = auth_msg_from_SM["T1"]
        M1 = auth_msg_from_SM["M1"]
        CertSM = auth_msg_from_SM["CertSM"]
        Q1 = auth_msg_from_SM["Q1"]
    
        # Verify Smart Meter's Certificate
        is_valid = self.TA.verify_SM_signature(IDi, Pi, CertSM)
        if not is_valid:
            print(f"[ERROR] CertSM verification failed for SM {IDi}")
            return None


        # Timestamp freshness check ( ±300 seconds allowed)
        current_time = int(time.time())
        if abs(current_time - T1) > 300:
            print("[ERROR] T1 timestamp is stale.")
            return None

        # Generate ephemeral keys
        self.r2 = secrets.randbelow(self.q - 1) + 1
        self.beta = secrets.randbelow(self.q - 1) + 1

        self.R2 = self.r2 * self.P
        self.Q2 = self.beta * self.P

        # Compute shared secret Z = r2 * R1
        Z = self.r2 * R1

        T2 = int(time.time())

        # Compute M2
        IDj_bytes = self.IDj.encode()
        R2x_bytes = self.R2.x.to_bytes(32, 'big')
        R2y_bytes = self.R2.y.to_bytes(32, 'big')
        Q2x_bytes = self.Q2.x.to_bytes(32, 'big')
        Q2y_bytes = self.Q2.y.to_bytes(32, 'big')
        Zx_bytes = Z.x.to_bytes(32, 'big')
        Zy_bytes = Z.y.to_bytes(32, 'big')
        T2_bytes = T2.to_bytes(8, 'big')
        CertSP_bytes = self.CertSP

        data = (
            IDj_bytes + 
            R2x_bytes + R2y_bytes + 
            Q2x_bytes + Q2y_bytes +
            Zx_bytes + Zy_bytes +
            T2_bytes +
            CertSP_bytes
        )

        M2 = H(data, self.q)

        print(f"[SP] Step 2 response prepared by {self.IDj}")
        print(f"r2: {self.r2}")
        print(f"beta: {self.beta}")
        print(f"R2: ({self.R2.x}, {self.R2.y})")
        print(f"Q2: ({self.Q2.x}, {self.Q2.y})")
        print(f"Z: ({Z.x}, {Z.y})")
        print(f"T2: {T2}")
        print(f"M2: {M2}\n")

        return {
            "IDj": self.IDj,
            "R2": self.R2,
            "Q2": self.Q2,
            "Z": Z,
            "T2": T2,
            "M2": M2,
            "CertSP": self.CertSP,
            "Pj": self.Pj,
        }


    def process_step4_from_SM(self, msg_from_SM, auth_msg_from_SM):
        print(f"\n[SP] Step 4 processing response from SM...")
    
        M3 = msg_from_SM["M3"]
        T3 = msg_from_SM["T3"]
    
        # Retrieve values from earlier messages:
        R1 = auth_msg_from_SM["R1"]
        Q2 = self.Q2
        Z = self.r2 * R1
    
        # Step 4-1: Verify timestamp T3
        current_time = int(time.time())
        if abs(current_time - T3) > 60:
            print("[ERROR] Timestamp T3 expired or invalid.")
            return None
        print("[SUCCESS] Timestamp T3 valid.")
    
        # Step 4-2: Compute M3
        concat_m3 = (
            Z.x.to_bytes(32, 'big') +
            Z.y.to_bytes(32, 'big') +
            R1.x.to_bytes(32, 'big') +
            R1.y.to_bytes(32, 'big') +
            Q2.x.to_bytes(32, 'big') +
            Q2.y.to_bytes(32, 'big') +
            T3.to_bytes(8, 'big')
        )
        M3_check = H(concat_m3, self.q)
    
        if M3 != M3_check:
            print("[ERROR] M3 verification failed!")
            return None
        print("[SUCCESS] M3 verified successfully.")
    
        # Step 4-3: Derive session key
        SK = sha256(
            Z.x.to_bytes(32, 'big') +
            Z.y.to_bytes(32, 'big') +
            b"session key derivation"
        ).hexdigest()
    
        print(f"[SP] Session key established: {SK}\n")
        return SK


In [4]:
print("\n\n==== BENCHMARKING ====\n")

tracemalloc.start()
t0 = time.perf_counter()

TA = TrustedAuthority()

print("Public Parameters:")
for k, v in TA.get_public_params().items():
    print(f"{k}: {v}")
print("\n")

# Create Smart Meter
SM = SmartMeter("SM-001", TA)
reg_data = SM.register()
TA.verify_SM_signature(reg_data["IDi"], reg_data["Pi"], reg_data["CertSM"])

SP1 = ServiceProvider("SP1", TA)
sp_info = SP1.register()
TA.verify_SM_signature(sp_info["IDj"], sp_info["Pj"], sp_info["CertSP"])

#step1 - inititate auth
auth_msg = SM.initiate_authentication()

#step2- SP responds
sp_response = SP1.respond_to_auth_request(auth_msg)

# step3 - SM response to SP
sm_response = SM.respond_to_SP(sp_response, TA)

# SP processes SM’s message in Step 4
session_key = SP1.process_step4_from_SM(sm_response, auth_msg)



t1 = time.perf_counter()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

print(f"\n==== BENCHMARK RESULTS ====")
print(f"Total execution time: {t1 - t0:.4f} seconds")
print(f"Peak memory usage: {peak / 1024:.2f} KB")
print(f"Current memory usage after run: {current / 1024:.2f} KB")




==== BENCHMARKING ====

Public Parameters:
curve_eq: y² = x³ + 115792089210356248762697446949407573530086143415290314195533631308867097853948x + 41058363725152142129326129780047268409114441015993725554835256314039467401291 mod 115792089210356248762697446949407573530086143415290314195533631308867097853951
p: 115792089210356248762697446949407573530086143415290314195533631308867097853951
q: 115792089210356248762697446949407573529996955224135760342422259061068512044369
P: (48439561293906451759052585252797914202762949526041747995844080717082404635286, 36134250956749795798585127919587881956611106672985015071877198253568414405109)
H: SHA-256
pk_TA: b'-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE9kX4jhiu/ppOC23vss7Qr2zzRVdh\ndhBv3WDH/AZNDSqaxdshJPSFLYHFcGFCegx208JzRvlUYrWsyeJ0QDVosA==\n-----END PUBLIC KEY-----\n'


[SUCCESS] SM SM-001 registered.
Public Key Pi: (34175771300690775079195927236487635065958093114685302457528746601947208154533, 705569845473061976061043674513245