In [9]:
import secrets
import numpy as np
from winternitz_utils import hash_chain, hash_message, split_into_n_chunks

In [10]:
####### CONSTANTS #######

# Max number of bits that a message can be
MESSAGE_BITS = 256

####### TYPE DEFINITIONS #######
PrivateKey = list[bytes]
PublicKey = list[bytes]
Signature = list[bytes]

### **TODO: ADD CHECKSUM TO MAKE SECURE**

In [11]:
def generate_private_key(hash_chain_bits: int) -> PrivateKey:
    '''Generates a random private key'''

    if (MESSAGE_BITS % hash_chain_bits) != 0:
        raise ValueError('hash_chain_bits must be a divisor of MESSAGE_BITS')

    # Generate a random private key that will be used to encode the range 2^{hash_chain_bits}
    return [secrets.token_bytes(MESSAGE_BITS // 8) for _ in range(MESSAGE_BITS // hash_chain_bits)]


def generate_public_key(hash_chain_bits: int, private_key: PrivateKey) -> PublicKey:
    '''Generates a public key from a private key'''

    # Hash the private key n times
    fixed_num_hash = lambda x: hash_chain(2**hash_chain_bits, x)
    return list(map(fixed_num_hash, private_key))

def generate_keypair(hash_chain_bits: int):
    '''Generates a private key and a public key'''

    private_key = generate_private_key(hash_chain_bits)
    public_key = generate_public_key(hash_chain_bits, private_key)

    return private_key, public_key


def sign_message(hash_chain_bits: int, message: str, private_key: PrivateKey) -> Signature:
    '''Signs a message using a private key'''
    
    message_hash = hash_message(message)
    message_chunks = split_into_n_chunks(message_hash, hash_chain_bits)
    return [hash_chain(message_chunks[i], private_key[i]) for i in range(len(message_chunks))]


def verify_signature(hash_chain_bits: int, message: str, public_key: PublicKey, signature: Signature) -> bool:
    '''Verifies a signature using a public key and message hash'''
    message_hash = hash_message(message)
    message_chunks = split_into_n_chunks(message_hash, hash_chain_bits)
    recreated_public_key = [hash_chain(2**hash_chain_bits - message_chunks[i], signature[i]) for i in range(len(message_chunks))]
    return recreated_public_key == public_key

In [12]:
sk, pk = generate_keypair(8)
signature = sign_message(8, 'hello', sk) 
verify_signature(8, 'hello', pk, signature)

True