<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/Differentiable_ILP_Reasoner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# pip install cryptography

import os
import json
import base64
from typing import Tuple
from cryptography.hazmat.primitives.asymmetric import x25519, ed25519
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305


# ---------- Utilities ----------
def b64e(b: bytes) -> str:
    return base64.b64encode(b).decode("utf-8")

def b64d(s: str) -> bytes:
    return base64.b64decode(s.encode("utf-8"))


# ---------- Long-term key management ----------
def generate_recipient_decrypt_keys() -> Tuple[bytes, bytes]:
    """
    Recipient's static X25519 keypair for decryption.
    Returns (private_raw, public_raw) as 32-byte raw keys.
    """
    priv = x25519.X25519PrivateKey.generate()
    pub = priv.public_key()
    priv_raw = priv.private_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PrivateFormat.Raw,
        encryption_algorithm=serialization.NoEncryption(),
    )
    pub_raw = pub.public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw,
    )
    return priv_raw, pub_raw

def generate_sender_sign_keys() -> Tuple[bytes, bytes]:
    """
    Sender's static Ed25519 keypair for signatures.
    Returns (private_raw, public_raw) as 32-byte raw (private) and 32-byte raw (public).
    """
    priv = ed25519.Ed25519PrivateKey.generate()
    pub = priv.public_key()
    priv_raw = priv.private_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PrivateFormat.Raw,
        encryption_algorithm=serialization.NoEncryption(),
    )
    pub_raw = pub.public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw,
    )
    return priv_raw, pub_raw


# ---------- Encrypt / Decrypt ----------
PROTOCOL_INFO = b"e2e-single-v1"

def derive_aead_key(shared_secret: bytes, salt: bytes) -> bytes:
    hkdf = HKDF(
        algorithm=hashes.SHA256(),
        length=32,
        salt=salt,
        info=PROTOCOL_INFO,
    )
    return hkdf.derive(shared_secret)

def encrypt_message(
    plaintext: bytes,
    recipient_pub_x25519_raw: bytes,
    sender_sign_priv_raw: bytes,
    aad: bytes = b"",  # optional associated data (e.g., conversation ID)
) -> str:
    """
    Encrypts a single message to the recipient.
    Returns a JSON string containing all fields needed for decryption.
    """
    # 1) Ephemeral ECDH
    eph_priv = x25519.X25519PrivateKey.generate()
    eph_pub = eph_priv.public_key()

    recipient_pub = x25519.X25519PublicKey.from_public_bytes(recipient_pub_x25519_raw)
    shared_secret = eph_priv.exchange(recipient_pub)

    # 2) Derive symmetric key (salt = random)
    salt = os.urandom(16)
    key = derive_aead_key(shared_secret, salt)

    # 3) AEAD encrypt
    nonce = os.urandom(12)  # ChaCha20-Poly1305 96-bit nonce
    aead = ChaCha20Poly1305(key)
    ciphertext = aead.encrypt(nonce, plaintext, aad if aad else None)

    # 4) Sign the package (binds ephemeral pubkey, nonce, aad, ciphertext)
    sender_sign_priv = ed25519.Ed25519PrivateKey.from_private_bytes(sender_sign_priv_raw)
    sender_sign_pub = sender_sign_priv.public_key()

    eph_pub_raw = eph_pub.public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw,
    )
    sender_sign_pub_raw = sender_sign_pub.public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw,
    )

    to_sign = b"|".join([PROTOCOL_INFO, eph_pub_raw, nonce, salt, aad, ciphertext])
    signature = sender_sign_priv.sign(to_sign)

    # 5) Package message
    package = {
        "v": "e2e-single-v1",
        "eph_pub": b64e(eph_pub_raw),
        "sender_sign_pub": b64e(sender_sign_pub_raw),
        "nonce": b64e(nonce),
        "salt": b64e(salt),
        "aad": b64e(aad),
        "ciphertext": b64e(ciphertext),
        "signature": b64e(signature),
        # optional metadata field if you need timestamps, message IDs, etc.
    }
    return json.dumps(package)

def decrypt_message(
    package_json: str,
    recipient_priv_x25519_raw: bytes,
) -> Tuple[bytes, bytes]:
    """
    Decrypts the JSON package.
    Returns (plaintext, sender_sign_pub_raw).
    Raises on failure (bad signature, wrong key, tampering).
    """
    package = json.loads(package_json)

    # 1) Parse and decode fields
    assert package.get("v") == "e2e-single-v1", "Unsupported version"
    eph_pub_raw = b64d(package["eph_pub"])
    sender_sign_pub_raw = b64d(package["sender_sign_pub"])
    nonce = b64d(package["nonce"])
    salt = b64d(package["salt"])
    aad = b64d(package["aad"])
    ciphertext = b64d(package["ciphertext"])
    signature = b64d(package["signature"])

    # 2) Verify signature first
    sender_sign_pub = ed25519.Ed25519PublicKey.from_public_bytes(sender_sign_pub_raw)
    to_sign = b"|".join([PROTOCOL_INFO, eph_pub_raw, nonce, salt, aad, ciphertext])
    sender_sign_pub.verify(signature, to_sign)  # raises InvalidSignature if tampered

    # 3) Recompute shared secret and decrypt
    recipient_priv = x25519.X25519PrivateKey.from_private_bytes(recipient_priv_x25519_raw)
    eph_pub = x25519.X25519PublicKey.from_public_bytes(eph_pub_raw)
    shared_secret = recipient_priv.exchange(eph_pub)

    key = derive_aead_key(shared_secret, salt)
    aead = ChaCha20Poly1305(key)
    plaintext = aead.decrypt(nonce, ciphertext, aad if aad else None)

    return plaintext, sender_sign_pub_raw


# ---------- Demo ----------
if __name__ == "__main__":
    # Long-term keys
    recipient_priv_raw, recipient_pub_raw = generate_recipient_decrypt_keys()
    sender_sign_priv_raw, sender_sign_pub_raw = generate_sender_sign_keys()

    message = b"Hello Kyaw \xe2\x80\x94 end-to-end, single message, forward secret."
    aad = b"chat-1234"

    envelope = encrypt_message(
        message,
        recipient_pub_x25519_raw=recipient_pub_raw,
        sender_sign_priv_raw=sender_sign_priv_raw,
        aad=aad,
    )

    recovered, sender_pub = decrypt_message(envelope, recipient_priv_x25519_raw=recipient_priv_raw)

    assert recovered == message
    assert sender_pub == sender_sign_pub_raw
    print("Decrypted:", recovered.decode())
    print("Sender verified:", base64.b64encode(sender_pub).decode()[:16], "...")