In [1]:
import hashlib
import secrets
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

# Short Private Key Lamport Signatures

Instead of generating private keys using 256 random numbers we can instead seed a [CSPRNG](https://en.wikipedia.org/wiki/Cryptographically_secure_pseudorandom_number_generator) with one random number and then use the CSPRNG to generate the rest of the private key. This is a lot faster and uses less memory to store the private key.

### Key Size Comparison

| Scheme | Private Key | Public Key | Total Size |
| --- | --- | --- | --- |
| Lamport: Naive | $32 * 256 = 8192$ bytes | $32 * 256 = 8192$ bytes | $16384$ bytes |
| Lamport: Short Private Key | $32 + 16 = 48$ bytes | $32 * 256 = 8192$ bytes | $8240$ bytes |

By using the *short private key* scheme we can reduce the total size of the private key by ~170x when compared to the *naive* scheme.

<!-- (48 / 8192) * 100 = 0.59 percent -->

### Signature Size Comparison

| Scheme | Signature |
| --- | --- |
| Lamport: Naive | $32 * 256 = 8192$ bytes |
| Lamport: Short Private Key | $32 * 256 = 8192$ bytes |

There is no difference to the signature size between the two schemes. The signature size is always $32 * 256 = 8192$ bytes as we must include the hash of all used private key elements.


In [2]:
####### CONSTANTS #######

# Provides roughly 128 bits of security (see Grover's algorithm)
security_parameter = 256 # 32 bytes

# Max message length that the private key can sign
max_message_length = 256 # 32 bytes

### Helper Functions

In [3]:
def hash_key_element(key_element: bytes) -> bytes:
    '''Hashes a single key element, producing a 256-bit digest'''

    return hashlib.sha3_256(key_element).digest()


def expand_private_key(private_key: dict) -> list:
    '''Expands a private key in the form of a CSPRNG key and nonce into a list of private key elements '''

    # Create a AES256 cipher object using counter mode
    cipher = Cipher(algorithms.AES256(private_key['key']), modes.CTR(private_key['nonce']))
    zero_bytes = (0).to_bytes(length=32, byteorder='big')
    enc = cipher.encryptor()

    return [enc.update(zero_bytes) for _ in range(max_message_length * 2)]


def hash_message(message: str) -> bytes:
    '''Hashes a message using SHA3-256'''

    return hashlib.sha3_256(message.encode()).digest()


def choose_key_elements(message_hash: bytes, key: list) -> list:
    '''Chooses the key elements depending on input message hash bits (either 1 or 0)'''
    
    output = []
    for i in range(max_message_length // 8):
        for j in range(8):
            bit_index = i * 8 + j
            bit_mask = 0b10000000 >> j

            # Determine which key to use based on the message hash bit
            key_bit = int(message_hash[i] & bit_mask != 0) # either 0 or 1
            key_index = 2*bit_index + key_bit
            
            # Add the key elements to output
            output.append(key[key_index])

    return output

## AES CSPRNG

For this example implementation we will use the [AES](https://en.wikipedia.org/wiki/Advanced_Encryption_Standard) cipher in [Counter Mode](https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Counter_(CTR)) as our [CSPRNG](https://en.wikipedia.org/wiki/Cryptographically_secure_pseudorandom_number_generator#Designs_based_on_cryptographic_primitives). 

In [4]:
def generate_private_key() -> dict:
    '''Generates a private Lamport key (CSPRNG key and nonce)'''

    # Private key is a random number (can be reused as long as nonce changes)
    CSPRNG_key = secrets.token_bytes(security_parameter // 8) 
    
    # Nonce is a random number used only once
    CSPRNG_nonce = secrets.token_bytes(16) 

    return {'key': CSPRNG_key, 'nonce': CSPRNG_nonce}


def generate_public_key(private_key: list) -> list:
    '''Generates a public Lamport key from a private key'''

    # Generate public key by hashing each private key element
    return list(map(hash_key_element, private_key))


def generate_public_key(private_key: dict) -> list:
    '''Generates a public Lamport key from a private key'''

    expanded_private_key = expand_private_key(private_key)

    # Generate public key by hashing each private key element
    return list(map(hash_key_element, expanded_private_key))


def generate_keypair() -> tuple:
    '''Generates a public and private Lamport keypair'''

    # Generate the private key from random bits
    private_key = generate_private_key()

    # Generate public key by hashing each private key element
    public_key = generate_public_key(private_key)

    return (public_key, private_key)


def sign_message(message: str, private_key: dict):
    '''Signs a message with a Lamport private key'''

    # Hash the message and then commit to that hash using the private key
    message_hash = hash_message(message)

    # Choose private key elements depending on the bits of the message hash
    expanded_private_key = expand_private_key(private_key)
    return choose_key_elements(message_hash, expanded_private_key)


def verify_signature(message: str, public_key: list, signature: list) -> bool:
    '''Verifies a Lamport signature'''
    
    # Hash the message and then check that the hash of the signature matches the public key
    message_hash = hash_message(message)

    chosen_signature_hashes = choose_key_elements(message_hash, public_key)
    derived_signature_hashes = list(map(hash_key_element, signature))

    return chosen_signature_hashes == derived_signature_hashes

In [5]:
signed_string = "Hello, world!"
public_key, private_key = generate_keypair()
signature = sign_message(signed_string, private_key)
if verify_signature(signed_string, public_key, signature):
    print("Signature is correct!")
else:
    print("Signature is incorrect!")

Signature is correct!
