In [1]:
import hashlib
import secrets
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

# Short Public Key Lamport Signatures

In order to shrink the size of the public key we can use a vector commitment to the public key elements.

![Public Key Merkle Root](./images/public_key_merkle_root.png)

### Key Size Comparison

| Scheme | Private Key | Public Key | Total Size |
| --- | --- | --- | --- |
| Lamport: Naive | $32 * 256 = 8192$ bytes | $32 * 256 = 8192$ bytes | $16384$ bytes |
| Lamport: Short Private Key | $32 + 16 = 48$ bytes | $32 * 256 = 8192$ bytes | $8240$ bytes |
| Lamport: Short Public Key | $32 * 256 = 8192$ bytes | $32$ bytes | $8224$ bytes |

By using a vector commitment (Merkle root) to the public key elements we can reduce the size of the public key to 32 bytes. By using this method we make the public key 256x smaller compared to the *naive* method. 

### Signature Size Comparison

| Scheme | Signature |
| --- | --- |
| Lamport: Naive | $32 * 256 = 8192$ bytes |
| Lamport: Short Private Key | $32 * 256 = 8192$ bytes |
| Lamport: Short Public Key | $32 * 256 + 32 * 256 = 16348$ bytes |

A downsize of reducing the size of the public key is that the signature size increases. This is because we now must place the public key elements in the signature.

In [2]:
####### CONSTANTS #######

# Provides roughly 128 bits of security (see Grover's algorithm)
security_parameter = 256 # 32 bytes

# Max message length that the private key can sign
max_message_length = 256 # 32 bytes


####### TYPE DEFINITIONS #######    
PrivateKey = list[bytes]
PublicKey = bytes
Signature = list[bytes]

### Helper Functions

In [3]:
def hash_bytes(b: bytes) -> bytes:
    '''Hashes bytes producing a 256-bit digest'''
    
    return hashlib.sha3_256(b).digest()


def hash_message(message: str) -> bytes:
    '''Hashes a message using SHA3-256'''

    return hashlib.sha3_256(message.encode()).digest()


def convert_bytes_to_bits(b: bytes) -> list[bool]:
    '''Converts bytes to a list of bits'''
    
    output = []
    for i in range(len(b)):
        for j in range(8):
            bit_mask = 0b10000000 >> j
            output.append(b[i] & bit_mask != 0)
            
    return output


def merkle_root(ls: list[bytes]) -> bytes:
    '''Computes the Merkle root of a list of bytes. The list must have a length of 2^n.'''

    # Base case: Return the a single leaf element
    # IMPORTANT: The leaf element is not hashed!
    if len(ls) == 1:
        return ls[0]

    # Recursive case: Hash the left and right subtrees
    left_subtree = merkle_root(ls[:len(ls)//2])
    right_subtree = merkle_root(ls[len(ls)//2:])
    return hash_bytes(left_subtree + right_subtree)

In [4]:
def generate_private_key() -> PrivateKey:
    '''Generates a private Lamport key'''

    # Generate the private key from random bits
    return [secrets.token_bytes(security_parameter // 8) for _ in range(2 * max_message_length)]


def generate_public_key(private_key: PrivateKey) -> PublicKey:
    '''Generates a public Lamport key from a private key'''

    # Generate expanded public key by hashing each private key element
    expanded_public_key = list(map(hash_bytes, private_key))

    # Use the merkle root of the expanded public key as the public key (essentially vector commitment)
    return merkle_root(expanded_public_key)


def generate_keypair() -> tuple[PublicKey, PrivateKey]:
    '''Generates a public and private Lamport keypair'''

    private_key = generate_private_key()
    public_key = generate_public_key(private_key)

    return (public_key, private_key)


def sign_message(message: str, private_key: PrivateKey) -> Signature:
    '''Signs a message with a Lamport private key'''

    # Hash the message and then commit to that hash using the private key
    message_hash = hash_message(message)
    message_bits = convert_bytes_to_bits(message_hash)

    # Signature is collection of public and private key elements in array
    signature = []
    for i in range(len(message_bits)):
        # 0th row
        if message_bits[i] == False: 
            signature.append(private_key[2*i])
            signature.append(hash_bytes(private_key[2*i + 1]))
        # 1st row
        else: 
            signature.append(hash_bytes(private_key[2*i]))
            signature.append(private_key[2*i + 1])

    return signature


def verify_signature(message: str, public_key: PublicKey, signature: Signature) -> bool:
    '''Verifies a Lamport signature'''
    
    # Hash the message and then check that the hash of the signature matches the public key
    message_hash = hash_message(message)
    message_bits = convert_bytes_to_bits(message_hash)

    modified_signature = signature.copy()
    for i in range(len(message_bits)):
        # 0th row
        if message_bits[i] == False: 
            modified_signature[2*i] = hash_bytes(modified_signature[2*i])
        # 1st row
        else: 
            modified_signature[2*i + 1] = hash_bytes(modified_signature[2*i + 1])

    # Encuse that the signiture is matches the vector commitment
    return public_key == merkle_root(modified_signature)

In [5]:
signed_string = "Hello, world!"
public_key, private_key = generate_keypair()
signature = sign_message(signed_string, private_key)
if verify_signature(signed_string, public_key, signature):
    print("Signature is correct!")
else:
    print("Signature is incorrect!")

Signature is correct!


In [6]:
# Try and forge a signature by using the same private key for a different message
random_signature = [secrets.token_bytes(security_parameter // 8) for _ in range(max_message_length * 2)]
if verify_signature(signed_string, public_key, random_signature):
    print("Forgery is correct!")
else:
    print("Forgery is incorrect!")

Forgery is incorrect!
