# Sphincs

In [1]:
# Import math Library
from math import log2, floor, ceil
from cryptography.hazmat.primitives import hashes
import copy

## Auxiliar Functions

Termos:
* ceil(x) floor(x) log2(x) log(x)-> len => int x
* Trunc_l =  x[:l]  floor(log(n,2).n() / 8) => bit-string x

In [2]:
'''
y-byte string containing the binary representation of x in
big-endian byte-order

y: size byte-string
x: int 
'''
def toByte(x,y):    
    return int(x).to_bytes(y, 'big')

'''
X : len_x-byte string
w : element of {4,16,256}
Outputs a array of out_len integers between 0 and w-1

8 ∗ len_X/ log(w).
'''
def base_w(X,w,out_len):    
    in_ = 0
    out_ = 0
    total_ = 0
    bits = 0
    
    basew = [0] * out_len
    
    for i in range(out_len):
        if(bits == 0):
            total = X[in_]
            in_ += 1
            bits += 8
        bits -= log2(w)
        basew[out_] = (total >> int(bits)) and (w - 1)
        out_ += 1
    return basew

## Cryptographic (Hash) Function Families

In [3]:
# Variáveis Globais:

# SPHINCS+ - 128f:
n = 16
h = 66
d = 22
t = 64
k = 33
w = 16
bitsec = 128
sec_level = 1
sig_bytes = 17088

len1 = ceil( (8*n) / log2(w) )
len2 = floor( log2(len1*(w - 1)) / log2(w) ) + 1
len_ = len1 + len2  


WOTS_HASH = 0
WOTS_PK = 1
TREE = 2
FORS_TREE = 3
FORS_ROOTS = 4

RANDOMIZE = True

### Tweakable Hash Functions:

* context information in form of an address ADRS...
* ln-byte message M to an n-byte hash value md using an n-byte seed PK.seed and a 32-byte address

    ```
    T_l :: Bn x B32 x Bln -> Bn
    md = T_l(PK.seed, ADRS, message)
    
    F = T_1
    H = T_2
    ```

In [4]:
## Simple Variant
# PK_seed, ADRS, M1, M2, M: byte string

def F(PK_seed, ADRS, M1):
    shake = hashes.Hash(hashes.SHAKE256(int(n)))
    shake.update(PK_seed + ADRS + M1)
    return shake.finalize()

def H(PK_seed, ADRS, Ma): # Ma = M1 + M2
    shake = hashes.Hash(hashes.SHAKE256(int(n)))
    shake.update(PK_seed + ADRS + Ma)
    return shake.finalize()

    
    # M is a ln-byte message
def T_l(PK_seed, ADRS, M): 
    shake = hashes.Hash(hashes.SHAKE256(int(n)))
    shake.update(PK_seed + ADRS + M)
    return shake.finalize()


### PRF and Message Digest: 

* PRF for pseudorandom key generation
* PRF_msg to generate randomness for the message compression
* To compress the message to be signed, SPHINCS + uses an additional keyed hash function H_msg that can process arbitrary length messages


In [5]:
# SPHINCS+ - SHAKE 256:

# R, PK_seed, PK_root, M : byte-strings
def H_msg(R, PK_seed, PK_root, M, m):
    shake = hashes.Hash(hashes.SHAKE256(int(m)))
    shake.update( R + PK_seed + PK_root + M)
    return shake.finalize()

# SEED, ADRS : byte-strings
def PRF(SEED, ADRS):
    shake = hashes.Hash(hashes.SHAKE256(int(n)))
    shake.update(SEED + ADRS)
    return shake.finalize()
    
# SK_prf, OptRand, M: byte-strings
def PRF_msg(SK_prf, OptRand, M):
    shake = hashes.Hash(hashes.SHAKE256(int(n)))
    shake.update(SK_prf + OptRand + M)
    return shake.finalize()

### Hash Function Address Scheme (Structure of ADRS):

An address ADRS is a 32-byte value that follows a defined structure. 

1. Used for the hashes in WOTS+ schemes
2. Used for compression of the WOTS+ public key
3. Used for hashes within the main Merkle tree construction
4. Used for the hashes in the Merkle tree in FORS
5. Used for the compression of the tree roots of FORS

word being 32 bits (4 bytes)

All fields within these addresses encode unsigned integers

When describing the generation of addresses we use set methods that take positive integers and set the bits of a field to the binary representation of that integer, in big-endian notation.

we refer to them respectively using the constants WOTS_HASH, WOTS_PK, TREE, FORS_TREE and FORS_ROOTS.

In [6]:
class ADRS:
    
    # [0:4]: layer address
    # [4:16]: tree address
    # [16:20]: type {0,1,2,3,4}
    '''
    0: WOTS + hash address
    1: compression of the WOTS + public key
    2: hash tree address
    3: FORS address
    4: compression of FORS tree roots
    '''
    
    '''
    0: WOTS_HASH
        [20:24]: key pair address
        [24:28]: chain address
        [28:32]: hash address
    '''
    
    '''
    1: WOTS_PK
        [20:24]: key pair address
        [24:32]: padding = 0
    '''
    
    '''
    2: TREE
        [20:24]: padding = 0
        [24:28]: tree height
        [28:32]: tree index
    '''
    
    '''
    3: FORS_TREE
        [20:24]: key pair address
        [24:28]: tree height
        [28:32]: tree index
    '''
    
    '''
    4: FORS_ROOTS
        [20:24]: key pair address
        [24:32]: padding
    '''
    
    def __init__(self):
        self.adsr = bytearray(toByte(0,32))

    def setLayerAddress(self, val):
        self.adsr[0:4] = val
    
    def setTreeAddress(self, val):
        self.adsr[4:16] = val
    
    def setType(self, type_):
        self.adsr[16:20] = type_
        self.adsr[20:32] = toByte(0, 12)
        
    def setKeyPairAddress(self, key_pair_address):
        self.adsr[20:24] = key_pair_address
        
    def getKeyPairAddress(self):
        return self.adsr[20:24]
    
    def setChainAddress(self, i):
        self.adsr[24:28] = i
        
    def setHashAddress(self, val):
        self.adsr[28:32] = val
    
    def setTreeHeight(self, value):
        self.adsr[24:28] = value
        
    def getTreeHeight(self):
        return self.adsr[24:28]
    
    def setTreeIndex(self, value):
        self.adsr[28:32] = value
    
    def getTreeIndex(self):
        return self.adsr[28:32]
    
    def getBytes(self):
        return bytes(self.adsr)


## WOTS+ One-Time Signatures

* each private key MUST NOT be used to sign more than a single message. In particular, if a private key is used to sign two different messages, the scheme becomes insecure

* tailored to the use inside of SPHINCS+

* Tópicos:
    * Parameters
    * Chaining Functions: formas main building block of the WOTS+ scheme
    * Algorithm for key generation and signing 
    * Comput WOTS+ public key from a WOTS+ signature


In [7]:
class WOTS:
    # n: the security parameter (message, private_key, public_key, signature length
    # w: the Winternitz parameter it is an element of the set {4, 16, 256}

    #len1 = ceil( (8*n) / log2(w) )
    #len2 = floor( log2(len1*(w - 1)) / log2(w) ) + 1
    #len_ = len1 + len2
    def __init__(self, n, w, len1, len2, len_):
        self.n = n
        self.w = w
        self.len1 = len1
        self.len2 = len2
        self.len_ = len_
    
    # Input: Input string X, start index i, number of steps s, public seed PK.seed, address ADRS
    # Output: value of F iterated s times on X
    
    # ADRS must have the first seven 32-bit set to encode the address of this chain.
    # ADRS 32-byte define in Section 2.7.3
    # PK_seed n-byte string
    
    def chain(self, X, i, s, PK_seed, ADRS):
        if s == 0:
            return bytes(X)
        if (i + s) > (self.w - 1):
            return -1
        
        # tmp array byte[n]
        tmp = self.chain(X, i , s-1, PK_seed, ADRS)
        
        pos = i + s - 1
        ADRS.setHashAddress(toByte(pos, 4)) # address is updated to encode the current position
        tmp = F(PK_seed, ADRS.getBytes(), tmp)
        return tmp
        
        
    
    # Each n-byte string in the WOTS + private key is derived from a secret seed SK.seed which is part of the SPHINCS + secret key and a WOTS + address ADRS using PRF.
    # Input: secret seed SK.seed, address ADRS
    # Output: WOTS+ private key sk : length len_ array of n-byte strings
    def wots_SKgen(self, SK_seed, ADRS):
        sk = []
        for i in range(self.len_): # POSSIVEL ERRO?
            ADRS.setChainAddress(toByte(i,4))
            ADRS.setHashAddress(toByte(0,4))
            sk.append(PRF(SK_seed, ADRS.getBytes())) # Returns a byte-string size n            
        return sk
    
    
    
    #Input: secret seed SK.seed, address ADRS, public seed PK.seed
    #Output: WOTS+ public key pk
    def wots_PKgen(self, SK_seed, PK_seed, ADRS):
        wotspkADRS = copy.deepcopy(ADRS) # copy address to create OTS public key address
        
        tmp = bytes()
        #tmp é uma byte-string de len bytes
        
        for i in range(self.len_): 
            ADRS.setChainAddress(toByte(i,4))
            ADRS.setHashAddress(toByte(0,4))
            sk = PRF(SK_seed, ADRS.getBytes())
            tmp += bytes(self.chain(sk, 0, self.w - 1, PK_seed, ADRS))
        
        wotspkADRS.setType(toByte(WOTS_PK, 4))
        wotspkADRS.setKeyPairAddress(ADRS.getKeyPairAddress())
        pk = T_l(PK_seed, wotspkADRS.getBytes(),  tmp)
        return pk
    
    
    def wots_sign(self, M, SK_seed, PK_seed, ADRS):
        # convert message to base w
        msg = base_w(M, self.w, self.len1)
        
        # compute checksum
        csum = 0
        for i in range(self.len1):
            csum += self.w - 1 - msg[i]
    
        # convert csum to base w
        if( (log2(self.w) % 8) != 0):
            csum = csum << ( 8 - ( ( self.len2 * log2(self.w) ) % 8 ))

        len_2_bytes = ceil( ( self.len2 * log2(self.w) ) / 8 )
        msg += base_w(toByte(csum, len_2_bytes),self.w, self.len2)
        sig = []
        for i in range(self.len_):
            ADRS.setChainAddress(toByte(i, 4))
            ADRS.setHashAddress(toByte(0,4))
            sk = PRF(SK_seed, ADRS.getBytes())
            sig += [self.chain(sk, 0, msg[i], PK_seed,ADRS)]
    
        return sig
    

    #Input: Message M, WOTS+ signature sig, address ADRS, public seed PK.seed
    #Output: WOTS+ public key pk_sig derived from sig
    def wots_pkFromSig(self, sig, M, PK_seed, ADRS):
        csum = 0
        wotspkADRS = copy.deepcopy(ADRS)
         
        # convert message to base w
        msg = base_w(M, self.w, self.len1)
        
        # compute checksum
        for i in range(self.len1):
            csum += self.w - 1 - msg[i]
        
        # convert csum to base w
        csum = csum << ( 8 - ( ( self.len2 * log2(self.w) ) % 8 ))
        len_2_bytes = ceil( ( self.len2 * log2(self.w) ) / 8 )
        msg += base_w(toByte(csum, len_2_bytes), self.w, self.len2)
        
        tmp = bytes()
        for i in range(self.len_):
            ADRS.setChainAddress(toByte(i, 4))
            tmp += self.chain(sig[i], msg[i], self.w - 1 - msg[i], PK_seed,ADRS)
        
        wotspkADRS.setType(toByte(WOTS_PK, 4))
        wotspkADRS.setKeyPairAddress(ADRS.getKeyPairAddress())
        pk_sig = T_l(PK_seed, wotspkADRS.getBytes(), tmp)
        return pk_sig

## The SPHINCS+ Hypertree

1. WOTS+ gets combined with a bnary hash tree, leading to a fized input-length version of the XMSS
2. hoe to go to a hypertree from there. The hypertree might be viewed as a fixed input-length version of multi-tree XMSS.


### (Fixed Input-Length) XMSS:

It authenticates 2^h WOTS + public keys using a binary tree of height h' .

Each node in the binary tree is an n-byte value which is the tweakable hash of the concatenation of its two child nodes. The leaves are the WOTS + public keys. The XMSS
public key is the root node of the tree. In SPHINCS + , the XMSS secret key is the single secret seed that is used to generate all WOTS + secret keys.

An XMSS signature in the context of SPHINCS + consists of the WOTS + signature on the message and the so-called authentication path. The latter is a vector of tree nodes that allow a verifier to compute a value for the root of the tree starting from a WOTS + signature. A verifier computes the root value and verifies its correctness.

In [8]:
class XMSS:
    # h : the height of the tree
    # n : the length in bytes of messages
    # w: the Winternitz parameter for WOTS+
    # SIG_XMSS: XMSS signatures
    def __init__ (self, h_, n, w):
        self.h_ = h_
        self.n = n
        self.w = w
        self.wots = WOTS(n,w, len1, len2, len_)
        
        
    # XMSS private key is the single secret seed SK_seed
    
    
    # It is REQUIRED that s % 2^z = 0
    # The treehash algorithm described here uses a stack holding up to (z − 1) nodes, with the usual stack functions push() and pop()
    # Input: Secret seed SK_seed, start index s, target node height z, public seed PK_seed, address ADRS
    # Output: n-byte root node - top node on Stack
    def treehash(self, SK_seed, s, z, PK_seed, ADRS):
        if s % (1 << z) != 0 :
            return -1
    
        stack = []
    
        for i in range(2**z):
            ADRS.setType(toByte(WOTS_HASH, 4))
            ADRS.setKeyPairAddress(toByte(s + i, 4))
            node = self.wots.wots_PKgen(SK_seed, PK_seed, ADRS)
            
            ADRS.setType(toByte(TREE, 4))
            ADRS.setTreeHeight(toByte(1, 4))
            ADRS.setTreeIndex(toByte(s + i, 4))
            
            if len(stack) > 0:
                while bytes(stack[len(stack)-1][1]) == bytes(ADRS.getTreeHeight()): # !!!!!!!!!!!!!!!!!
                    tree_idx = int.from_bytes(ADRS.getTreeIndex(),'big')
                    
                    ADRS.setTreeIndex(toByte((( tree_idx - 1) // 2), 4))
                    aux = stack.pop()[0] + node
                    node = H(PK_seed, ADRS.getBytes(), aux)
                    tree_h = int.from_bytes(ADRS.getTreeHeight(),'big')
                    ADRS.setTreeHeight(toByte( tree_h + 1 ,4))

                    if len(stack) <= 0:
                        break
                    
            stack.append((node, ADRS.getTreeHeight())) #push
            
        return stack.pop()[0]

    
    # the XMSS public key PK is the root of the binary hash tree
    # The latter encodes the position of this XMSS instance within the SPHINCS + structure
    # Input: Secret seed SK.seed, public seed PK.seed, address ADRS
    # Output: XMSS public key PK
    def xmss_PKgen(self, SK_seed, PK_seed, ADRS):
        pk = self.treehash(SK_seed, 0, self.h_, PK_seed, ADRS)
        return pk

    
    # XMSS Signature: ((len + h' ) ∗ n)-byte string:
    #    WOTS+ signature sig taking len - n bytes
    #    AUTH: autentication path taking h' * n bytes
    
    # Input: n-byte message M, secret seed SK.seed, index idx, public seed PK.seed, address ADRS
    # Output: XMSS signature SIG_XMSS = (sig || AUTH)
    def xmss_sign(self, M, SK_seed, idx, PK_seed, ADRS):
        AUTH = []
    
        # build authentication path
        for j in range(self.h_):
            k_aux = floor(idx // 2**j)
            k = k_aux ^^ 1 # If it is bytes, i need to change this
            AUTH += [self.treehash(SK_seed, k * 2**j, j, PK_seed, ADRS)]
            
        ADRS.setType(toByte(WOTS_HASH, 4))
        ADRS.setKeyPairAddress(toByte(idx, 4))
        
        sig = self.wots.wots_sign(M, SK_seed, PK_seed, ADRS)
        SIG_XMSS = sig + AUTH
        return SIG_XMSS

    
    # Input: index idx, XMSS signature SIG_XMSS = (sig || AUTH), n-byte message M, public seed PK.seed, address ADRS
    # Output: n-byte root value node[0]
    def xmss_pkFromSig(self, idx, SIG_XMSS, M, PK_seed, ADRS):
        # compute WOTS+ pk from WOTS+ sig
        
        ADRS.setType(toByte(WOTS_HASH, 4))
        ADRS.setKeyPairAddress(toByte(idx,4))
        
        sig = SIG_XMSS[:len_]
        AUTH = SIG_XMSS[len_:]
        
        
        node0 = self.wots.wots_pkFromSig(sig, M, PK_seed, ADRS)
        node1 = 0
        
        # compute root from WOTS+ pk and AUTH
        ADRS.setType(toByte(TREE,4))
        ADRS.setTreeIndex(toByte(idx,4))
        
        for i in range(self.h_):
            ADRS.setTreeHeight(toByte(i+1, 4))
            
            tree_idx = int.from_bytes(ADRS.getTreeIndex(),'big')
            if (floor(idx / 2**i) % 2) == 0:
                ADRS.setTreeIndex( toByte( tree_idx // 2 , 4) )
                node1 = H(PK_seed, ADRS.getBytes(), node0 + AUTH[i])
            else:
                ADRS.setTreeIndex( toByte( (tree_idx - 1) // 2 , 4) )
                node1 = H(PK_seed, ADRS.getBytes(), AUTH[i] + node0)
            node0 = node1
        return node0

### HT: The Hypertree

In [9]:
class HT:
        
    def __init__(self, h, n, w, d):
        self.h = h
        self.n = n
        self.w = w
        self.d = d
        self.len_ = len_
        self.xmss = XMSS(h/d, n, w)
        
        
    # Input: Private seed SK.seed, public seed PK.seed
    # Output: HT public key PK_HT
    def ht_PKgen(self, SK_seed, PK_seed):
        adrs = ADRS()
        adrs.setLayerAddress(toByte(self.d-1, 4))
        adrs.setTreeAddress(toByte(0, 12))
        root = self.xmss.xmss_PKgen(SK_seed, PK_seed, adrs)
        return root
    
    # Input: Message M, private seed SK.seed, public seed PK.seed, tree index idx_tree, leaf index idx_leaf
    # Output: HT signature SIG_HT
    def ht_sign(self, M, SK_seed, PK_seed, idx_tree, idx_leaf):
        # init
        adrs = ADRS()
                
        adrs.setLayerAddress(toByte(0, 4))
        adrs.setTreeAddress(toByte(idx_tree, 12))
        
        SIG_tmp = self.xmss.xmss_sign(M, SK_seed, idx_leaf, PK_seed, adrs)
        SIG_HT = SIG_tmp 
        root = self.xmss.xmss_pkFromSig(idx_leaf, SIG_tmp, M, PK_seed, adrs)
        
        for j in range(1, d):
            # idx_leaf_aux = toByte(idx_tree, 4)[-(self.h / self.d):] # !:  least significant bits of idx_tree
            # idx_leaf = int.from_bytes(idx_leaf_aux, 'big')
            # idx_tree_aux = toByte(idx_tree, 4)[:(self.h - (j + 1) * (self.h / self.d))] # !: most significant bits of idx_tree
            # idx_tree = int.from_bytes(idx_tree_aux, 'big')
            
            idx_leaf = idx_tree % 2**(self.h//self.d)
            idx_tree = idx_tree >> (self.h//self.d)
            
            adrs.setLayerAddress(toByte(j, 4))
            adrs.setTreeAddress(toByte(idx_tree, 12))
            
            SIG_tmp = self.xmss.xmss_sign(root, SK_seed, idx_leaf, PK_seed, adrs)
            SIG_HT = SIG_HT + SIG_tmp
            if j < self.d - 1:
                root = self.xmss.xmss_pkFromSig(idx_leaf, SIG_tmp, root, PK_seed, adrs)
        
        return SIG_HT
    
    
    
    # Input: Message M, signature SIG_HT, public seed PK.seed, tree index idx_tree, leaf index idx_leaf, HT public key PK_HT.
    # Output: Boolean
    def ht_verify(self, M, SIG_HT, PK_seed, idx_tree, idx_leaf, PK_HT):
        # init
        adrs = ADRS()
        
        h_ = self.h//self.d
        
        # verify
        SIG_tmp = SIG_HT[:(h_ + len_)]
        
        adrs.setLayerAddress(toByte(0,4))
        adrs.setTreeAddress(toByte(idx_tree,12))
        node = self.xmss.xmss_pkFromSig(idx_leaf, SIG_tmp, M, PK_seed, adrs)
        
        for j in range(1, self.d):
            # idx_leaf_aux = toByte(idx_tree, 4)[-(self.h / self.d):]
            # idx_leaf = int.from_bytes(idx_leaf_aux, 'big')# (self.h / self.d) least significant bits of idx_tree
            # idx_tree_aux = toByte(idx_tree, 4)[:(self.h - (j + 1) * (self.h / self.d))] # !: most significant bits of idx_tree
            # idx_tree = int.from_bytes(idx_tree_aux, 'big')
            idx_leaf = idx_tree % 2**(self.h//self.d)
            idx_tree = idx_tree >> (self.h//self.d)
            
            
            SIG_tmp = SIG_HT[ (h_ + len_)*j : (h_ + len_)*(j+1) ]# .getXMSSSignature(j)
            
            adrs.setLayerAddress(toByte(j, 4))
            adrs.setTreeAddress(toByte(idx_tree,12))
            node = self.xmss.xmss_pkFromSig(idx_leaf, SIG_tmp, node, PK_seed, adrs)
            
        if ( node == PK_HT ):
            return True
        else:
            return False

## FORS: Forest Of Random Subsets:

In [10]:
class FORS:
    
    def __init__(self, n, k, t):
        self.n = n
        self.k = k
        self.t = t
        
    #Input: secret seed SK.seed, address ADRS, secret key index idx = it+j
    #Output: FORS private key sk
    def fors_SKgen(self, SK_seed, ADRS, idx):
        ADRS.setTreeHeight(toByte(0,4))
        ADRS.setTreeIndex(toByte(idx, 4))
        sk = PRF(SK_seed, ADRS.getBytes())
        return sk
    
    
    # Input: Secret seed SK_seed, start index s, target node height z, public seed PK_seed, address ADRS
    # Output: n-byte root node - top node on Stack
    
    def fors_treehash(self, SK_seed, s, z, PK_seed, ADRS):
        if s % (1 << z) != 0:
            return -1
        
        stack = []
        
        for i in range(2**z):
            ADRS.setTreeHeight(toByte(0,4))
            ADRS.setTreeIndex(toByte( (s+i), 4 ))
            sk = PRF(SK_seed, ADRS.getBytes())
            node = F(PK_seed, ADRS.getBytes(), sk)
            
            ADRS.setTreeHeight( toByte(1, 4) )
            ADRS.setTreeIndex( toByte((s+i), 4) )
            
            if len(stack) > 0:
                treeidx =  int.from_bytes(ADRS.getTreeIndex(),'big')
                while bytes(stack[ len(stack) - 1 ]['height']) == bytes(ADRS.getTreeIndex()):
                    ADRS.setTreeIndex( toByte((( treeidx - 1) // 2), 4) )
                    aux = stack.pop()['node'] + node
                    node = H(PK_seed, ADRS.getBytes(), aux)
                    treeheight = int.from_bytes(ADRS.getTreeHeight(),'big')
                    ADRS.setTreeHeight( toByte((treeheight + 1) ,4) )
                    
                    if len(stack) <= 0:
                        break
           
            stack.append({'node': node, 'height': ADRS.getTreeHeight()})
        
        return stack.pop()['node']
    
    
    def fors_PKgen(self, SK_seed, PK_seed, ADRS):
        forspkADRS = copy.deepcopy(ADRS); # copy address to create FTS public key address
        
        a = log2(self.t)
        
        root = bytes()
        for i in range(self.k):
            root += self.fors_treehash(SK_seed, i * self.t, int(a), PK_seed, ADRS)
        
        forspkADRS.setType(toByte(FORS_ROOTS, 4))
        forspkADRS.setKeyPairAddress(ADRS.getKeyPairAddress())
        pk = T_l(PK_seed, forspkADRS.getBytes(), root) # root k*n
        
        return pk
        
    
    def fors_sign(self, M, SK_seed, PK_seed, ADRS):
        a = int(log2(self.t))
        
        SIG_FORS = []
        m_int = int.from_bytes(M, 'big')
        
        # compute signature elements
        for i in range(self.k):
            # get next index
            idx = (m_int >> ((self.k - 1 - i) * a)) % self.t
            
            # pick private key element
            ADRS.setTreeHeight( toByte(0,4) )
            ADRS.setTreeIndex( toByte( (i*self.t + idx),4 ) )
            SIG_FORS += [PRF(SK_seed, ADRS.getBytes())]
                        
            AUTH = []
            
            # compute auth path
            for j in range(a):
                s = floor( idx // 2**j ) ^^ 1
                AUTH += [self.fors_treehash(SK_seed, i * self.t + s * 2**j, j, PK_seed, ADRS)]
                            
            SIG_FORS += AUTH
        
        return SIG_FORS
    
    def aux(self, SIG_FORS):
        sigs = []
        a = int(log2(self.t))
        
        for i in range(self.k):
            sigs.append([])
            sigs[i].append(SIG_FORS[(a+1) * i])
            sigs[i].append(SIG_FORS[((a+1) * i + 1):((a+1) * (i+1))])
        return sigs
    
    
    def fors_pkFromSig(self, SIG_FORS, M, PK_seed, ADRS):
        m_int = int.from_bytes(M, 'big')
        a = int(log2(self.t))
        
        sigs = self.aux(SIG_FORS)
                
        root = bytes()
        
        # compute roots
        for i in range(self.k):
            # get next index
            idx = (m_int >> (self.k - 1 - i) * a) % self.t
            
            # compute leaf
            sk = sigs[i][0]
            
            ADRS.setTreeHeight(toByte(0,4))
            ADRS.setTreeIndex(toByte( i*self.t + idx , 4 ))
            node = [0] * 2
            node[0] = F(PK_seed, ADRS.getBytes(), sk)
            
            # compute root from leaf and AUTH
            auth = sigs[i][1]# !!
            
            ADRS.setTreeIndex( toByte(i*self.t+idx , 4) )
            
            for j in range(a):
                ADRS.setTreeHeight(toByte(j+1, 4))
                treeidx = int.from_bytes(ADRS.getTreeIndex(),'big')
                
                if floor(idx / 2**j ) % 2 == 0:
                   ADRS.setTreeIndex( toByte( treeidx // 2 , 4))
                   node[1] = H(PK_seed, ADRS.getBytes(), node[0] + auth[j] )
                else:
                   ADRS.setTreeIndex( toByte( ((treeidx - 1) // 2), 4 ))
                   node[1] = H(PK_seed, ADRS.getBytes(), auth[j] + node[0])
                node[0] = node[1]
                
            root += node[0]
        
        forspkADRS = copy.deepcopy(ADRS)
        forspkADRS.setType(toByte(FORS_ROOTS, 4))
        forspkADRS.setKeyPairAddress(ADRS.getKeyPairAddress())
        pk = T_l(PK_seed, forspkADRS.getBytes(), root)
        return pk
        

## SPHINCS

In [11]:
class sphincs:
    # Private Key: contains two elements
    def __init__(self, n, w, h, d, k, t):
        self.n = n # Parametro de segurança em bytes
        self.w = w # Parametro de Winternitz
        self.h = h # Altura da hypertree
        self.d = d # Número de Camadas da hypertree
        self.k = k # Número de árvores na FORS
        self.t = t # Número de folhas na FORS
        
        self.a = log2(int(self.t))
        m = floor( (self.k*self.a + 7)/8 ) + floor( (self.h - self.h/self.d + 7 )/8 ) + floor( (self.h/self.d + 7 )/8 )
        #len1 = len1
        #len2 = floor( ( log2(len1(self.w - 1)) )/log2(w) ) + 1
        #len_ = len1 + len2
        
        self.fors = FORS(n, k, t)
        self.ht = HT(h, n, w, d)
        
        
    # function that with the input i returns i-bytes of cryptographically strong randomness
    def sec_rand(self,size):
        return os.urandom(size)
    
    # return (secret key, public key)
    def spx_keygen(self):
        SK_seed = self.sec_rand(self.n)
        SK_prf = self.sec_rand(self.n)
        PK_seed = self.sec_rand(self.n)
        PK_root = self.ht.ht_PKgen(SK_seed, PK_seed) 
        return ((SK_seed, SK_prf, PK_seed, PK_root), (PK_seed, PK_root))
    
    
    # Input: Message M, private key SK = (SK.seed, SK.prf, PK.seed, PK.root)
    # Output: SPHINCS+ signatura SIG
    def spx_sign(self, M,SK):
        (SK_seed, SK_prf, PK_seed, PK_root) = SK
    
        # init 
        adrs = ADRS() 
        
        # generate randomizer
        opt = PK_seed
        if(RANDOMIZE):
            opt = self.sec_rand(self.n)
        R = PRF_msg(SK_prf, opt, M)
        SIG = [R] # first element: 1 n-byte string
        
        ka = self.k*self.a
        p1 = floor((ka + 7) / 8)
        p2 = floor((h - h // d + 7) / 8)
        p3 = floor((h // d + 7) / 8)
        
        
        # compute message digest and index
        digest = H_msg(R,  PK_seed, PK_root, M, (p1+p2+p3))
        
        tmp_md = digest[:p1]                       # first floor((ka +7)/ 8) bytes of digest;
        tmp_idx_tree = digest[p1:(p1+p2)]          # next floor((h - h/d +7)/ 8) bytes of digest
        tmp_idx_leaf = digest[(p1+p2):len(digest)]  # next floor((h/d +7)/ 8) bytes of digest
        
        # md = tmp_md[:ka]                                   # first ka bits of tmp_md
        # idx_tree = tmp_idx_tree[:(self.h - self.h/self.d)] # first h - h/d bits of tmp_idx_tree
        # idx_leaf =  tmp_idx_leaf[:(self.h/self.d)]         # first h/d bits of tmp_idx_leaf
        
        
        md_int = int.from_bytes(tmp_md, 'big') >> int(len(tmp_md) * 8 - self.k * self.a)
        md = md_int.to_bytes(ceil(self.k * self.a / 8), 'big')

        idx_tree = int.from_bytes(tmp_idx_tree, 'big') >> (len(tmp_idx_tree) * 8 - (self.h - self.h // self.d))
        idx_leaf = int.from_bytes(tmp_idx_leaf, 'big') >> (len(tmp_idx_leaf) * 8 - (self.h // self.d))
        
        
        # FORS sign
        adrs.setLayerAddress(toByte(0,4))
        adrs.setTreeAddress(toByte(idx_tree, 12))
        adrs.setType(toByte(FORS_TREE, 4))
        adrs.setKeyPairAddress(toByte(idx_leaf, 4))
        
        SIG_FORS = self.fors.fors_sign(md, SK_seed, PK_seed, adrs)
        SIG += [SIG_FORS]
        
        # get FORS public key
        PK_FORS = self.fors.fors_pkFromSig(SIG_FORS, md, PK_seed, adrs)
        
        # sign FORS public key with HT
        adrs.setType(toByte(TREE, 4))
        SIG_HT = self.ht.ht_sign(PK_FORS, SK_seed, PK_seed, idx_tree, idx_leaf)
        SIG += [SIG_HT]
        return SIG
    
    
    # Input: Message M, signature SIG, public key PK
    # Output: Boolean
    def spx_verify(self,M, SIG, PK):
        (PK_seed, PK_root) = PK
        
        # init
        adrs = ADRS()
        R = SIG[0]
        SIG_FORS = SIG[1]
        SIG_HT = SIG[2]
                
        ka = self.k*self.a
        
        p1 = floor((ka + 7)/ 8)
        p2 = floor((h - h//d + 7) / 8)
        p3 = floor((h//d + 7) / 8)
        
        # compute message digest and index
        digest = H_msg(R, PK_seed, PK_root, M, (p1+p2+p3))
        tmp_md = digest[:p1]   #first floor((ka +7)/ 8) bytes of digest
        tmp_idx_tree = digest[p1:(p1+p2)] # next floor((h - h/d +7)/ 8) bytes of digest
        tmp_idx_leaf = digest[(p1+p2):(p1+p2+p3)] # next floor((h/d +7)/ 8) bytes of digest
        
        # md = tmp_md[:ka]                                   # first ka bits of tmp_md
        # idx_tree = tmp_idx_tree[:(self.h - self.h/self.d)] # first h - h/d bits of tmp_idx_tree
        # idx_leaf = tmp_idx_leaf[:(self.h/self.d)]          # first h/d bits of tmp_idx_leaf
        
        md_int = int.from_bytes(tmp_md, 'big') >> int(len(tmp_md) * 8 - k * self.a)
        md = md_int.to_bytes(math.ceil(k * self.a / 8), 'big')

        idx_tree = int.from_bytes(tmp_idx_tree, 'big') >> (len(tmp_idx_tree) * 8 - (h - h // d))
        idx_leaf = int.from_bytes(tmp_idx_leaf, 'big') >> (len(tmp_idx_leaf) * 8 - (h // d))
            
        #compute FORS public key
        adrs.setLayerAddress(toByte(0,4))
        adrs.setTreeAddress(toByte(idx_tree, 12))
        adrs.setType(toByte(FORS_TREE, 4))
        adrs.setKeyPairAddress(toByte(idx_leaf, 4))
        
        PK_FORS = self.fors.fors_pkFromSig(SIG_FORS, md, PK_seed, adrs)
        
        # verify HT signature
        adrs.setType(toByte(TREE, 4))
        return self.ht.ht_verify(PK_FORS, SIG_HT, PK_seed, idx_tree, idx_leaf, PK_root)

## TESTES:

In [12]:
import os
SK_seed = os.urandom(n)
PK_seed = os.urandom(n)
M = "ESTA FRASE E UM TESTE DO WOTS".encode('utf-8')


# TESTES WOTS:
wots = WOTS(n, w, len1, len2, len_)
adsr = ADRS()
sk = wots.wots_SKgen(SK_seed, adsr)
pk = wots.wots_PKgen(SK_seed, PK_seed, adsr)
sign = wots.wots_sign(M, SK_seed, PK_seed, adsr)
pk_sign = wots.wots_pkFromSig(sign, M, PK_seed, adsr)
print(pk == pk_sign)
print("_______________")

# TESTES XMSS
adsr_xmss = ADRS()
xmss = XMSS(h//d, n, w)
pk = xmss.xmss_PKgen(SK_seed, PK_seed, adsr_xmss)
print(f"PK: {pk}")
SIG_XMSS = xmss.xmss_sign(M, SK_seed, 10, PK_seed, adsr_xmss)
pk_sig = xmss.xmss_pkFromSig(10, SIG_XMSS, M, PK_seed, adsr_xmss)
print(f"PK: {pk_sig}")
print(pk == pk_sig)
print("_______________")
# 
# # TESTES HT
ht = HT(h, n, w, d)
pk = ht.ht_PKgen(SK_seed, PK_seed)
sign = ht.ht_sign(M, SK_seed, PK_seed, 1, 1)
res = ht.ht_verify(M, sign, PK_seed, 1, 1, pk) 
print(res)
print("_______________")
# 
# # TESTES FOR

fors = FORS(n, k, t)
adrs_fors = ADRS()
sk = fors.fors_SKgen(SK_seed, adrs_fors , 10)
pk = fors.fors_PKgen(SK_seed, PK_seed,adrs_fors)
SIG_FORS = fors.fors_sign(M, SK_seed, PK_seed, adrs_fors)
pk_sig = fors.fors_pkFromSig(SIG_FORS, M, PK_seed, adrs_fors)
print(pk == pk_sig)
print("_______________")

# TEST spx
spx = sphincs(n, w, h, d, k, t)
SK, PK = spx.spx_keygen()

M1 = "ESTA FRASE E UM TESTE DO SPHINCS. ESPERO QUE ESTEJA CORRETA".encode('utf-8')
M2 = "ESTA FRASE E UM TESTE DO SPHINCS. ESPERO ERRO".encode('utf-8')

SIG = spx.spx_sign(M1,SK)
res1 = spx.spx_verify(M1, SIG, PK)
print(res1)

SIG2 = spx.spx_sign(M2, SK)
res2 = spx.spx_verify(M2, SIG, PK)
print(res2)

SK2, PK2 = spx.spx_keygen()
SIG3 = spx.spx_sign(M2, SK) 
res2 = spx.spx_verify(M2, SIG3, PK2) # esperava-se que tivesse sido assinado pelo 2, mas foi assinado pelo 1
print(res2)


True
_______________
PK: b'\xfb\xea\x93;\xbc\xea\xd6\xd2\xa04c\x9c7\x81\xc2\xbd'
PK: b'c\x87\xca\xa6\x90\x0e\x99V\xe3\xe0\xbb&\xd4>,\x16'
False
_______________
True
_______________
False
_______________
True
False
False
