In [None]:
import os

from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes


# ==========================
# üîè Certificate (CA Simulation)
# ==========================

def generate_certificate(public_key_pem: bytes, identity: str) -> dict:
    """Creates a simulated self-signed certificate (identity + public key + signature)."""
    digest = hashes.Hash(hashes.SHA256())
    digest.update(identity.encode("utf-8"))
    digest.update(public_key_pem)
    cert_hash = digest.finalize()

    certificate = {
        "identity": identity,
        "public_key": public_key_pem,
        "signature": cert_hash
    }
    return certificate


def verify_certificate(certificate: dict) -> bool:
    """Verifies the simulated certificate by recomputing the hash and comparing signature."""
    try:
        identity = certificate["identity"]
        public_key_pem = certificate["public_key"]
        signature = certificate["signature"]

        digest = hashes.Hash(hashes.SHA256())
        digest.update(identity.encode("utf-8"))
        digest.update(public_key_pem)
        computed_hash = digest.finalize()

        return computed_hash == signature
    except Exception:
        return False


# ==========================
# üîë ECC Key Generation
# ==========================

def generate_ecc_keys() -> tuple:
    """Generates an ECC private/public key pair (SECP256R1) and returns (private_key, public_pem)."""
    private_key = ec.generate_private_key(ec.SECP256R1())
    public_key = private_key.public_key()

    public_pem = public_key.public_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PublicFormat.SubjectPublicKeyInfo
    )
    return private_key, public_pem


# ==========================
# üîÅ Shared Key Derivation (ECDH + HKDF)
# ==========================

def derive_shared_key(private_key, peer_public_key_pem: bytes) -> bytes:
    """Derives a 32-byte AES key using ECDH shared secret + HKDF-SHA256."""
    peer_public_key = serialization.load_pem_public_key(peer_public_key_pem)
    shared_secret = private_key.exchange(ec.ECDH(), peer_public_key)

    hkdf = HKDF(
        algorithm=hashes.SHA256(),
        length=32,
        salt=None,
        info=b"tls-handshake-sim"
    )
    aes_key = hkdf.derive(shared_secret)
    return aes_key


# ==========================
# üîí AES-GCM Encrypt / Decrypt
# ==========================

def encrypt_message(message: str, aes_key: bytes) -> tuple:
    """Encrypts a UTF-8 message using AES-GCM. Returns (iv, ciphertext, tag)."""
    iv = os.urandom(12)

    encryptor = Cipher(
        algorithms.AES(aes_key),
        modes.GCM(iv)
    ).encryptor()

    plaintext_bytes = message.encode("utf-8")
    ciphertext = encryptor.update(plaintext_bytes) + encryptor.finalize()
    tag = encryptor.tag

    return iv, ciphertext, tag


def decrypt_message(iv: bytes, ciphertext: bytes, tag: bytes, aes_key: bytes):
    """Decrypts AES-GCM data. Returns plaintext bytes if valid, otherwise None."""
    try:
        decryptor = Cipher(
            algorithms.AES(aes_key),
            modes.GCM(iv, tag)
        ).decryptor()

        plaintext = decryptor.update(ciphertext) + decryptor.finalize()
        return plaintext
    except Exception:
        return None


# ==========================
# üñ•Ô∏è Interactive TLS Handshake Simulation
# ==========================

def main():
    print("\nüîê Simulating Secure TLS Communication...\n")

    # Long-term cert/signing key (Bob)
    print("[Server] Generating Bob's certificate (signing) key pair...\n")
    server_cert_private_key, server_cert_public_pem = generate_ecc_keys()

    # Self-signed certificate (CA simulation)
    print("[CA] Issuing a certificate for Bob...\n")
    certificate = generate_certificate(server_cert_public_pem, "Bob")

    conn = 1
    while True:
        print(f"\n===== Connection #{conn} =====\n")

        # Client verifies certificate
        print("[Client] Verifying Bob‚Äôs certificate...\n")
        if not verify_certificate(certificate):
            print("[Client] ‚ùå Certificate verification failed! Exiting.\n")
            return
        print("[Client] ‚úÖ Certificate verified!\n")

        # Server generates ephemeral ECDH keys
        print("[Server] Generating ephemeral ECDH key pair...\n")
        server_eph_private_key, server_eph_public_pem = generate_ecc_keys()

        # Server signs ephemeral parameters with long-term cert key
        server_signature = server_cert_private_key.sign(
            server_eph_public_pem,
            ec.ECDSA(hashes.SHA256())
        )

        # Client verifies signature using certified public key
        print("[Client] Verifying server signature on ephemeral parameters...\n")
        bob_cert_pub = serialization.load_pem_public_key(certificate["public_key"])
        try:
            bob_cert_pub.verify(server_signature, server_eph_public_pem, ec.ECDSA(hashes.SHA256()))
            print("[Client] ‚úÖ Ephemeral parameters authenticated.\n")
        except Exception:
            print("[Client] ‚ùå Ephemeral parameters NOT authenticated! Exiting.\n")
            return

        # Client generates ephemeral ECDH keys
        print("[Client] Generating ephemeral ECDH key pair...\n")
        client_private_key, client_public_pem = generate_ecc_keys()

        # Client derives AES key
        print("[Client] Deriving shared AES session key...\n")
        client_aes_key = derive_shared_key(client_private_key, server_eph_public_pem)

        # Client encrypts message
        message = input("[Client] üìù Enter a secure message to send to Bob: ")
        iv, ciphertext, tag = encrypt_message(message, client_aes_key)

        # Server derives AES key
        print("\n[Server] üîì Deriving shared AES session key...\n")
        server_aes_key = derive_shared_key(server_eph_private_key, client_public_pem)

        # Server decrypts message
        print("\n[Server] üîì Decrypting message...\n")
        decrypted = decrypt_message(iv, ciphertext, tag, server_aes_key)
        if decrypted is None:
            print("\nüìñ Decrypted Message: Decryption Failed (Integrity Check Failed)\n")
        else:
            print("\nüìñ Decrypted Message:", decrypted.decode("utf-8"), "\n")

        print("‚úÖ Connection complete. Tearing down session state...\n")

        conn += 1
        again = input("Start another connection? (y/n): ").strip().lower()
        if again != "y":
            break

    print("\n‚úÖ TLS Handshake Simulation Finished!\n")

if __name__ == "__main__":
    main()
