# Modules  importation

In [22]:
import os
import random
import time
from hashlib import shake_256
from sage.all import (
    EllipticCurve, GF, ZZ, Zmod, matrix, vector, QuaternionAlgebra,
    randint, is_prime, gcd, floor, ceil, log, inverse_mod, Integer
)



from montgomery_isogenies.kummer_line import KummerLine
from montgomery_isogenies.kummer_isogeny import KummerLineIsogeny
from montgomery_isogenies.isogenies_x_only import lift_image_to_curve
from theta_structures.couple_point import CouplePoint 
from theta_isogenies.product_isogeny import EllipticProductIsogeny 
from utilities.utilities1.discrete_log import BiDLP, discrete_log_pari


from precomputations import Precomputations 
from sqisign2d_west.ideal_to_isogeny_clapotis import IdealToIsogenyClapotis 
from quaternions.ideals import pushforward_ideal 
from sqisign2d_west.dim2_wrapper import Dim2Isogeny 


from sqisign2d_west.endomorphisms import FullrepresentInteger, iota
from utilities.supersingular import torsion_basis_with_pairing_2e
from utilities.discrete_log import discrete_log_pair_power_two
from utilities.pairing import (
    weil_pairing_pari, tate_pairing_pari, weil_pairing_biextension,
    tate_pairing_biextension
)
from utilities.utilities1.strategy import optimised_strategy

import os
import random
import time
from hashlib import shake_256
from sage.all import (
    EllipticCurve, GF, ZZ, Zmod, matrix, vector, QuaternionAlgebra,
    randint, is_prime, gcd, floor, ceil, log, inverse_mod, Integer
)



# Additional imports provided by the user (kept for completeness, though not all are used in this specific OWF implementation)
from sqisign2d_west.endomorphisms import FullrepresentInteger, iota
from utilities.supersingular import torsion_basis_with_pairing_2e
from utilities.discrete_log import discrete_log_pair_power_two
from utilities.pairing import (
    weil_pairing_pari, tate_pairing_pari, weil_pairing_biextension,
    tate_pairing_biextension
)
from utilities.utilities1.strategy import optimised_strategy



# POKE-KEM

In [22]:
# import os
# import random
# import time
# from hashlib import shake_256
# import secrets
# from sage.structure.element import Element

# from montgomery_isogenies.kummer_line import KummerLine
# from montgomery_isogenies.kummer_isogeny import KummerLineIsogeny
# from montgomery_isogenies.isogenies_x_only import lift_image_to_curve

# from theta_structures.couple_point import CouplePoint
# from theta_isogenies.product_isogeny import EllipticProductIsogeny

# from utilities.discrete_log import BiDLP, discrete_log_pari
# from utilities.supersingular import torsion_basis, torsion_basis_with_pairing

proof.all(False)

# --- Helper functions ---

def random_unit_with_rng(modulus, rng):
    """Generates a random unit modulo `modulus` using the provided RNG."""
    while True:
        alpha = rng.randint(0, modulus - 1)
        if gcd(alpha, modulus) == 1:
            break
    return alpha

def random_matrix_with_rng(modulus, rng):
    """Generates a random invertible 2x2 matrix modulo `modulus` using the provided RNG."""
    while True:
        d1 = rng.randint(0, modulus - 1)
        d2 = rng.randint(0, modulus - 1)
        d3 = rng.randint(0, modulus - 1)
        d4 = rng.randint(0, modulus - 1)
        if gcd(d1 * d4 - d2 * d3, modulus) == 1:
            break
    return d1, d2, d3, d4

def eval_dimtwo_isog(Phi, q, P, Q, ord, RS=None):
    R1 = P
    R2 = Phi.domain()[1](0)
    phiP = Phi(CouplePoint(R1, R2))
    imP = phiP[0]

    R1 = Q
    R2 = Phi.domain()[1](0)
    phiQ = Phi(CouplePoint(R1, R2))
    imQ = phiQ[0]

    R1 = P - Q
    R2 = Phi.domain()[1](0)
    phiPQ = Phi(CouplePoint(R1, R2))
    imPQ = phiPQ[0]

    if (imP - imQ)[0] != imPQ[0]:
        imQ = -imQ

    if RS == None:
        R, S, WP = torsion_basis_with_pairing(Phi.domain()[1], ord)
    else:
        R, S = RS
        WP = R.weil_pairing(S, ord)

    R1 = Phi.domain()[0](0)
    R2 = R
    phiR = Phi(CouplePoint(R1, R2))
    imR = phiR[0]

    R1 = Phi.domain()[0](0)
    R2 = S
    phiS = Phi(CouplePoint(R1, R2))
    imS = phiS[0]

    R1 = Phi.domain()[0](0)
    R2 = R - S
    phiRS = Phi(CouplePoint(R1, R2))
    imRS = phiRS[0]

    if (imR - imS)[0] != imRS[0]:
        imS = -imS

    wp = WP**q
    x, y = BiDLP(imP, imR, imS, ord, ePQ=wp)
    w, z = BiDLP(imQ, imR, imS, ord, ePQ=wp)

    imP = x * R + y * S
    imQ = w * R + z * S

    imP *= q
    imQ *= q

    return imP, imQ

def point_to_xonly(P, Q):
    L = KummerLine(P.curve())

    PQ = P - Q
    xP = L(P[0])
    xQ = L(Q[0])
    xPQ = L(PQ[0])

    return L, xP, xQ, xPQ

def make_prime(N):
    for f in range(1, 1000):
        if (N * f - 1).is_prime():
            return Integer(N * f - 1)

def xof_kdf(X, Y):
    """Key derivation function using SHAKE-256"""
    X_bytes = X.to_bytes()
    Y_bytes = Y.to_bytes()
    return shake_256(X_bytes + Y_bytes)

def xof_encrypt(xof, msg):
    """XOR encryption using XOF output"""
    key = xof.digest(len(msg))
    return bytes(x ^^ y for (x, y) in zip(key, msg))

# --- KEM Specific Functions ---
def G_lambda(m):
    """Generates a deterministic integer seed from message m"""
    seed_bytes = shake_256(m).digest(16)
    return int.from_bytes(seed_bytes, 'big')

def H_lambda(m_or_s, ct):
    """Cryptographically secure hash function for key derivation"""
    if isinstance(m_or_s, bytes):
        m_s_bytes = m_or_s
    elif isinstance(m_or_s, int):
        m_s_bytes = m_or_s.to_bytes((m_or_s.bit_length() + 7) // 8, 'big')
    else:
        raise TypeError("m_or_s must be bytes or int")

    ct_components_bytes = []
    for component in ct:
        if isinstance(component, Element) and hasattr(component, 'to_bytes'):
            ct_components_bytes.append(component.to_bytes())
        elif hasattr(component, 'x') and hasattr(component.x(), 'to_bytes'):
            ct_components_bytes.append(component.x().to_bytes())
        elif isinstance(component, bytes):
            ct_components_bytes.append(component)
        elif hasattr(component, 'a4') and hasattr(component, 'a6'):
            ct_components_bytes.append(component.a4().to_bytes())
            ct_components_bytes.append(component.a6().to_bytes())
        else:
            ct_components_bytes.append(str(component).encode('utf-8'))

    ct_bytes_for_hash = b"".join(ct_components_bytes)
    return shake_256(m_s_bytes + ct_bytes_for_hash).digest(32)

# --- Main Implementation ---
def run_benchmarks(lambda_):
    bs = [162, 243, 324]
    i = [128, 192, 256].index(lambda_)
    b = bs[i]
    
    a = lambda_ - 2
    c = floor(lambda_ / 3 * log(2) / log(5))
    A = Integer(2**a)
    B = Integer(3**b)
    assert B > 2**(2 * lambda_)
    C = 5**c
    p = make_prime(4 * A * B * C)

    x = C
    a = factor(p + 1)[0][1] - 2
    A = Integer(2**a)

    FF.<xx> = GF(p)[]
    F.<i> = GF(p**2, modulus=xx**2 + 1)
    E0 = EllipticCurve(F, [1, 0])
    E0.set_order((p + 1)**2)

    PA, QA = torsion_basis(E0, 4 * A)
    PB, QB = torsion_basis(E0, B)
    X0, Y0 = torsion_basis(E0, C)

    PQA = PA - QA
    PQB = PB - QB
    XY0 = X0 - Y0

    _E0 = KummerLine(E0)
    _PB = _E0(PB[0])
    _QB = _E0(QB[0])
    _PQB = _E0(PQB[0])

    xPA = _E0(PA[0])
    xQA = _E0(QA[0])
    xPQA = _E0(PQA[0])

    xX_0 = _E0(X0[0])
    xY_0 = _E0(Y0[0])
    xXY_0 = _E0(XY0[0])

    def keygenA():
        sys_rng = random.SystemRandom()

        for _ in range(1000):
            q = sys_rng.randint(0, 2**(a - 1) - 1)
            if q % 2 == 1 and q % 3 != 0 and q % 5 != 0 and (A - q) % 3 != 0 and (A - q) % 5 != 0:
                break
        else:
            raise ValueError("Could not find suitable q.")

        deg = q
        rhs = deg * (2**a - deg) * B

        upper_bound = ZZ((rhs.nbits() - p.nbits()) // 2 - 2)
        
        alpha = random_unit_with_rng(A, sys_rng)
        beta = random_unit_with_rng(A, sys_rng)
        gamma = random_unit_with_rng(B, sys_rng)
        delta = random_unit_with_rng(C, sys_rng)

        if upper_bound < 0:
            raise ValueError("Degree is too small.")

        P2, Q2, P3, Q3 = PA, QA, PB, QB

        QF = gp.Qfb(1, 0, 1)

        for _ in range(10_000):
            zz = sys_rng.randint(0, ZZ(2**upper_bound))
            tt = sys_rng.randint(0, ZZ(2**upper_bound))
            sq = rhs - p * (zz**2 + tt**2)

            if sq <= 0:
                continue

            if not sq.is_prime() or sq % 4 != 1:
                continue

            try:
                xx, yy = QF.qfbsolve(sq)
                break
            except ValueError:
                continue
        else:
            raise ValueError("Could not find a suitable endomorphism.")

        i_end = lambda P: E0(-P[0], i * P[1])
        pi_end = lambda P: E0(P[0]**p, P[1]**p)
        theta = lambda P: xx * P + yy * i_end(P) + zz * pi_end(P) + tt * i_end(pi_end(P))

        P2_ = theta(P2)
        Q2_ = theta(Q2)
        P3_ = theta(P3)
        Q3_ = theta(Q3)

        try:
            R = Q3
            wp_ = P3_.weil_pairing(R, B)
            wp = Q3_.weil_pairing(R, B)
            discrete_log_pari(wp_, wp, B)

            K3_dual = theta(Q3)
        except TypeError:
            R = P3
            K3_dual = theta(P3)

        K3_dual = _E0(K3_dual[0])
        phi3 = KummerLineIsogeny(_E0, K3_dual, B)

        xP2_ = _E0(P2_[0])
        xQ2_ = _E0(Q2_[0])
        xPQ2_ = _E0((P2_ - Q2_)[0])

        ximP2_ = phi3(xP2_)
        ximQ2_ = phi3(xQ2_)
        ximPQ2_ = phi3(xPQ2_)

        P2_ = ximP2_.curve_point()
        Q2_ = ximQ2_.curve_point()

        if (P2_ - Q2_)[0] != ximPQ2_.x():
            Q2_ = -Q2_

        inverse = inverse_mod(B, 4 * A)
        P2_ = inverse * P2_
        Q2_ = inverse * Q2_

        P, Q = CouplePoint(-deg * P2, P2_), CouplePoint(-deg * Q2, Q2_)
        kernel = (P, Q)

        Phi = EllipticProductIsogeny(kernel, a)

        P23x = PA + PB + X0
        Q23x = QA + QB + Y0

        imP23x, imQ23x = eval_dimtwo_isog(Phi, A - deg, P23x, Q23x, 4 * A * B * x)

        X_A = inverse_mod(4 * A * B, x) * (4 * A * B * imP23x)
        Y_A = inverse_mod(4 * A * B, x) * (4 * A * B * imQ23x)

        P2_og = alpha * inverse_mod(B * x, 4 * A) * (B * x * imP23x)
        Q2_og = beta * inverse_mod(B * x, 4 * A) * (B * x * imQ23x)

        P3_og = gamma * inverse_mod(4 * A * x, B) * (4 * A * x * imP23x)
        Q3_og = gamma * inverse_mod(4 * A * x, B) * (4 * A * x * imQ23x)

        _, xP2, xQ2, xPQ2 = point_to_xonly(P2_og, Q2_og)
        _, xP3, xQ3, xPQ3 = point_to_xonly(P3_og, Q3_og)

        return (deg, alpha, beta, delta), (xP2, xQ2, xPQ2, xP3, xQ3, xPQ3, delta * X_A, delta * Y_A)

    def Enc(pkA, m, rng_for_enc):
        """PKE Encapsulation with deterministic RNG"""
        xP2, xQ2, xPQ2, xP3, xQ3, xPQ3, X_A, Y_A = pkA
        
        beta_enc = 0
        d1, d2, d3, d4 = random_matrix_with_rng(C, rng_for_enc) 
        omega = random_unit_with_rng(A, rng_for_enc)
        omega_inv = inverse_mod(omega, A)

        _KB = _QB.ladder_3_pt(_PB, _PQB, beta_enc)
        phiB = KummerLineIsogeny(_E0, _KB, B)

        EB = phiB.codomain()
        xP2_B = phiB(xPA)
        xQ2_B = phiB(xQA)
        X_B = phiB(xX_0).curve_point()
        Y_B = phiB(xY_0).curve_point()
        xXY_B = phiB(xXY_0)

        if (X_B - Y_B)[0] != xXY_B.x():
            Y_B = -Y_B

        X_B, Y_B = d1 * X_B + d2 * Y_B, d3 * X_B + d4 * Y_B

        P2_B, Q2_B = lift_image_to_curve(PA, QA, xP2_B, xQ2_B, 4 * A, B)
        P2_B = omega * P2_B
        Q2_B = omega_inv * Q2_B

        EA = xP3.parent()
        xK = xQ3.ladder_3_pt(xP3, xPQ3, beta_enc)
        phiB_ = KummerLineIsogeny(EA, xK, B)

        EAB = phiB_.codomain().curve()
        xP2_AB = phiB_(xP2)
        xQ2_AB = phiB_(xQ2)
        xPQ2_AB = phiB_(xPQ2)

        P2_AB = xP2_AB.curve_point()
        Q2_AB = xQ2_AB.curve_point()

        if (P2_AB - Q2_AB)[0] != xPQ2_AB.x():
            Q2_AB = -Q2_AB

        P2_AB *= omega
        Q2_AB *= omega_inv

        xX_AB = phiB_(EA(X_A[0]))
        xY_AB = phiB_(EA(Y_A[0]))
        xXY_AB = phiB_(EA((X_A - Y_A)[0]))

        X_AB = xX_AB.curve_point()
        Y_AB = xY_AB.curve_point()

        if (X_AB - Y_AB)[0] != xXY_AB.x():
            Y_AB = -Y_AB

        X_AB, Y_AB = d1 * X_AB + d2 * Y_AB, d3 * X_AB + d4 * Y_AB

        xof = xof_kdf(X_AB[0], Y_AB[0])
        ct_bytes = xof_encrypt(xof, m)

        return (EB.curve(), P2_B, Q2_B, X_B, Y_B, EAB, P2_AB, Q2_AB, ct_bytes)

    def Dec(skA, ct):
        """PKE Decryption"""
        deg, alpha, beta, delta = skA
        EB, P2_B, Q2_B, X_B, Y_B, EAB, P2_AB, Q2_AB, ct_bytes = ct

        P2_AB = inverse_mod(alpha, A) * P2_AB
        Q2_AB = inverse_mod(beta, A) * Q2_AB

        UAB, VAB, wp = torsion_basis_with_pairing(EAB, x)

        P, Q = CouplePoint(-deg * P2_B, P2_AB), CouplePoint(-deg * Q2_B, Q2_AB)
        kernel = (P, Q)
        Phi = EllipticProductIsogeny(kernel, a)

        X_AB, Y_AB = eval_dimtwo_isog(Phi, A - deg, X_B, Y_B, x, RS=(UAB, VAB))

        X_AB *= delta
        Y_AB *= delta

        xof = xof_kdf(X_AB[0], Y_AB[0])
        pt = xof_encrypt(xof, ct_bytes)

        return pt

    def KEM_KeyGen():
        """KEM Key Generation (reuses PKE keygen)"""
        return keygenA()

    def KEM_Encapsulate(pk_kem):
        """KEM Encapsulation"""
        m = secrets.token_bytes(32)
        g_lambda_m_seed = G_lambda(m)
        rng_encapsulate = random.Random(g_lambda_m_seed)
        ct = Enc(pk_kem, m, rng_encapsulate)
        K = H_lambda(m, ct)
        return K, ct

    def KEM_Decapsulate(sk, ct, pk_for_reencrypt):
        """KEM Decapsulation with re-encryption check"""
        m_prime = Dec(sk, ct)
        g_lambda_m_prime_seed = G_lambda(m_prime)
        rng_decapsulate = random.Random(g_lambda_m_prime_seed)
        re_encrypted_ct = Enc(pk_for_reencrypt, m_prime, rng_decapsulate)

        is_valid = True
        if len(re_encrypted_ct) != len(ct):
            is_valid = False
        else:
            for idx in range(len(ct)):
                comp_orig = ct[idx]
                comp_reenc = re_encrypted_ct[idx]

                if isinstance(comp_orig, (Element, int, bytes)):
                    if comp_orig != comp_reenc:
                        is_valid = False
                        break
                elif hasattr(comp_orig, 'x') and hasattr(comp_reenc, 'x'):
                    if comp_orig.x() != comp_reenc.x():
                        is_valid = False
                        break
                elif hasattr(comp_orig, 'a4') and hasattr(comp_reenc, 'a4'):
                    if comp_orig.a4() != comp_reenc.a4() or comp_orig.a6() != comp_reenc.a6():
                        is_valid = False
                        break
                else:
                    if str(comp_orig) != str(comp_reenc):
                        is_valid = False
                        break
        
        if is_valid:
            return H_lambda(m_prime, ct)
        else:
            s = secrets.token_bytes(32)
            return H_lambda(s, ct)

    # Print parameters first
    print(f"\n{'='*80}")
    print(f"{' Benchmarking Results ':=^80}")
    print(f"{' Security Level: λ = ' + str(lambda_) + ' ':=^80}")
    print(f"{' Parameters: a=' + str(a) + ', b=' + str(b) + ', c=' + str(c) + ' ':=^80}")
    print(f"{' Prime p: ' + str(p.nbits()) + ' bits ':=^80}")
    
    # Calculate compressed sizes
    pk_compressed = ceil(5 * p.nbits() / 8)
    kem_ct_compressed = ceil(4 * p.nbits() / 8) + ceil(3 * A.nbits() / 8)
    pke_ct_compressed = ceil((4*p.nbits())/8 + (3*A.nbits())/8)
    
    print("\nKey and Ciphertext Sizes:")
    print(f"{'PKE Public key (full):':<30} {ceil(6*p.nbits()/8)} bytes")
    print(f"{'PKE Public key (compressed):':<30} {pk_compressed} bytes")
    print(f"{'PKE Ciphertext (full):':<30} {2*ceil(6*p.nbits()/8)} bytes")
    print(f"{'PKE Ciphertext (compressed):':<30} {pke_ct_compressed} bytes")
    print(f"{'KEM Public key (full):':<30} {ceil(6*p.nbits()/8)} bytes")
    print(f"{'KEM Public key (compressed):':<30} {pk_compressed} bytes")
    print(f"{'KEM Ciphertext (full):':<30} {ceil(4*p.nbits()/8) + 2*ceil(3*A.nbits()/8)} bytes")
    print(f"{'KEM Ciphertext (compressed):':<30} {kem_ct_compressed} bytes")
    print(f"{'='*80}\n")

    # Benchmark both PKE and KEM
    N = 100
    pke_times = [0, 0, 0]  # KeyGen, Encrypt, Decrypt
    kem_times = [0, 0, 0]  # KeyGen, Encapsulate, Decapsulate
    pke_correct = 0
    kem_correct = 0

    for _ in range(N):
        # PKE Benchmarks
        t0 = time.time_ns()
        sk_pke, pk_pke = keygenA()
        pke_times[0] += time.time_ns() - t0

        m = os.urandom(32)
        rng_encrypt = random.Random(os.urandom(16))
        t0 = time.time_ns()
        ct_pke = Enc(pk_pke, m, rng_encrypt)
        pke_times[1] += time.time_ns() - t0

        t0 = time.time_ns()
        m2 = Dec(sk_pke, ct_pke)
        pke_times[2] += time.time_ns() - t0

        if m == m2:
            pke_correct += 1

        # KEM Benchmarks
        t0 = time.time_ns()
        sk_kem, pk_kem = KEM_KeyGen()
        kem_times[0] += time.time_ns() - t0

        t0 = time.time_ns()
        K, ct_kem = KEM_Encapsulate(pk_kem)
        kem_times[1] += time.time_ns() - t0

        t0 = time.time_ns()
        K_prime = KEM_Decapsulate(sk_kem, ct_kem, pk_kem)
        kem_times[2] += time.time_ns() - t0

        if K == K_prime:
            kem_correct += 1

    # Convert to milliseconds
    pke_times = [t/N/1e6 for t in pke_times]
    kem_times = [t/N/1e6 for t in kem_times]

    # Print formatted results
    print(f"{'PKE Performance':^40} | {'KEM Performance':^40}")
    print(f"{'-'*40}+{'-'*40}")
    print(f"KeyGen: {pke_times[0]:>6.1f} ms {'|':^5} KeyGen: {kem_times[0]:>6.1f} ms")
    print(f"Encrypt: {pke_times[1]:>6.1f} ms {'|':^5} Encaps: {kem_times[1]:>6.1f} ms") 
    print(f"Decrypt: {pke_times[2]:>6.1f} ms {'|':^5} Decaps: {kem_times[2]:>6.1f} ms")
    print(f"\nCorrectness:")
    print(f"PKE: {pke_correct}/{N} successful decryptions")
    print(f"KEM: {kem_correct}/{N} successful decapsulations")
    print(f"{'='*80}\n")

# Run benchmarks for all security levels
for lambda_ in [128, 192, 256]:
    run_benchmarks(lambda_)



Key and Ciphertext Sizes:
PKE Public key (full):         324 bytes
PKE Public key (compressed):   270 bytes
PKE Ciphertext (full):         648 bytes
PKE Ciphertext (compressed):   264 bytes
KEM Public key (full):         324 bytes
KEM Public key (compressed):   270 bytes
KEM Ciphertext (full):         312 bytes
KEM Ciphertext (compressed):   264 bytes

            PKE Performance              |             KEM Performance             
----------------------------------------+----------------------------------------
KeyGen: 1477.1 ms   |   KeyGen: 1487.2 ms
Encrypt:  458.5 ms   |   Encaps:  461.9 ms
Decrypt:  364.9 ms   |   Decaps:  827.1 ms

Correctness:
PKE: 100/100 successful decryptions
KEM: 100/100 successful decapsulations



Key and Ciphertext Sizes:
PKE Public key (full):         486 bytes
PKE Public key (compressed):   405 bytes
PKE Ciphertext (full):         972 bytes
PKE Ciphertext (compressed):   396 bytes
KEM Public key (full):         486 bytes
KEM Public key (compressed

# POKE-KEM based on the modification of the Key generation algorithm

In [21]:
import secrets
from montgomery_isogenies.montgomery_isogenies1.kummer_line import KummerLine
from montgomery_isogenies.montgomery_isogenies1.kummer_isogeny import KummerLineIsogeny
from montgomery_isogenies.montgomery_isogenies1.isogenies_x_only import lift_image_to_curve
from theta_structures.theta_structures1.couple_point1 import CouplePoint
from theta_isogenies.theta_isogenies1.product_isogeny1 import EllipticProductIsogeny
from utilities.utilities1.discrete_log import BiDLP, discrete_log_pari
from utilities.utilities1.supersingular import torsion_basis, torsion_basis_with_pairing
import random
import time
from math import log, floor
from sage.all import gcd, ZZ, Integer, EllipticCurve, GF, factor, inverse_mod, ceil
from hashlib import shake_256
import os

# --- Security Parameters Class ---
class SecurityParameters:
    def __init__(self, lambda_):
        if lambda_ not in [128, 192, 256]:
            raise ValueError("Unsupported lambda value. Choose from 128, 192, 256.")
        self.lambda_ = lambda_
        self._derive_parameters()

    def _derive_parameters(self):
        self.a = self.lambda_
        self.b = {128: 162, 192: 243, 256: 324}[self.lambda_]
        self.c_exponent = int(self.lambda_ / 3 * log(2) / log(5))
        self.A = Integer(2 ** self.a)
        self.B = Integer(3 ** self.b)
        self.C = Integer(5 ** self.c_exponent)
        self.N_prod = 4 * self.A * self.B * self.C
        self.p = self._find_suitable_prime(self.N_prod)
        self.x = self.C

    def _find_suitable_prime(self, N):
        for f in range(1, 1000):
            candidate_p = N * f - 1
            if candidate_p.is_prime():
                return Integer(candidate_p)
        raise RuntimeError(f"Could not find suitable prime for N={N}")

# --- Helper functions ---
def random_unit_with_rng(modulus, rng):
    while True:
        alpha = rng.randint(0, modulus - 1)
        if gcd(alpha, modulus) == 1:
            break
    return alpha

def random_matrix_with_rng(modulus, rng):
    while True:
        d1 = rng.randint(0, modulus - 1)
        d2 = rng.randint(0, modulus - 1)
        d3 = rng.randint(0, modulus - 1)
        d4 = rng.randint(0, modulus - 1)
        if gcd(d1 * d4 - d2 * d3, modulus) == 1:
            break
    return d1, d2, d3, d4

def eval_dimtwo_isog(Phi, q, P, Q, ord, RS=None):
    R1 = P
    R2 = Phi.domain()[1](0)
    phiP = Phi(CouplePoint(R1, R2))
    imP = phiP[0]

    R1 = Q
    phiQ = Phi(CouplePoint(R1, R2))
    imQ = phiQ[0]

    R1 = P - Q
    phiPQ = Phi(CouplePoint(R1, R2))
    imPQ = phiPQ[0]

    if (imP - imQ)[0] != imPQ[0]:
        imQ = -imQ

    if RS is None:
        R, S, WP = torsion_basis_with_pairing(Phi.domain()[1], ord)
    else:
        R, S = RS
        WP = R.weil_pairing(S, ord)

    R1 = Phi.domain()[0](0)
    R2 = R
    phiR = Phi(CouplePoint(R1, R2))
    imR = phiR[0]

    R1 = Phi.domain()[0](0)
    R2 = S
    phiS = Phi(CouplePoint(R1, R2))
    imS = phiS[0]

    R1 = Phi.domain()[0](0)
    R2 = R - S
    phiRS = Phi(CouplePoint(R1, R2))
    imRS = phiRS[0]

    if (imR - imS)[0] != imRS[0]:
        imS = -imS

    wp = WP**q
    x, y = BiDLP(imP, imR, imS, ord, ePQ=wp)
    w, z = BiDLP(imQ, imR, imS, ord, ePQ=wp)

    imP = x * R + y * S
    imQ = w * R + z * S
    imP *= q
    imQ *= q
    return imP, imQ

def point_to_xonly(P, Q):
    L = KummerLine(P.curve())
    PQ = P - Q
    xP = L(P[0])
    xQ = L(Q[0])
    xPQ = L(PQ[0])
    return L, xP, xQ, xPQ

def xof_kdf(X, Y):
    X_bytes = X.to_bytes()
    Y_bytes = Y.to_bytes()
    return shake_256(X_bytes + Y_bytes)

def xof_encrypt(xof, msg):
    key = xof.digest(len(msg))
    return bytes(x ^^ y for (x, y) in zip(key, msg))

# --- KEM Specific Functions ---
def G_lambda(m):
    seed_bytes = shake_256(m).digest(16)
    return int.from_bytes(seed_bytes, 'big')

def H_lambda(m_or_s, ct):
    if isinstance(m_or_s, bytes):
        m_s_bytes = m_or_s
    elif isinstance(m_or_s, int):
        m_s_bytes = m_or_s.to_bytes((m_or_s.bit_length() + 7) // 8, 'big')
    else:
        raise TypeError("m_or_s must be bytes or int")

    ct_components_bytes = []
    for component in ct:
        if isinstance(component, bytes):
            ct_components_bytes.append(component)
        elif hasattr(component, 'x') and hasattr(component.x(), 'to_bytes'):
            ct_components_bytes.append(component.x().to_bytes())
        elif hasattr(component, 'a4') and hasattr(component, 'a6'):
            ct_components_bytes.append(component.a4().to_bytes())
            ct_components_bytes.append(component.a6().to_bytes())
        else:
            ct_components_bytes.append(str(component).encode('utf-8'))

    ct_bytes_for_hash = b"".join(ct_components_bytes)
    return shake_256(m_s_bytes + ct_bytes_for_hash).digest(32)

# --- Main Implementation ---
def run_code_with_new_keygen():
    security_levels = [128, 192, 256]
    for lambda_ in security_levels:
        params = SecurityParameters(lambda_)
        p = params.p
        a = params.a
        b = params.b
        c = params.c_exponent
        A = params.A
        B = params.B
        C = params.C
        x = params.x

        FF = GF(p)['xx']; (xx,) = FF._first_ngens(1)
        F = GF(p**2, modulus=xx**2 + 1, names=('i',)); (i,) = F._first_ngens(1)
        E0 = EllipticCurve(F, [1, 0])
        E0.set_order((p + 1)**2)

        PA, QA = torsion_basis(E0, 4 * A)
        PB, QB = torsion_basis(E0, B)
        X0, Y0 = torsion_basis(E0, C)

        PQA = PA - QA
        PQB = PB - QB
        XY0 = X0 - Y0

        _E0 = KummerLine(E0)
        _PB = _E0(PB[0])
        _QB = _E0(QB[0])
        _PQB = _E0(PQB[0])
        xPA = _E0(PA[0])
        xQA = _E0(QA[0])
        xPQA = _E0(PQA[0])
        xX_0 = _E0(X0[0])
        xY_0 = _E0(Y0[0])
        xXY_0 = _E0(XY0[0])

        def keygenA():
            sys_rng = random.SystemRandom()
            for _ in range(1000):
                q = sys_rng.randint(0, 2**(a - 1) - 1)
                if q % 2 == 1 and q % 3 != 0 and q % 5 != 0 and (A - q) % 3 != 0 and (A - q) % 5 != 0:
                    break
            else:
                raise ValueError("Could not find suitable q.")

            deg = q
            rhs = deg * (2**a - deg) * B
            upper_bound = ZZ((rhs.nbits() - p.nbits()) // 2 - 2)

            alpha = random_unit_with_rng(A, sys_rng)
            beta = random_unit_with_rng(A, sys_rng)
            gamma = random_unit_with_rng(B, sys_rng)
            delta = random_unit_with_rng(C, sys_rng)

            if upper_bound < 0:
                raise ValueError("Degree is too small.")

            P2, Q2, P3, Q3 = PA, QA, PB, QB
            QF = gp.Qfb(1, 0, 1)

            for _ in range(10_000):
                zz = sys_rng.randint(0, ZZ(2**upper_bound))
                tt = sys_rng.randint(0, ZZ(2**upper_bound))
                sq = rhs - p * (zz**2 + tt**2)
                if sq <= 0:
                    continue
                if not sq.is_prime() or sq % 4 != 1:
                    continue
                try:
                    xx, yy = QF.qfbsolve(sq)
                    break
                except ValueError:
                    continue
            else:
                raise ValueError("Could not find a suitable endomorphism.")

            i_end = lambda P: E0(-P[0], i * P[1])
            pi_end = lambda P: E0(P[0]**p, P[1]**p)
            theta = lambda P: xx * P + yy * i_end(P) + zz * pi_end(P) + tt * i_end(pi_end(P))

            P2_ = theta(P2)
            Q2_ = theta(Q2)
            P3_ = theta(P3)
            Q3_ = theta(Q3)

            try:
                R = Q3
                wp_ = P3_.weil_pairing(R, B)
                wp = Q3_.weil_pairing(R, B)
                discrete_log_pari(wp_, wp, B)
                K3_dual = theta(Q3)
            except TypeError:
                R = P3
                K3_dual = theta(P3)

            K3_dual = _E0(K3_dual[0])
            phi3 = KummerLineIsogeny(_E0, K3_dual, B)

            xP2_ = _E0(P2_[0])
            xQ2_ = _E0(Q2_[0])
            xPQ2_ = _E0((P2_ - Q2_)[0])
            ximP2_ = phi3(xP2_)
            ximQ2_ = phi3(xQ2_)
            ximPQ2_ = phi3(xPQ2_)

            P2_ = ximP2_.curve_point()
            Q2_ = ximQ2_.curve_point()
            if (P2_ - Q2_)[0] != ximPQ2_.x():
                Q2_ = -Q2_

            inverse = inverse_mod(B, 4 * A)
            P2_ = inverse * P2_
            Q2_ = inverse * Q2_

            P, Q = CouplePoint(-deg * P2, P2_), CouplePoint(-deg * Q2, Q2_)
            kernel = (P, Q)
            Phi = EllipticProductIsogeny(kernel, a)

            P23x = PA + PB + X0
            Q23x = QA + QB + Y0
            imP23x, imQ23x = eval_dimtwo_isog(Phi, A - deg, P23x, Q23x, 4 * A * B * x)

            X_A = inverse_mod(4 * A * B, x) * (4 * A * B * imP23x)
            Y_A = inverse_mod(4 * A * B, x) * (4 * A * B * imQ23x)

            P2_og = alpha * inverse_mod(B * x, 4 * A) * (B * x * imP23x)
            Q2_og = beta * inverse_mod(B * x, 4 * A) * (B * x * imQ23x)
            P3_og = gamma * inverse_mod(4 * A * x, B) * (4 * A * x * imP23x)
            Q3_og = gamma * inverse_mod(4 * A * x, B) * (4 * A * x * imQ23x)

            _, xP2, xQ2, xPQ2 = point_to_xonly(P2_og, Q2_og)
            _, xP3, xQ3, xPQ3 = point_to_xonly(P3_og, Q3_og)

            return (deg, alpha, beta, delta), (xP2, xQ2, xPQ2, xP3, xQ3, xPQ3, delta * X_A, delta * Y_A)

        def Enc(pkA, m, rng_for_enc):
            xP2, xQ2, xPQ2, xP3, xQ3, xPQ3, X_A, Y_A = pkA

            beta_enc = 0
            d1, d2, d3, d4 = random_matrix_with_rng(C, rng_for_enc)
            omega = random_unit_with_rng(A, rng_for_enc)
            omega_inv = inverse_mod(omega, A)

            _KB = _QB.ladder_3_pt(_PB, _PQB, beta_enc)
            phiB = KummerLineIsogeny(_E0, _KB, B)
            EB = phiB.codomain()

            xP2_B = phiB(xPA)
            xQ2_B = phiB(xQA)
            X_B = phiB(xX_0).curve_point()
            Y_B = phiB(xY_0).curve_point()
            xXY_B = phiB(xXY_0)
            if (X_B - Y_B)[0] != xXY_B.x():
                Y_B = -Y_B
            X_B, Y_B = d1 * X_B + d2 * Y_B, d3 * X_B + d4 * Y_B

            P2_B, Q2_B = lift_image_to_curve(PA, QA, xP2_B, xQ2_B, 4 * A, B)
            P2_B = omega * P2_B
            Q2_B = omega_inv * Q2_B

            EA = xP3.parent()
            xK = xQ3.ladder_3_pt(xP3, xPQ3, beta_enc)
            phiB_ = KummerLineIsogeny(EA, xK, B)
            EAB = phiB_.codomain().curve()

            xP2_AB = phiB_(xP2)
            xQ2_AB = phiB_(xQ2)
            xPQ2_AB = phiB_(xPQ2)
            P2_AB = xP2_AB.curve_point()
            Q2_AB = xQ2_AB.curve_point()
            if (P2_AB - Q2_AB)[0] != xPQ2_AB.x():
                Q2_AB = -Q2_AB
            P2_AB *= omega
            Q2_AB *= omega_inv

            xX_AB = phiB_(EA(X_A[0]))
            xY_AB = phiB_(EA(Y_A[0]))
            xXY_AB = phiB_(EA((X_A - Y_A)[0]))
            X_AB = xX_AB.curve_point()
            Y_AB = xY_AB.curve_point()
            if (X_AB - Y_AB)[0] != xXY_AB.x():
                Y_AB = -Y_AB
            X_AB, Y_AB = d1 * X_AB + d2 * Y_AB, d3 * X_AB + d4 * Y_AB

            xof = xof_kdf(X_AB[0], Y_AB[0])
            ct_bytes = xof_encrypt(xof, m)

            return (EB.curve(), P2_B, Q2_B, X_B, Y_B, EAB, P2_AB, Q2_AB, ct_bytes)

        def Dec(skA, ct):
            deg, alpha, beta, delta = skA
            EB, P2_B, Q2_B, X_B, Y_B, EAB, P2_AB, Q2_AB, ct_bytes = ct

            P2_AB = inverse_mod(alpha, A) * P2_AB
            Q2_AB = inverse_mod(beta, A) * Q2_AB

            UAB, VAB, wp = torsion_basis_with_pairing(EAB, x)

            P, Q = CouplePoint(-deg * P2_B, P2_AB), CouplePoint(-deg * Q2_B, Q2_AB)
            kernel = (P, Q)
            Phi = EllipticProductIsogeny(kernel, a)

            X_AB, Y_AB = eval_dimtwo_isog(Phi, A - deg, X_B, Y_B, x, RS=(UAB, VAB))
            X_AB *= delta
            Y_AB *= delta

            xof = xof_kdf(X_AB[0], Y_AB[0])
            pt = xof_encrypt(xof, ct_bytes)
            return pt

        def KEM_KeyGen():
            return keygenA()

        def KEM_Encapsulate(pk_kem):
            m = secrets.token_bytes(32)
            g_lambda_m_seed = G_lambda(m)
            rng_encapsulate = random.Random(g_lambda_m_seed)
            ct = Enc(pk_kem, m, rng_encapsulate)
            K = H_lambda(m, ct)
            return K, ct

        def KEM_Decapsulate(sk, ct, pk_for_reencrypt):
            m_prime = Dec(sk, ct)
            g_lambda_m_prime_seed = G_lambda(m_prime)
            rng_decapsulate = random.Random(g_lambda_m_prime_seed)
            re_encrypted_ct = Enc(pk_for_reencrypt, m_prime, rng_decapsulate)

            is_valid = True
            if len(re_encrypted_ct) != len(ct):
                is_valid = False
            else:
                for idx in range(len(ct)):
                    comp_orig = ct[idx]
                    comp_reenc = re_encrypted_ct[idx]
                    if isinstance(comp_orig, (int, bytes)):
                        if comp_orig != comp_reenc:
                            is_valid = False
                            break
                    elif hasattr(comp_orig, 'x') and hasattr(comp_reenc, 'x'):
                        if comp_orig.x() != comp_reenc.x():
                            is_valid = False
                            break
                    elif hasattr(comp_orig, 'a4') and hasattr(comp_reenc, 'a4'):
                        if comp_orig.a4() != comp_reenc.a4() or comp_orig.a6() != comp_reenc.a6():
                            is_valid = False
                            break
                    else:
                        if str(comp_orig) != str(comp_reenc):
                            is_valid = False
                            break

            if is_valid:
                return H_lambda(m_prime, ct)
            else:
                s = secrets.token_bytes(32)
                return H_lambda(s, ct)

        # Print parameters first
        print(f"\n{'='*80}")
        print(f"{' Benchmarking Results ':=^80}")
        print(f"{' Security Level: λ = ' + str(lambda_) + ' ':=^80}")
        print(f"{' Parameters: a=' + str(a) + ', b=' + str(b) + ', c=' + str(c) + ' ':=^80}")
        print(f"{' Prime p: ' + str(p.nbits()) + ' bits ':=^80}")

        # Calculate compressed sizes
        pk_compressed = ceil(5 * p.nbits() / 8)
        kem_ct_compressed = ceil(4 * p.nbits() / 8) + ceil(3 * A.nbits() / 8)

        print("\nKey and Ciphertext Sizes:")
        print(f"{'PKE Public key (full):':<30} {ceil(6*p.nbits()/8)} bytes")
        print(f"{'PKE Public key (compressed):':<30} {pk_compressed} bytes")
        print(f"{'PKE Ciphertext (full):':<30} {2*ceil(6*p.nbits()/8)} bytes")
        print(f"{'PKE Ciphertext (compressed):':<30} {ceil((4*p.nbits())/8 + (3*A.nbits())/8)} bytes")
        print(f"{'KEM Public key (full):':<30} {ceil(6*p.nbits()/8)} bytes")
        print(f"{'KEM Public key (compressed):':<30} {pk_compressed} bytes")
        print(f"{'KEM Ciphertext (full):':<30} {ceil(4*p.nbits()/8) + 2*ceil(3*A.nbits()/8)} bytes")
        print(f"{'KEM Ciphertext (compressed):':<30} {kem_ct_compressed} bytes")
        print(f"{'='*80}\n")

        # Benchmark both PKE and KEM
        N = 100
        pke_times = [0, 0, 0]  # KeyGen, Encrypt, Decrypt
        kem_times = [0, 0, 0]  # KeyGen, Encapsulate, Decapsulate
        pke_correct = 0
        kem_correct = 0
        for _ in range(N):
            # PKE Benchmarks
            t0 = time.time_ns()
            sk_pke, pk_pke = keygenA()
            pke_times[0] += time.time_ns() - t0
            m = os.urandom(32)
            rng_encrypt = random.Random(os.urandom(16))
            t0 = time.time_ns()
            ct_pke = Enc(pk_pke, m, rng_encrypt)
            pke_times[1] += time.time_ns() - t0
            t0 = time.time_ns()
            m2 = Dec(sk_pke, ct_pke)
            pke_times[2] += time.time_ns() - t0
            if m == m2:
                pke_correct += 1

            # KEM Benchmarks
            t0 = time.time_ns()
            sk_kem, pk_kem = KEM_KeyGen()
            kem_times[0] += time.time_ns() - t0
            t0 = time.time_ns()
            K, ct_kem = KEM_Encapsulate(pk_kem)
            kem_times[1] += time.time_ns() - t0
            t0 = time.time_ns()
            K_prime = KEM_Decapsulate(sk_kem, ct_kem, pk_kem)
            kem_times[2] += time.time_ns() - t0
            if K == K_prime:
                kem_correct += 1

        # Convert to milliseconds
        pke_times = [t/N/1e6 for t in pke_times]
        kem_times = [t/N/1e6 for t in kem_times]
        # Print formatted results
        print(f"{'PKE Performance':^40} | {'KEM Performance':^40}")
        print(f"{'-'*40}+{'-'*40}")
        print(f"KeyGen: {pke_times[0]:>6.1f} ms {'|':^5} KeyGen: {kem_times[0]:>6.1f} ms")
        print(f"Encrypt: {pke_times[1]:>6.1f} ms {'|':^5} Encaps: {kem_times[1]:>6.1f} ms")
        print(f"Decrypt: {pke_times[2]:>6.1f} ms {'|':^5} Decaps: {kem_times[2]:>6.1f} ms")
        print(f"\nCorrectness:")
        print(f"PKE: {pke_correct}/{N} successful decryptions")
        print(f"KEM: {kem_correct}/{N} successful decapsulations")
        print(f"{'='*80}\n")

# Run the updated code with new key generation
run_code_with_new_keygen()




Key and Ciphertext Sizes:
PKE Public key (full):         327 bytes
PKE Public key (compressed):   273 bytes
PKE Ciphertext (full):         654 bytes
PKE Ciphertext (compressed):   267 bytes
KEM Public key (full):         327 bytes
KEM Public key (compressed):   273 bytes
KEM Ciphertext (full):         316 bytes
KEM Ciphertext (compressed):   267 bytes

            PKE Performance              |             KEM Performance             
----------------------------------------+----------------------------------------
KeyGen: 1434.5 ms   |   KeyGen: 1436.8 ms
Encrypt:  433.1 ms   |   Encaps:  442.0 ms
Decrypt:  359.7 ms   |   Decaps:  797.4 ms

Correctness:
PKE: 100/100 successful decryptions
KEM: 100/100 successful decapsulations



Key and Ciphertext Sizes:
PKE Public key (full):         487 bytes
PKE Public key (compressed):   406 bytes
PKE Ciphertext (full):         974 bytes
PKE Ciphertext (compressed):   397 bytes
KEM Public key (full):         487 bytes
KEM Public key (compressed