In [1]:
import hashlib
import secrets
import numpy as np
import matplotlib.pyplot as plt
from winternitz_utils import hash_chain, hash_message, split_into_int_chunks
from winternitz_utils import Signature, PublicKey, PrivateKey

In [2]:
####### CONSTANTS #######

# Max number of bits that a message can be
MESSAGE_BITS = 256

In [4]:
def generate_private_key(hash_chain_bits: int) -> PrivateKey:
    '''Generates a random private key'''

    if (MESSAGE_BITS % hash_chain_bits) != 0:
        raise ValueError('hash_chain_bits must be a divisor of MESSAGE_BITS')

    # Generate a random private key that will be used to encode the range 2^{hash_chain_bits}
    return [secrets.token_bytes(MESSAGE_BITS // 8) for _ in range(MESSAGE_BITS // hash_chain_bits)]


def generate_public_key(hash_chain_bits: int, private_key: PrivateKey) -> PublicKey:
    '''Generates a public key from a private key'''

    # Hash the private key n times
    fixed_num_hash = lambda x: hash_chain(2**hash_chain_bits, x)
    return list(map(fixed_num_hash, private_key))


def generate_keypair(hash_chain_bits: int):
    '''Generates a private key and a public key'''

    private_key = generate_private_key(hash_chain_bits)
    public_key = generate_public_key(hash_chain_bits, private_key)

    return private_key, public_key


def sign_message(hash_chain_bits: int, message: str, private_key: PrivateKey) -> Signature:
    '''Signs a message using a private key'''
    
    message_hash = hash_message(message)
    message_chunks = split_into_int_chunks(message_hash, hash_chain_bits)
    return [hash_chain(message_chunks[i], private_key[i]) for i in range(len(message_chunks))]


def verify_signature(hash_chain_bits: int, message: str, public_key: PublicKey, signature: Signature) -> bool:
    '''Verifies a signature using a public key and message hash'''
    message_hash = hash_message(message)
    message_chunks = split_into_int_chunks(message_hash, hash_chain_bits)
    recreated_public_key = [hash_chain(2**hash_chain_bits - message_chunks[i], signature[i]) for i in range(len(message_chunks))]
    return recreated_public_key == public_key

# Results

Running `hash_chain` on a m1 macbook in python. From running `%timeit`

| Hash Chain Length | Speed |
| --- | --- |
| $2^0$ | 768 ns ± 6.7 ns |
| $2^1$ | 1.55 µs ± 11.1 ns |
| $2^2$ | 2.76 µs ± 43.1 ns |
| $2^3$ | 5.2 µs ± 67.1 ns |
| $2^4$ | 9.94 µs ± 50.9 ns |
| $2^5$ | 19.8 µs ± 43.3 ns |
| $2^6$ | 39.2 µs ± 172 ns |
| $2^7$ | 77.2 µs ± 414 ns |
| $2^8$ | 154 µs ± 1.14 µs |
| $2^9$ | 307 µs ± 2.23 µs |
| $2^{10}$ | 627 µs ± 6.36 µs |
| $2^{11}$ | 1.24 ms ± 8.59 µs |
| $2^{12}$ | 2.48 ms ± 33 µs |
| $2^{13}$ | 4.94 ms ± 33.7 µs |
| $2^{14}$ | 9.95 ms ± 70.9 µs |
| $2^{15}$ | 20.1 ms ± 73.5 µs |
| $2^{16}$ | 39.4 ms ± 124 µs |
| $2^{17}$ | 80.6 ms ± 413 µs |
| $2^{18}$ | 161 ms ± 1.39 ms |
| $2^{19}$ | 317 ms ± 4.48 ms |
| $2^{20}$ | 645 ms ± 5.36 ms |
| $2^{21}$ | 1.29 s ± 5.88 ms |

<!-- PYTHON CODE TO PLOT RESULTS 

```python
# Total number of hashes completed
hash_chain_length = 2 ** np.arange(1, 22) 

# Time recorded in seconds
time = np.array([
    1.55 * (10**(-6)),
    2.76 * (10**(-6)),
    5.2 * (10**(-6)), 
    9.94 * (10**(-6)),
    19.8 * (10**(-6)),
    39.2 * (10**(-6)),
    77.2 * (10**(-6)),
    154 * (10**(-6)), 
    307 * (10**(-6)), 
    627 * (10**(-6)), 
    1.24 * (10**(-3)),
    2.48 * (10**(-3)),
    4.94 * (10**(-3)),
    9.95 * (10**(-3)),
    20.1 * (10**(-3)),
    39.4 * (10**(-3)),
    80.6 * (10**(-3)),
    161 * (10**(-3)), 
    317 * (10**(-3)), 
    645 * (10**(-3)), 
    1.29, 
])

plt.plot(hash_chain_length, time)
plt.show()
```

-->