In [280]:
import falcon

from common import q
from numpy import set_printoptions
from math import sqrt
from fft import fft, ifft, sub, neg, add_fft, mul_fft
from ntt import sub_zq, mul_zq, div_zq
from ffsampling import gram, ffldl_fft, ffsampling_fft
from ntrugen import ntru_gen
from encoding import compress, decompress
# https://pycryptodome.readthedocs.io/en/latest/src/hash/shake256.html
from Crypto.Hash import SHAKE256, SHA512
# Randomness
from os import urandom
from rng import ChaCha20
# For debugging purposes
import sys
if sys.version_info >= (3, 4):
    from importlib import reload  # Python 3.4+ only.


set_printoptions(linewidth=200, precision=5, suppress=True)

logn = {
    2: 1,
    4: 2,
    8: 3,
    16: 4,
    32: 5,
    64: 6,
    128: 7,
    256: 8,
    512: 9,
    1024: 10
}


# Bytelength of the signing salt and header
HEAD_LEN = 1
SALT_LEN = 40
SEED_LEN = 56


# Parameter sets for Falcon:
# - n is the dimension/degree of the cyclotomic ring
# - sigma is the std. dev. of signatures (Gaussians over a lattice)
# - sigmin is a lower bounds on the std. dev. of each Gaussian over Z
# - sigbound is the upper bound on ||s0||^2 + ||s1||^2
# - sig_bytelen is the bytelength of signatures
Params = {
    # FalconParam(2, 2)
    2: {
        "n": 2,
        "sigma": 144.81253976308423,
        "sigmin": 1.1165085072329104,
        "sig_bound": 101498,
        "sig_bytelen": 44,
    },
    # FalconParam(4, 2)
    4: {
        "n": 4,
        "sigma": 146.83798833523608,
        "sigmin": 1.1321247692325274,
        "sig_bound": 208714,
        "sig_bytelen": 47,
    },
    # FalconParam(8, 2)
    8: {
        "n": 8,
        "sigma": 148.83587593064718,
        "sigmin": 1.147528535373367,
        "sig_bound": 428865,
        "sig_bytelen": 52,
    },
    # FalconParam(16, 4)
    16: {
        "n": 16,
        "sigma": 151.78340713845503,
        "sigmin": 1.170254078853483,
        "sig_bound": 892039,
        "sig_bytelen": 63,
    },
    # FalconParam(32, 8)
    32: {
        "n": 32,
        "sigma": 154.6747794602761,
        "sigmin": 1.1925466358390344,
        "sig_bound": 1852696,
        "sig_bytelen": 82,
    },
    # FalconParam(64, 16)
    64: {
        "n": 64,
        "sigma": 157.51308555044122,
        "sigmin": 1.2144300507766141,
        "sig_bound": 3842630,
        "sig_bytelen": 122,
    },
    # FalconParam(128, 32)
    128: {
        "n": 128,
        "sigma": 160.30114421975344,
        "sigmin": 1.235926056771981,
        "sig_bound": 7959734,
        "sig_bytelen": 200,
    },
    # FalconParam(256, 64)
    256: {
        "n": 256,
        "sigma": 163.04153322607107,
        "sigmin": 1.2570545284063217,
        "sig_bound": 16468416,
        "sig_bytelen": 356,
    },
    # FalconParam(512, 128)
    512: {
        "n": 512,
        "sigma": 165.7366171829776,
        "sigmin": 1.2778336969128337,
        "sig_bound": 34034726,
        "sig_bytelen": 666,
    },
    # FalconParam(1024, 256)
    1024: {
        "n": 1024,
        "sigma": 168.38857144654395,
        "sigmin": 1.298280334344292,
        "sig_bound": 70265242,
        "sig_bytelen": 1280,
    },
}

In [248]:
import random

def seeded_rng(i):
    random_generator = random.Random(42)
    return bytes(random_generator.getrandbits(8) for _ in range(i))

In [249]:
n = 512
sk = falcon.SecretKey(n)
pk = falcon.PublicKey(sk)

In [250]:
byte_string = b'test'
signature = sk.sign(byte_string, seeded_rng)
sk.verify(byte_string, signature)

True

In [251]:
def sign_pk_recovery(sk, message, randombytes=urandom):
    """
    Sign a message. The message MUST be a byte string or byte array.
    Optionally, one can select the source of (pseudo-)randomness used
    (default: urandom).
    """
    int_header = 0x30 + logn[sk.n]
    header = int_header.to_bytes(1, "little")

    salt = randombytes(SALT_LEN)
    hashed = sk.hash_to_point(message, salt)

    # We repeat the signing procedure until we find a signature that is
    # short enough (both the Euclidean norm and the bytelength)
    while(1):
        if (randombytes == urandom):
            s = sk.sample_preimage(hashed)
        else:
            seed = randombytes(SEED_LEN)
            s = sk.sample_preimage(hashed, seed=seed)
        norm_sign = sum(coef ** 2 for coef in s[0])
        norm_sign += sum(coef ** 2 for coef in s[1])
        # Check the Euclidean norm
        if norm_sign <= sk.signature_bound:
            # Check that s2 (s[1]) is invertible
            ntt.intt(ntt.ntt(s[1]))
            # try:
            #     ntt.intt(ntt.ntt(s[1]))
            # except: # unsure what error it would cause if it's not invertible
            #     continue
            
            # signature is (compress(s1), compress(s2), r)
            enc_s1 = compress(s[0], sk.sig_bytelen - HEAD_LEN - SALT_LEN)
            enc_s2 = compress(s[1], sk.sig_bytelen - HEAD_LEN - SALT_LEN)
            
            
            # Check that the encoding is valid (sometimes it fails)
            if enc_s1 and enc_s2:
                return header + enc_s1 + enc_s2 + salt
            
            
            # enc_s = compress(s[1], sk.sig_bytelen - HEAD_LEN - SALT_LEN)
            # # Check that the encoding is valid (sometimes it fails)
            # if (enc_s is not False):
            #     return header + salt + enc_s

In [276]:
def verify_pk_recovery(sk, pk, message, signature):
    """
    Verify a signature.
    """
    # # Unpack the salt and the short polynomial s1
    # salt = signature[HEAD_LEN:HEAD_LEN + SALT_LEN]
    # enc_s = signature[HEAD_LEN + SALT_LEN:]
    # s1 = decompress(enc_s, sk.sig_bytelen - HEAD_LEN - SALT_LEN, sk.n)
    signature_length = sk.sig_bytelen - HEAD_LEN - SALT_LEN
    
    enc_s1 = signature[HEAD_LEN: signature_length + HEAD_LEN]
    enc_s2 = signature[signature_length + HEAD_LEN: 2*signature_length + HEAD_LEN]
    salt = signature[2*signature_length + HEAD_LEN:]
    
    # Need to unpack polynomial s1 and s2
    s1 = decompress(enc_s1, signature_length, sk.n)
    s2 = decompress(enc_s2, signature_length, sk.n)
    
    ## Check that s1 and s2 are valid
    if not s1 or not s2:
        print("Invalid encoding")
        return False

    # Check that the (s1, s2) is short
    norm_sign = sum(coef ** 2 for coef in s1)
    norm_sign += sum(coef ** 2 for coef in s2)
    if norm_sign > sk.signature_bound:
        print("Squared norm of signature is too large:", norm_sign)
        return False
    
    # Check that pk = H(inverse(s2)*(HashToPoint(r||m, q, n) - s1))
    
    hash_to_point_message = sk.hash_to_point(message, salt)
    recovered_pk = div_zq(sub_zq(hash_to_point_message, s1),s2)
    
    return recovered_pk == pk.h

    # If all checks are passed, accept
    return True

The encoded values are concatenated into a bit sequence of 14n bits, which is then represented as ⌈14n/8⌉ bytes.

In [170]:
int_header = 0x30 + logn[512]
header = int_header.to_bytes(1, "little")

In [253]:
signature_pkr = sign_pk_recovery(sk, byte_string, seeded_rng)

In [277]:
verify_pk_recovery(sk, pk, byte_string, signature_pkr)

True

In [278]:
pk.h

[1696,
 1014,
 2238,
 4305,
 11150,
 6416,
 9051,
 409,
 7413,
 9994,
 8422,
 10963,
 7602,
 1,
 6827,
 10648,
 7691,
 3004,
 2754,
 7198,
 7206,
 11643,
 10601,
 11753,
 6951,
 8817,
 7995,
 7351,
 8291,
 10098,
 3975,
 2106,
 9602,
 11214,
 8703,
 1914,
 6378,
 8901,
 3741,
 12241,
 5819,
 2144,
 11121,
 10029,
 11776,
 5586,
 1289,
 4210,
 5454,
 9980,
 10150,
 485,
 2943,
 9723,
 8358,
 3169,
 7745,
 5006,
 2849,
 9435,
 2784,
 11720,
 10715,
 3555,
 5822,
 4072,
 4111,
 6647,
 4793,
 8908,
 7624,
 2628,
 59,
 11622,
 7391,
 6537,
 1188,
 11644,
 8758,
 5775,
 8181,
 1510,
 2587,
 7465,
 12200,
 10339,
 9684,
 1418,
 3117,
 4153,
 8867,
 2865,
 5077,
 5683,
 5009,
 8964,
 4904,
 9739,
 8397,
 5889,
 8020,
 11417,
 8805,
 195,
 10683,
 5628,
 5433,
 9199,
 1498,
 5173,
 823,
 5882,
 2769,
 7204,
 9219,
 10536,
 7323,
 7767,
 2620,
 5030,
 3787,
 516,
 9821,
 7881,
 4647,
 1861,
 8389,
 8796,
 7982,
 9432,
 865,
 3848,
 6778,
 8026,
 3789,
 10036,
 12184,
 2225,
 9714,
 7204,
 11380,

In [285]:
tuple(pk.h)

(1696,
 1014,
 2238,
 4305,
 11150,
 6416,
 9051,
 409,
 7413,
 9994,
 8422,
 10963,
 7602,
 1,
 6827,
 10648,
 7691,
 3004,
 2754,
 7198,
 7206,
 11643,
 10601,
 11753,
 6951,
 8817,
 7995,
 7351,
 8291,
 10098,
 3975,
 2106,
 9602,
 11214,
 8703,
 1914,
 6378,
 8901,
 3741,
 12241,
 5819,
 2144,
 11121,
 10029,
 11776,
 5586,
 1289,
 4210,
 5454,
 9980,
 10150,
 485,
 2943,
 9723,
 8358,
 3169,
 7745,
 5006,
 2849,
 9435,
 2784,
 11720,
 10715,
 3555,
 5822,
 4072,
 4111,
 6647,
 4793,
 8908,
 7624,
 2628,
 59,
 11622,
 7391,
 6537,
 1188,
 11644,
 8758,
 5775,
 8181,
 1510,
 2587,
 7465,
 12200,
 10339,
 9684,
 1418,
 3117,
 4153,
 8867,
 2865,
 5077,
 5683,
 5009,
 8964,
 4904,
 9739,
 8397,
 5889,
 8020,
 11417,
 8805,
 195,
 10683,
 5628,
 5433,
 9199,
 1498,
 5173,
 823,
 5882,
 2769,
 7204,
 9219,
 10536,
 7323,
 7767,
 2620,
 5030,
 3787,
 516,
 9821,
 7881,
 4647,
 1861,
 8389,
 8796,
 7982,
 9432,
 865,
 3848,
 6778,
 8026,
 3789,
 10036,
 12184,
 2225,
 9714,
 7204,
 11380,

In [296]:
pk_bytes = b''.join(int.to_bytes(num, length = 14, byteorder='little', signed=False) for num in pk.h)

In [294]:
int.to_bytes?

In [291]:
len(pk_bytes)

2048

In [297]:
test = SHA512.new(pk_bytes)

In [300]:
len(test.digest())

64