# YAK authenticated key exchange protocol

In [18]:
from cryptography.hazmat.primitives.asymmetric import x25519, ed25519
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes, constant_time, serialization

def b(pk):
    return pk.public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw
    )

class Party:
    def __init__(self, name):
        self.name = name
        self.sig_priv = ed25519.Ed25519PrivateKey.generate()
        self.sig_pub = self.sig_priv.public_key()
        self.static_priv = x25519.X25519PrivateKey.generate()
        self.static_pub = self.static_priv.public_key()

    def gen_eph(self):
        self.eph_priv = x25519.X25519PrivateKey.generate()
        self.eph_pub = self.eph_priv.public_key()

    def sign(self, peer, peer_static, peer_eph):
        msg = (self.name.encode() + peer.encode() +
               b(self.static_pub) + peer_static +
               b(self.eph_pub) + peer_eph)
        return self.sig_priv.sign(msg), msg

    def verify(self, sig, pub, msg):
        pub.verify(sig, msg)

    def shared(self, my_role, peer_static, peer_eph):
        """Compute DHs in fixed canonical order, regardless of role."""
        peer_static_pub = x25519.X25519PublicKey.from_public_bytes(peer_static)
        peer_eph_pub = x25519.X25519PublicKey.from_public_bytes(peer_eph)

        dh_e_e = self.eph_priv.exchange(peer_eph_pub)
        dh_es = (self.eph_priv.exchange(peer_static_pub)
                 if my_role == "initiator" else self.static_priv.exchange(peer_eph_pub))
        dh_se = (self.static_priv.exchange(peer_eph_pub)
                 if my_role == "initiator" else self.eph_priv.exchange(peer_static_pub))

        return dh_e_e + dh_es + dh_se

def yak_protocol():
    alice, bob = Party("Alice"), Party("Bob")
    alice.gen_eph(); bob.gen_eph()
    A_s, A_e, B_s, B_e = b(alice.static_pub), b(alice.eph_pub), b(bob.static_pub), b(bob.eph_pub)

    a_sig, a_msg = alice.sign("Bob", B_s, B_e)
    b_sig, b_msg = bob.sign("Alice", A_s, A_e)
    bob.verify(a_sig, alice.sig_pub, a_msg)
    alice.verify(b_sig, bob.sig_pub, b_msg)

    a_mat = alice.shared("initiator", B_s, B_e)
    b_mat = bob.shared("responder", A_s, A_e)

    info = b"AliceBob" + A_s + B_s + A_e + B_e
    hkdf = lambda m: HKDF(algorithm=hashes.SHA256(), length=32, salt=None, info=info).derive(m)
    a_key, b_key = hkdf(a_mat), hkdf(b_mat)
    assert constant_time.bytes_eq(a_key, b_key)
    print("Shared session key:", a_key.hex())

if __name__ == "__main__":
    yak_protocol()

Shared session key: 32ac1f0c572e6f6509d773931c31155aa86151285d6f0ee8df8626ca7ca8c525


# Implementation of Standard Asymmetric Block Cipher
## 1. Elgamal Cryptosystems

In [20]:
from Crypto.Random import random
from Crypto.Util.number import getPrime, inverse

# Key generation
def gen_keys(bits=256):
    p = getPrime(bits)
    g = random.randint(2, p - 2)
    x = random.randint(2, p - 2)      # private key
    y = pow(g, x, p)                  # public key component
    return (p, g, y), x               # (public_key, private_key)

# Encryption
def encrypt(pub, m):
    p, g, y = pub
    k = random.randint(2, p - 2)
    a = pow(g, k, p)
    b = (pow(y, k, p) * m) % p
    return (a, b)

# Decryption
def decrypt(priv, pub, cipher):
    p, g, y = pub
    a, b = cipher
    s = pow(a, priv, p)
    m = (b * inverse(s, p)) % p
    return m

pub, priv = gen_keys()
msg = 12345
cipher = encrypt(pub, msg)
plain = decrypt(priv, pub, cipher)

print("Ciphertext:", cipher)
print("Decrypted:", plain)


Ciphertext: (39420336393945753614445400066421493882867244502627428542876858097779063502771, 60250286652549625932877508216887267951595917282203541756300509294822779214893)
Decrypted: 12345


## w. Elliptic Curve Cryptography

In [19]:
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import os

def pub_bytes(pub):
    return pub.public_bytes(
        encoding=serialization.Encoding.X962,
        format=serialization.PublicFormat.UncompressedPoint
    )

def pub_from_bytes(bts):
    return ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), bts)

# key generation
def ec_keygen():
    prv = ec.generate_private_key(ec.SECP256R1())
    pub = prv.public_key()
    return prv, pub

def ecies_encrypt(receiver_pub, plaintext: bytes, info=b"ecies-v1"):
    eph_prv = ec.generate_private_key(ec.SECP256R1())
    eph_pub = eph_prv.public_key()
    shared = eph_prv.exchange(ec.ECDH(), receiver_pub)
    key = HKDF(algorithm=hashes.SHA256(), length=32, salt=None, info=info).derive(shared)
    aesgcm = AESGCM(key)
    nonce = os.urandom(12)
    ct = aesgcm.encrypt(nonce, plaintext, None)
    return pub_bytes(eph_pub) + len(nonce).to_bytes(1,'big') + nonce + ct

def ecies_decrypt(receiver_prv, blob: bytes, info=b"ecies-v1"):
    pt_len = 65
    eph_pub_bytes = blob[:pt_len]
    pos = pt_len
    nlen = blob[pos]; pos+=1
    nonce = blob[pos:pos+nlen]; pos+=nlen
    ct = blob[pos:]
    eph_pub = pub_from_bytes(eph_pub_bytes)
    shared = receiver_prv.exchange(ec.ECDH(), eph_pub)
    key = HKDF(algorithm=hashes.SHA256(), length=32, salt=None, info=info).derive(shared)
    aesgcm = AESGCM(key)
    return aesgcm.decrypt(nonce, ct, None)

if __name__ == "__main__":
    print("Generating EC keypair (secp256r1)...")
    prv, pub = ec_keygen()
    msg = b"Hello ECIES-style world!"
    blob = ecies_encrypt(pub, msg)
    rec = ecies_decrypt(prv, blob)
    assert rec == msg
    print("Decrypted:", rec)


Generating EC keypair (secp256r1)...
Decrypted: b'Hello ECIES-style world!'
