<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/SafeDreamer_Wrapper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# pip install cryptography

import os
import json
import base64
from typing import Tuple, Union
from cryptography.hazmat.primitives.asymmetric import x25519, ed25519
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305


# ---------- Utilities ----------
def b64e(b: bytes) -> str:
    return base64.b64encode(b).decode("utf-8")

def b64d(s: str) -> bytes:
    return base64.b64decode(s.encode("utf-8"))

def to_bytes(x: Union[str, bytes]) -> bytes:
    return x if isinstance(x, (bytes, bytearray)) else str(x).encode("utf-8")

def pack_for_signature(*fields: bytes) -> bytes:
    """
    Length-prefix each field (4 bytes big-endian) to avoid ambiguity,
    and domain-separate with protocol info.
    """
    PROTOCOL_INFO = b"e2e-single-v1"
    parts = [PROTOCOL_INFO, b"|sig|"]
    for f in fields:
        f = bytes(f)
        parts.append(len(f).to_bytes(4, "big"))
        parts.append(f)
    return b"".join(parts)


# ---------- Long-term key management ----------
def generate_recipient_decrypt_keys() -> Tuple[bytes, bytes]:
    """
    Recipient's static X25519 keypair for decryption.
    Returns (private_raw, public_raw) as 32-byte raw keys.
    """
    priv = x25519.X25519PrivateKey.generate()
    pub = priv.public_key()
    priv_raw = priv.private_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PrivateFormat.Raw,
        encryption_algorithm=serialization.NoEncryption(),
    )
    pub_raw = pub.public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw,
    )
    return priv_raw, pub_raw

def generate_sender_sign_keys() -> Tuple[bytes, bytes]:
    """
    Sender's static Ed25519 keypair for signatures.
    Returns (private_raw, public_raw).
    """
    priv = ed25519.Ed25519PrivateKey.generate()
    pub = priv.public_key()
    priv_raw = priv.private_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PrivateFormat.Raw,
        encryption_algorithm=serialization.NoEncryption(),
    )
    pub_raw = pub.public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw,
    )
    return priv_raw, pub_raw


# ---------- Derivation ----------
def derive_aead_key(shared_secret: bytes, salt: bytes) -> bytes:
    hkdf = HKDF(
        algorithm=hashes.SHA256(),
        length=32,
        salt=salt,
        info=b"e2e-single-v1",
    )
    return hkdf.derive(shared_secret)


# ---------- Encrypt / Decrypt ----------
def encrypt_message(
    plaintext: Union[str, bytes],
    recipient_pub_x25519_raw: bytes,
    sender_sign_priv_raw: bytes,
    aad: Union[str, bytes] = b"",
) -> str:
    """
    Encrypts one message to the recipient.
    - plaintext: str or bytes (UTF-8 will be used for str)
    - aad: optional associated data (str or bytes), authenticated but not encrypted
    Returns a JSON string with all fields required to decrypt and verify.
    """
    pt = to_bytes(plaintext)
    aad_bytes = to_bytes(aad)

    # 1) Ephemeral ECDH
    eph_priv = x25519.X25519PrivateKey.generate()
    eph_pub = eph_priv.public_key()
    recipient_pub = x25519.X25519PublicKey.from_public_bytes(recipient_pub_x25519_raw)
    shared_secret = eph_priv.exchange(recipient_pub)

    # 2) Derive symmetric key
    salt = os.urandom(16)
    key = derive_aead_key(shared_secret, salt)

    # 3) AEAD encrypt
    nonce = os.urandom(12)  # 96-bit nonce for ChaCha20-Poly1305
    aead = ChaCha20Poly1305(key)
    ciphertext = aead.encrypt(nonce, pt, aad_bytes)

    # 4) Sign the package (binds ephemeral key, nonce, salt, aad, ciphertext)
    sender_sign_priv = ed25519.Ed25519PrivateKey.from_private_bytes(sender_sign_priv_raw)
    sender_sign_pub = sender_sign_priv.public_key()

    eph_pub_raw = eph_pub.public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw,
    )
    sender_sign_pub_raw = sender_sign_pub.public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw,
    )

    to_sign = pack_for_signature(eph_pub_raw, nonce, salt, aad_bytes, ciphertext)
    signature = sender_sign_priv.sign(to_sign)

    # 5) Package as JSON
    package = {
        "v": "e2e-single-v1",
        "alg": {
            "dh": "x25519",
            "kdf": "hkdf-sha256",
            "aead": "chacha20poly1305",
            "sig": "ed25519",
        },
        "eph_pub": b64e(eph_pub_raw),
        "sender_sign_pub": b64e(sender_sign_pub_raw),
        "nonce": b64e(nonce),
        "salt": b64e(salt),
        "aad": b64e(aad_bytes),
        "ciphertext": b64e(ciphertext),
        "signature": b64e(signature),
    }
    return json.dumps(package)


def decrypt_message(
    package_json: str,
    recipient_priv_x25519_raw: bytes,
) -> Tuple[bytes, bytes]:
    """
    Decrypts and verifies the JSON package.
    Returns (plaintext_bytes, sender_sign_pub_raw).
    Raises on version mismatch, bad signature, or decryption failure.
    """
    package = json.loads(package_json)
    if package.get("v") != "e2e-single-v1":
        raise ValueError("Unsupported version")

    eph_pub_raw = b64d(package["eph_pub"])
    sender_sign_pub_raw = b64d(package["sender_sign_pub"])
    nonce = b64d(package["nonce"])
    salt = b64d(package["salt"])
    aad_bytes = b64d(package["aad"])
    ciphertext = b64d(package["ciphertext"])
    signature = b64d(package["signature"])

    # 1) Verify signature first
    sender_sign_pub = ed25519.Ed25519PublicKey.from_public_bytes(sender_sign_pub_raw)
    to_sign = pack_for_signature(eph_pub_raw, nonce, salt, aad_bytes, ciphertext)
    sender_sign_pub.verify(signature, to_sign)  # raises InvalidSignature if tampered

    # 2) Recompute shared secret and decrypt
    recipient_priv = x25519.X25519PrivateKey.from_private_bytes(recipient_priv_x25519_raw)
    eph_pub = x25519.X25519PublicKey.from_public_bytes(eph_pub_raw)
    shared_secret = recipient_priv.exchange(eph_pub)

    key = derive_aead_key(shared_secret, salt)
    aead = ChaCha20Poly1305(key)
    plaintext = aead.decrypt(nonce, ciphertext, aad_bytes)

    return plaintext, sender_sign_pub_raw


# ---------- Demo ----------
if __name__ == "__main__":
    # Long-term keys
    recipient_priv_raw, recipient_pub_raw = generate_recipient_decrypt_keys()
    sender_sign_priv_raw, sender_sign_pub_raw = generate_sender_sign_keys()

    # Unicode-safe inputs (no ASCII-only restriction)
    message = "Hello Kyaw — end-to-end, single message, forward secret."
    aad = "chat-1234"

    envelope = encrypt_message(
        message,
        recipient_pub_x25519_raw=recipient_pub_raw,
        sender_sign_priv_raw=sender_sign_priv_raw,
        aad=aad,
    )

    recovered, sender_pub = decrypt_message(
        envelope,
        recipient_priv_x25519_raw=recipient_priv_raw,
    )

    print("Decrypted:", recovered.decode("utf-8"))
    print("Sender verified (base64 pub prefix):", base64.b64encode(sender_pub).decode()[:16], "…")