### Number Theoretic Transform - CRT


http://doc.sagemath.org/html/en/reference/polynomial_rings/sage/rings/polynomial/polynomial_element_generic.html

In [1]:
class NTT(object):
#    
    def __init__(self, n=128):
        if not any([n == t for t in [32,64,128,256,512,1024,2048]]):
            raise ValueError("improper argument ",n)
        self.n = n  
        self.q = 1 + 2*n
        while True:
            if (self.q).is_prime():
                break
            self.q += 2*n
            
        self.F = GF(self.q) ;  self.R = PolynomialRing(self.F, name="w")
        w = (self.R).gen(); self.w = w
        
        g = (w^n + 1)
        xi = g.roots(multiplicities=False)[-1]
        self.xi = xi
        rs = [xi^(2*i+1)  for i in range(n)] 
        self.base = crt_basis([(w - r) for r in rs])  
    
    
    def ntt(self,f):
        def _expand_(f): 
            u = f.list()
            return u + [0]*(self.n-len(u)) 
        
        def _ntt_(xi,N,f):
            if N==1:
                return f
            N_ = N/2 ; xi2 =  xi^2  
            f0 = [f[2*i]   for i in range(N_)] ; f1 = [f[2*i+1] for i in range(N_)] 
            ff0 = _ntt_(xi2,N_,f0) ; ff1 = _ntt_(xi2,N_,f1)  
    
            s  = xi ; ff = [self.F(0) for i in range(N)] 
            for i in range(N_):
                a = ff0[i] ; b = s*ff1[i]  
                ff[i] = a + b ; ff[i + N_] = a - b 
                s = s * xi2                     
            return ff 
        
        return _ntt_(self.xi,self.n,_expand_(f))
        
    def ntt_inv(self,ff):                              ## transformada inversa
        return sum([ff[i]*self.base[i] for i in range(self.n)])
    
    def random_pol(self,args=None):
        return (self.R).random_element(args)

In [2]:
# Teste

T = NTT(2048)

f = T.random_pol(64)
# print(f)

ff = T.ntt(f)

fff = T.ntt_inv(ff)

# print(fff)
print("Correto ? ",f == fff)

Correto ?  True


In [None]:
from sage.misc.prandom import randint
import hashlib

class NewHope_CPA_PKE:

    #parameters
    n = 512 #|| n = 1024
    k = 8
    q = 12289

    def sample(seed, nonce):
        
        shake = hashlib.shake_256()
        r # polinomio

        extseed = []
        extseed[:31] = seed
        extseed[32] = nonce

        for i in range((n/64)-1):
            extseed[33] = i
            shake.update(extseed)
            buf = shake.digest(128)
            for j in range(63):
                a = buf[2 * j]
                b = buf[2 * j + 1]
                r[64*i+j]= HW(a) + q - HW(b) mod q

    def keyGen():

        shake = hashlib.shake_256()

        seed = [randint(0,1) for i in range(32)]
        shake.update(seed)
        z = shake.digest(64)

        publicseed = z[:32]
        noiseseed = z[32:]

        a  = genA(publicseed)
        s  = polyBitRev(sample(noiseseed,0))
        ss = NTT(s)
        e  = polyBitRev(sample(noiseseed,1))
        ee = NTT(e)

        b = a º ss + ee

        return (pk = encodePK(b,publicseed),sk = encodePolynomial(ss))

    def encrypt(pk,m,coin): # coin == seed

        (b,publicseed) = decodePK(pk)

        a  = genA(publicseed)
        s  = polyBitRev(sample(coin,0))
        e  = polyBitRev(sample(coin,1))
        ee = sample(coin,2)

        t  = NTT(s)
        u  = a º t + NTT(e)
        v  = encode(m)
        vv = NTTInv(b º t) + ee + v
        h  = compress(vv)

        return c = encodeC(u,h)

    def decrypt(c,sk):

        (b,publicseed) = decodePK(pk)

        (u,h) = decodeC(c)
        s = DecodePolynomial(sk)
        v = decompress(h)
        m = decode(v -  NTTInv(u º s))

        return m

    def genA(seed):

        a #poligono
        extseed = seed[:31]

        for i in range((n/64) - 1)
            ctr = 0
            extseed[32] = i
            stare = shake128Absorb(extseed) #TODO
            while cte < 64:
                buf, state = shake128Squeeze(1,state) #TODO
                j = 0
                while (j < 168 && ctr < 64):
                    val = b2i(buf[j])|(b2i(buf[j+1])<< 8) #TODO
                    if val < 5.q:
                        a[i*64+ctr] = val #polinomio
                        ctr++
                    j = j+2

        return a

    def encodePolynomial(p)

        r = []

        for i in range(n/4 -1):
            t0 = p[4*i+0] mod q
            t1 = p[4*i+1] mod q
            t2 = p[4*i+2] mod q
            t3 = p[4*i+3] mod q

            r[7*i + 0] = t0&0xff
            r[7*i + 1] = (t0 >> 8)  | (t1 << 6)&0xff
            r[7*i + 2] = (t1 >> 2)&0xff
            r[7*i + 3] = (t1 >> 10) | (t2 << 4)&0xff
            r[7*i + 4] = (t2 >> 4)&0xff
            r[7*i + 5] = (t3 >> 12) | (t3 << 2)&0xff
            r[7*i + 6] = (t3 >> 6)&0xff

        return r

    def DecodePolynomial(v)

        for i in range(n/4-1):
            r #polinomial
            r[4*i+0] =  b2i(v[7*i+0]) ||((b2i(v[7*i+1])%0x3f) << 8)
            r[4*i+1] = (b2i(v[7*i+1]) >> 6) || (b2i(v[7*i+2]) << 2) || ((b2i(v[7*i+3])%0x0f) << 10)
            r[4*i+2] = (b2i(v[7*i+3]) >> 4) || (b2i(v[7*i+4]) << 4) || ((b2i(v[7*i+5])%0x03) << 12)
            r[4*i+3] = (b2i(v[7*i+5]) >> 2) || (b2i(v[7*i+6]) << 6)

        return r

    def encodePK(b,publicseed):

        r = []

        r[0:7*n/4-1] = encodePolynomial(b)
        r[7*n/4:7*n/4+31] = publicseed[0:31]

        return r

    def decodePK(pk):
        b = DecodePolynomial(pk[0:7*n/4-1])
        seed = pk[7*n/4:7*n/4+31]

        return (b,seed)

    def encode(m):

        v # polinomio
        for i in range(31):
            for j in range(7):
                mask = -((m[i]>>j)&1)
                v[8*i+j+0] = mask&(q/2)
                v[8*i+j+256] = mask&(q/2)
                if n == 1024:
                    v[8*i+j+512]   = mask&(q/2)
                    v[8*i+j+68] = mask&(q/2)

        return v

    def decode(v):

        m = []

        for i in range(255):
            t = |(v[i+0] mod q) - (q-1)/2|
            t = t + |(v[i+256] mod q) - (q-1)/2|
            if n == 1024:
                t = t + |(v[i+512] mod q) - (q-1)/2|
                t = t + |(v[i+768] mod q) - (q-1)/2|
                t = t-q
            else:
                t = t - q/2
            t = t >> 15
            m[i>>3] = m[i>>3] || (t<<(i&7))

        return m

    def compress(v)

        k = 0
        t = [8]
        h = [3*n/8]

        for l in range(n/8-1):
            i = 8*l
            for j in range(7):
                t[j] = v[i+j]  mod q
                t[j] = ((b2i(t[j] << 3) + q/2)/q)&7

            h[k+0] = t[0]|(t[1]<<3)|(t[2]<<6)
            h[k+1] = (t[2] >> 2)|(t[3] << 1)|(t[4] << 4)|(t[5] << 7)
            h[k+2] = (t[5] >> 1)|(t[6] << 2)|(t[7] << 5)
            k = k+3

        return h

    def decompress(a)

        k = 0
        r #polinomio

        for l in range(n/8-1):
            i = 8*l

            r[i+0] = a[k+0]&7
            r[i+1] = (a[k+0] >> 3)&7
            r[i+2] = (a[k+0] >> 6)|((a[k+1] << 2)&4)
            r[i+3] = (a[k+1] >> 1)&7
            r[i+4] = (a[k+1] >> 4)&7
            r[i+5] = (a[k+1] >> 7)|((a[k+2] << 1)&6)
            r[i+6] = (a[k+2] >> 2)&7
            r[i+7] = (a[k+2] >> 5)

            k = k+3
            for j in range(7):
                r[i+j] = (r[i+j]*q+4)>>3

        return r

    def encodeC(u,h):

        c[0:(7*n/4-1)] = encodePolynomial(u)
        c[(7*n/4-1):(7*n/4+3*n/8-1)] = h

        return c

    def decodeC(c):

        u = DecodePolynomial(c[0:(7*n/4-1)])
        h = c[(7*n/4-1):(7*n/4+3*n/8-1)]

        return (u,h)


class NewHope_CPA_KEM:
        
    def gen():
        (pk,sk) = pke.keyGen()

        return (pk,sk)

    def encapsulation(pk):

        seed = [randint(0,1) for i in range(32)]
        shake = hashlib.shake_256()

        shake.update(seed)
        k = shake.digest(32)

        coin = shake.digest(32)

        c = pke.encrypt(pk,k,coin)

        shake.update(k)
        ss = shake.digest(32)

        return (c,ss)

    def decapsulation(c,sk):

        shake = hashlib.shake_256()

        k = pke.decrypt(c,sk)

        shake.update(k)
        ss = shake.digest(32)

        return ss
