# Estruturas Criptográficas - TP4-1
### PG53721 - Carlos Machado
### PG54249 - Tiago Oliveira
### Enunciado - Implementação Dilithium FIPS204


In [6]:
import hashlib, os

### Params

In [7]:

class MLDSA44_PARAMS:
    q = 8380417
    d = 13
    tau = 39
    lmbda = 128
    gamma1 = 2**17
    gamma2 = (q-1)/88
    dim = (4,4)
    eta = 2
    beta = 78
    omega = 80

class MLDSA65_PARAMS:
    q = 8380417
    d = 13
    tau = 49
    lmbda = 192
    gamma1 = 2**19
    gamma2 = (q-1)/32
    dim = (6,5)
    eta = 4
    beta = 196
    omega = 55
    
class MLDSA87_PARAMS:
    q = 8380417
    d = 13
    tau = 60
    lmbda = 256
    gamma1 = 2**19
    gamma2 = (q-1)/32
    dim = (8,7)
    eta = 2
    beta = 120
    omega = 75
    


### Auxiliary Functions

#### Conversion Between Data Types

In [8]:

def integer_to_bits(x, l):
    return x.to_bytes(l, 'little')

def bits_to_integer(y):
    return int.from_bytes(y, 'little')

def bits_to_bytes(b):
    B = bytearray([0] * (len(b) // 8))
    for i in range(len(b)):
        B[i // 8] += b[i] * 2 ** (i % 8)
    return bytes(B)

def bytes_to_bits(B):
    B_list = list(B)
    b = [0] * (len(B_list) * 8)
    
    for i in range(len(B_list)):
        for j in range(8):
            b[8 * i + j] = B_list[i] % 2
            B_list[i] //= 2

    return b

def coef_from_three_bytes(b0,b1,b2,params):
    if b2>127:
        b2 = b2-128
    z = 2**16*b2 + 2**8 *b1 + b0
    if z<params.q:
        return z
    else:
        return None

def coef_from_half_byte(b,params):
    if params.eta == 2 and b<15:
        return 2-(b%5)
    else:
        if params.eta == 4 and b<9:
            return 4-b
        else:
            return None

def simple_bit_pack(w,b):
    z = bytearray()
    l = b.bit_length()
    for i in range(0,255):
        z += integer_to_bits(b-w[i],l)
    return bits_to_bytes(z)

def bit_pack(w,a,b):
    z = bytearray()
    l = (a+b).bit_length()
    for i in range(0,255):
        z += integer_to_bits(b-w[i],l)
    return bits_to_bytes(z)

def simple_bit_unpack(v,b):
    c = b.bit_length()
    z = bytes_to_bits(v)
    w = []
    for i in range(0,255):
        bits = bytearray()
        for j in range(0,c):
            bits.append(z[i*c]+j)
        w.append(bits_to_integer(bits,c))
    return w


def bit_unpack(v,a,b):
    c = (a+b).bit_length()
    z = bytes_to_bits(v)
    w = []
    for i in range(0,255):
        bits = bytearray()
        for j in range(0,c):
            bits.append(z[i*c]+j)
        w.append(bits_to_integer(bits,c))
    return w


# não percebo o que este quer
#def hint_bit_pack(h):
#    y = bytearray()
#    for i in range(0,len(h)):
#        for j in range(0,255):
#            if h(i)

#### Hashing and Pseudorandom Sampling

In [9]:
def shake256(v,d):
    shake = hashlib.shake_256()
    shake.update(v.encode('utf-8'))
    shake.digest(d//8)
    
def H(rho, d):
    return shake256(rho, d)

def H_byte(rho, j):
    res = H(rho, 8 * (j + 1))
    return res[j]

def shake128(v,d):
    shake = hashlib.shake_128()
    shake.update(v.encode('utf-8'))
    shake.digest(d//8)

def H128(rho, d):
    return shake128(rho, d)

def H128_byte(rho, c):
    bits = H128(rho, c + 1)
    return bits[c % 32]
    
def sample_in_ball(rho,params):
    c = [0] * 256
    k = 8

    for i in range(256 - params.tau, 256):
        while H_byte(rho, k) > i:
            k += 1

        j = H_byte(rho, k)  
        c[i] = c[j]
        c[j] = (-1)**(H(rho,i + params.tau * 256 +1
        k += 1

    return c

def rej_NTT_poly(rho,params):
    j = 0
    c = 0
    a_res = [0] * 256

    while j < 256:
        byte1 = H128_byte(rho, c)
        byte2 = H128_byte(rho, c + 1)
        byte3 = H128_byte(rho, c + 2)

        coef = coef_from_three_bytes(byte1, byte2, byte3,params)
        a_res[j] = coef

        c += 3

        if coef != None:
            j += 1

    return a_res

def rej_bounded_poly(rho,params):
    j = 0
    c = 0
    a = [0] * 256

    while j < 256:
        z = H_byte(rho, c)
        z0 = coef_from_half_byte(z % 16, params)
        z1 = coef_from_half_byte(z // 16, params)
        
        if z0 != None:
            a[j] = z0
            j += 1
            
        if z1 != None and j < 256:
            a[j] = z1
            j += 1
            
        c += 1

    return a

def expand_A(rho,params):

    l = params.dim[0]
    k = params.dim[1]
    
    A_res = [[0] * l for _ in range(k)]

    for r in range(k):
        for s in range(l):
            A_res[r][s] = rej_NTT_poly(rho + integer_to_bits(s, 8) + integer_to_bits(r, 8),params)

    return A_res

def expand_S(rho,params):
    l = params.dim[0]
    k = params.dim[1]
    eta = params.eta
    
    s1 = [0] * l
    s2 = [0] * k

    for r in range(l):
        s1[r] = rej_bounded_poly(rho + integer_to_bits(r, 16), params)

    for r in range(k):
        s2[r] = rej_bounded_poly(rho + integer_to_bits(r + l, 16), params)

    return s1, s2

def expand_mask(rho, u, params):

    c = 1 + (params.gamma1 - 1).bit_length()
    l = params.dim[0]
    s = [0] * l

    for r in range(l):
        n = integer_to_bits(u + r, 16)
        v = [H_byte(rho + n, 32 * r * c + i) for i in range(32 * c)]
        s[r] = bit_unpack(v, params.gamma1 - 1, params.gamma1)

    return s