# Estruturas Criptográficas - TP4-1
### PG53721 - Carlos Machado
### PG54249 - Tiago Oliveira
### Enunciado - Implementação Dilithium FIPS204


In [7]:
import hashlib, os

### Params

In [8]:

class MLDSA44_PARAMS:
    q = 8380417
    d = 13
    tau = 39
    lmbda = 128
    gamma1 = 2**17
    gamma2 = (q-1)/88
    dim = (4,4)
    k = 4
    l = 4
    eta = 2
    beta = 78
    omega = 80

class MLDSA65_PARAMS:
    q = 8380417
    d = 13
    tau = 49
    lmbda = 192
    gamma1 = 2**19
    gamma2 = (q-1)/32
    dim = (6,5)
    k = 6
    l = 5
    eta = 4
    beta = 196
    omega = 55
    
class MLDSA87_PARAMS:
    q = 8380417
    d = 13
    tau = 60
    lmbda = 256
    gamma1 = 2**19
    gamma2 = (q-1)/32
    dim = (8,7)
    k = 8
    l = 7
    eta = 2
    beta = 120
    omega = 75
    
PARAMS = MLDSA44_PARAMS()

### Auxiliary Functions

#### Conversion Between Data Types

In [12]:

def integer_to_bits(x, l):
    return x.to_bytes(l, 'little')

def bits_to_integer(y):
    return int.from_bytes(y, 'little')

def bits_to_bytes(b):
    B = bytearray([0] * (len(b) // 8))
    for i in range(len(b)):
        B[i // 8] += b[i] * 2 ** (i % 8)
    return bytes(B)

def bytes_to_bits(B):
    B_list = list(B)
    b = [0] * (len(B_list) * 8)
    
    for i in range(len(B_list)):
        for j in range(8):
            b[8 * i + j] = B_list[i] % 2
            B_list[i] //= 2

    return b

def coef_from_three_bytes(b0,b1,b2):
    if b2>127:
        b2 = b2-128
    z = 2**16*b2 + 2**8 *b1 + b0
    if z<PARAMS.q:
        return z
    else:
        return None

def coef_from_half_byte(b):
    if PARAMS.eta == 2 and b<15:
        return 2-(b%5)
    else:
        if PARAMS.eta == 4 and b<9:
            return 4-b
        else:
            return None

def simple_bit_pack(w,b):
    z = bytearray()
    l = b.bit_length()
    for i in range(0,256):
        z += integer_to_bits(b-w[i],l)
    return bits_to_bytes(z)

def bit_pack(w,a,b):
    z = bytearray()
    l = (a+b).bit_length()
    for i in range(0,256):
        z += integer_to_bits(b-w[i],l)
    return bits_to_bytes(z)

def simple_bit_unpack(v,b):
    c = b.bit_length()
    z = bytes_to_bits(v)
    w = []
    for i in range(0,256):
        bits = bytearray()
        for j in range(0,c):
            bits.append(z[i*c]+j)
        w.append(bits_to_integer(bits,c))
    return w


def bit_unpack(v,a,b):
    c = (a+b).bit_length()
    z = bytes_to_bits(v)
    w = []
    for i in range(0,255):
        bits = bytearray()
        for j in range(0,c):
            bits.append(z[i*c]+j)
        w.append(bits_to_integer(bits,c))
    return w


def hint_bit_pack(h):
    y = bytearray(PARAMS.omega+PARAMS.k)
    index = 0
    for i in range(0,PARAMS.k):
        for j in range(0,256):
            if h[i][j] != 0:
                y[index] = j
                index += 1
        y[PARAMS.omega+i]
    
    return y

def hint_bit_unpack(y):
    h = [[0 for _ in range(0,256)] for _ in range(0,PARAMS.k)]
    index = 0
    for i in range(0,PARAMS.k):
        if y[PARAMS.omega+i]<index or y[PARAMS.omega+i]>PARAMS.omega:
            return None
        while index < y[PARAMS.omega+i]:
            h[i][y[index]] = 1
            index += 1
    while index < PARAMS.omega:
        if y[index]!=0:
            return None
        index += 1
    return h

#### Encodings of ML-DSA Keys and Signatures

In [13]:
def pk_encode(p,t1):
    pk = bits_to_bytes(p)
    for i in range(0,PARAMS.k):
        bl = (PARAMS.q-1).bit_length() - PARAMS.d
        pk = pk + simple_bit_pack(t1[i],2**bl-1)
    return pk

def pk_decode(pk):
    y = pk[:32]
    zs = []
    bl = (PARAMS.q-1).bit_length() - PARAMS.d # TODO: fix other bl
    
    rho = bytes_to_bits(y)
    t1 = [0 for _ in range(0,PARAMS.k)]
    
    for i in range(0,PARAMS.k):
        zs.append(pk[32+(32*bl*i):32+(32*bl*(i+1))])
    
    for i in range(0,PARAMS.k):
        t1[i] = simple_bit_unpack(z[i],2**bl-1)
    
    return rho, t1

def sk_encode(rho, K, tr, s1, s2, t0):
    sk = bits_to_bytes(rho) + bits_to_bytes(K) + bits_to_bytes(tr)
    
    for i in range(0,PARAMS.l):
        sk = sk + bit_pack(s1[i],PARAMS.eta,PARAMS.eta)
    
    for i in range(0,PARAMS.k):
        sk = sk + bit_pack(s2[i],PARAMS.eta,PARAMS.eta)
    
    for i in range(0,PARAMS.k):
        sk = sk + bit_pack(t0[i],2**(PARAMS.d-1)-1,2**(PARAMS.d-1))
    
    return sk

def sk_decode(sk):
    f = sk[:32]
    g = sk[32:64]
    h = sk[64:128]
    
    ys = []
    base = 128
    stride = 32*(2*PARAMS.eta)
    for i in range(0,PARAMS.l):
        ys.append(sk[base+stride*i:base+stride*(i+1)])
        
    base += stride*PARAMS.l
    
    zs = []
    stride = 32*(2*PARAMS.eta)
    for i in range(0,PARAMS.k):
        zs.append(sk[base+stride*i:base+stride*(i+1)])

    base += stride*PARAMS.k
    
    ws = []
    stride = 32*PARAMS.d
    for i in range(0,PARAMS.k):
        ws.append(sk[base+stride*i:base+stride*(i+1)])
    
    rho = bytes_to_bits(f)
    K = bytes_to_bits(g)
    tr = bytes_to_bits(h)
    
    s1 = []
    for i in range(0,PARAMS.l):
        s1.append(bit_unpack(ys[i],PARAMS.eta,PARAMS.eta))
    
    s2 = []
    for i in range(0,PARAMS.k):
        s2.append(bit_unpack(zs[i],PARAMS.eta,PARAMS.eta))
        
    t0 = []
    for i in range(0,PARAMS.k):
        t0.append(bit_unpack(ws[i],2**(PARAMS.d-1)-1,2**(PARAMS.d-1)))
    
    return rho, K, tr, s1, s2, t0

def sig_encode(c,z,h):
    sigma = bits_to_bytes(c)
    for i in range(0,PARAMS.l):
        sigma = sigma + bit_pack(z[i],PARAMS.gamma1-1,PARAMS.gamma1)
    
    sigma = sigma + hint_bit_pack(h)
    
    return sigma

def sig_decode(sigma):
    w = sigma[:PARAMS.lmbda/4]
    
    base = PARAMS.lmbda/4
    xs = []
    stride = 32*((PARAMS.gamma1-1).bit_length()+1)
    for i in range(0,PARAMS.l):
        xs.append(sigma[base+stride*i:base+stride*(i+1)])
    
    base += stride*PARAMS.l
    
    y = sigma[base:PARAMS.omega+PARAMS.k]
    
    c = bytes_to_bits(w)
    
    z = []
    for i in range(0,PARAMS.l):
        z.append(bit_unpack(xs[i],PARAMS.gamma1-1,PARAMS.gamma1))
    
    h = hint_bit_unpack(y)
    
    return c,z,h

def w1_encode(w1):
    nw1 = bytearray()
    
    for i in range(0,PARAMS.k):
        nw1 += bytes_to_bits(
            simple_bit_pack(w1[i],(PARAMS.q-1)/(2*PARAMS.gamma2)-1)
        )
    
    return nw1

#### Hashing and Pseudorandom Sampling

In [16]:
def shake256(v,d):
    shake = hashlib.shake_256()
    shake.update(v.encode('utf-8'))
    shake.digest(d//8)
    
def H(rho, d):
    return shake256(rho, d)

def H_byte(rho, j):
    res = H(rho, 8 * (j + 1))
    return res[j]

def shake128(v,d):
    shake = hashlib.shake_128()
    shake.update(v.encode('utf-8'))
    shake.digest(d//8)

def H128(rho, d):
    return shake128(rho, d)

def H128_byte(rho, c):
    bits = H128(rho, c + 1)
    return bits[c % 32]
    
def sample_in_ball(rho):
    c = [0] * 256
    k = 8

    for i in range(256 - PARAMS.tau, 256):
        while H_byte(rho, k) > i:
            k += 1

        j = H_byte(rho, k)  
        c[i] = c[j]
        c[j] = (-1)**(H(rho,i + PARAMS.tau * 256 +1))
        k += 1

    return c

def rej_NTT_poly(rho):
    j = 0
    c = 0
    a_res = [0] * 256

    while j < 256:
        byte1 = H128_byte(rho, c)
        byte2 = H128_byte(rho, c + 1)
        byte3 = H128_byte(rho, c + 2)

        coef = coef_from_three_bytes(byte1, byte2, byte3)
        a_res[j] = coef

        c += 3

        if coef != None:
            j += 1

    return a_res

def rej_bounded_poly(rho):
    j = 0
    c = 0
    a = [0] * 256

    while j < 256:
        z = H_byte(rho, c)
        z0 = coef_from_half_byte(z % 16)
        z1 = coef_from_half_byte(z // 16)
        
        if z0 != None:
            a[j] = z0
            j += 1
            
        if z1 != None and j < 256:
            a[j] = z1
            j += 1
            
        c += 1

    return a

def expand_A(rho):

    l = PARAMS.dim[0]
    k = PARAMS.dim[1]
    
    A_res = [[0] * l for _ in range(k)]

    for r in range(k):
        for s in range(l):
            A_res[r][s] = rej_NTT_poly(rho + integer_to_bits(s, 8) + integer_to_bits(r, 8),PARAMS)

    return A_res

def expand_S(rho):
    l = PARAMS.dim[0]
    k = PARAMS.dim[1]
    eta = PARAMS.eta
    
    s1 = [0] * l
    s2 = [0] * k

    for r in range(l):
        s1[r] = rej_bounded_poly(rho + integer_to_bits(r, 16))

    for r in range(k):
        s2[r] = rej_bounded_poly(rho + integer_to_bits(r + l, 16))

    return s1, s2

def expand_mask(rho, u):

    c = 1 + (PARAMS.gamma1 - 1).bit_length()
    l = PARAMS.dim[0]
    s = [0] * l

    for r in range(l):
        n = integer_to_bits(u + r, 16)
        v = [H_byte(rho + n, 32 * r * c + i) for i in range(32 * c)]
        s[r] = bit_unpack(v, PARAMS.gamma1 - 1, PARAMS.gamma1)

    return s

#### High Order / Low Order Bits and Hints

In [5]:
def modpm(n, a):
    return (n%a) - a

def power2round(r):
    rp = r % PARAMS.q
    r0 = modpm(rp, 2**PARAMS.d)
    return (rp-r0)/2**PARAMS.d, r0

def decompose(r):
    rp = r & PARAMS.q
    r0 = modpm(rp, 2*PARAMS.gamma2)
    
    if rp-r0 == PARAMS.q-1:
        r1 = 0
        r0 = r0 - 1
    else:
        r1 = (rp-r0)/(2*PARAMS.gamma2)
    
    return r1,r0

def high_bits(r):
    r1,r0 = decompose(r)
    return r1

def low_bits(r):
    r1,r0 = decompose(r)
    return r0

def make_hint(z,r):
    r1 = high_bits(r)
    v1 = high_bits(r+z)
    return r1!=v1

def use_hint(h,r):
    m = (PARAMS.q-1)/(2*PARAMS.gamma2)
    r1,r0 = decompose(r)
    if h==1 and r0>0:
        return (r1+1) % m
    elif h==1 and r0<=0:
        return (r1-1) % m
    return r1


#### NTT and Inverse NTT

In [18]:
def bit_rev_7(r):
    return int('{:07b}'.format(r)[::-1], 2)

ZETA = [pow(17, bit_rev_7(k_value), PARAMS.q) for k_value in range(128)]
GAMMA = [pow(17, 2 * bit_rev_7(k_value) + 1, PARAMS.q) for k_value in range(128)]

def ntt(f):
    fc = f
    k_value = 1
    l = 128

    while l >= 2:
        start = 0
        while start < 256:
            zeta = ZETA[k_value] % PARAMS.q
            k_value += 1
            for j in range(start, start + l):
                t = (zeta * fc[j + l]) % PARAMS.q
                fc[j + l] = (fc[j] - t) % PARAMS.q
                fc[j] = (fc[j] + t) % PARAMS.q

            start += 2 * l
            
        l //= 2

    return fc

def ntt_inv(fc):
    f = fc
    k_value = 127
    l = 2
    while l <= 128:
        start = 0
        while start < 256:
            zeta = ZETA[k_value]
            k_value -= 1
            for j in range(start, start + l):
                t = f[j]
                f[j] = (t + f[j + l]) % PARAMS.q
                f[j + l] = (zeta * (f[j + l] - t)) % PARAMS.q

            start += 2 * l

        l *= 2

    return [(felem * 3303) % PARAMS.q for felem in f]
