In [1]:
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import os

In [4]:
# 1. Generate ECDH key pairs for Party A and Party B
a_private_key = ec.generate_private_key(ec.SECP256R1())
b_private_key = ec.generate_private_key(ec.SECP256R1())

a_public_key = a_private_key.public_key()
b_public_key = b_private_key.public_key()

# 2. Derive shared secret
shared_secret_a = a_private_key.exchange(ec.ECDH(), b_public_key)
shared_secret_b = b_private_key.exchange(ec.ECDH(), a_public_key)

assert shared_secret_a == shared_secret_b

In [5]:
# 3. Derive root key and subsequent chain keys via HKDF
# Root key
hkdf_root = HKDF(
    algorithm=hashes.SHA256(),
    length=32,
    salt=None,
    info=b'protocol root key',
)
root_key = hkdf_root.derive(shared_secret_a)
# Chain keys (example: derive 5 chain keys)
chain_keys = []
prev_key = root_key
for i in range(5):
    hkdf_chain = HKDF(
        algorithm=hashes.SHA256(),
        length=32,
        salt=None,
        info=f'chain key {i}'.encode(),
    )
    ck = hkdf_chain.derive(prev_key)
    chain_keys.append(ck)
    prev_key = ck

In [6]:
# 4. Encrypt and decrypt a message using one of the chain keys with AES-GCM
message = b"Secret message for AES encryption"
aes_key = chain_keys[0]  # use the first chain key
aesgcm = AESGCM(aes_key)

nonce = os.urandom(12)  # 96-bit nonce for AES-GCM
ciphertext = aesgcm.encrypt(nonce, message, associated_data=None)


In [19]:
os.urandom(12)

b'n\x06\x9f\x8a\x1b\xfb\x83,=\xc0\xc8v'

In [None]:
# 5. Decrypt
plaintext = aesgcm.decrypt(nonce, ciphertext, associated_data=None)
assert plaintext == message

In [8]:
# Output
print("Shared secret (hex):", shared_secret_a.hex())
print("Root key (hex):", root_key.hex())
for idx, ck in enumerate(chain_keys):
    print(f"Chain key {idx} (hex):", ck.hex())
print("Nonce (hex):", nonce.hex())
print("Ciphertext (hex):", ciphertext.hex())
print("Decrypted message:", plaintext.decode())

Shared secret (hex): 15df9ca19527760a9438103b8e232187c3e9d27f8e011360b35357d0a54290ec
Root key (hex): be6d54dca4c38d8a53f7546dd11982bdc3bcee55f1cc87f7f9a3b59988dc1269
Chain key 0 (hex): 5180e1d83c37fb10e0ed32c57856603b1f96765c19d6195c90205f6ee3a2fb23
Chain key 1 (hex): e15b54951fe98589b15d339bdcdeb1829256610d490ef5460786e88df9edfab6
Chain key 2 (hex): 94739ab6275fbf85acf591bb27843249e0026f18bc33de2b4cdb0419528524ed
Chain key 3 (hex): 6f123eb726ee2931200295d097b84e8228cf8905964f42b16d76ba1dd8fb1acf
Chain key 4 (hex): cfc720bc548876ffcbf08988455ed4f87bbacb8e561d5aaebb4e3bc9b0bd783e
Nonce (hex): 92bbac1352b29295ae39173e
Ciphertext (hex): a08ea2b00f1430d7c498f5128a015c25cc7be8cd655729588cfec349e939bdd535f4711702bb392072b17426ca886a60c8
Decrypted message: Secret message for AES encryption
