# Implementação KEM do NTRU-Prime

## Parameters

In [1]:
def verifyW(p, q, w, indice):
    while (2*p < 3*w):
        indice = indice + 1
        w = p//indice
    
    while (q < (16*w + 1)):
        indice = indice + 1
        w = p//indice

    return w

In [2]:
p = next_prime(120)
q = next_prime(120)
w = p//4
d = (q-1)/2

w = verifyW(p,q,w,4)

R_.<x>  = ZZ[]
R       = R_.quotient(x^p-x-1)

R3_.<x> = GF(3)[]
R3      = R3_.quotient(x^p-x-1)

Rq_.<x> = GF(q)[]
Rq      = Rq_.quotient(x^p-x-1)

if (x^p-x-1).is_irreducible() is False:
    print("Error\n")

## Funções auxiliares

In [3]:
def round_next_3(inp,pol=None):
    try:
        inp = lift(inp).list()
    except: 
        pass
    def f(u):
        a = lift(u) ; a = a if a <= d else a-q
        return 3*round(a/3)
        
    pr = [f(u) for u in inp]
    if pol:
        return pol(pr)
    else:
        return pr
    
    
def round3(inp,pol=None):
    try:
        inp = lift(inp).list()
    except: 
        pass
    def f(a):
        u = lift(a) ; u = u if u <= d else u-q 
        u = u%3
        return u if u < 2 else -1 
        
    pr = [f(a) for a in inp]
    if pol:
        return pol(pr)
    else:
        return pr

def R3_to_small(inp):

    inp2 = lift(inp).list()
    anp = [0]*p
    
    for i in range(len(inp2)):
        anp[i] = inp2[i]      
    
    def f(u):
        return u if u < 2 else -1
    return [f(u) for u in anp]

In [4]:
import random as rn

def small(p):
    u = [rn.choice([-1,0,1]) for i in range(p)]
    return u

def smallW(p,w):
    u = [rn.choice([-1,1]) for i in range(w)] + [0]*(p-w)
    rn.shuffle(u)
    return u

In [5]:
def verifyG():
    
    g = small(p)
    while not R3(g).is_unit():
        g = small(p)

    return g

In [6]:
def pesosR(vec):
    cont = 0

    for i in range(len(vec)):
        if (vec[i] != 0):
           cont = cont+1
    
    return cont

## Key generation, Encapsulate e Decapsulate

In [7]:
import hashlib

def KeyGen():
    g = verifyG()
    f = smallW(p,w)
    F = Rq(f)
    g1 = 1/R3(g)
    h = Rq(g)/(3*F)
    
    return {'f': F, 'g1': g1, 'pk' : h}

def Encapsulate(pk):
    r = smallW(p,w)
    c = round_next_3(pk*Rq(r))

    fhash = hashlib.sha512()
    fhash.update(str(r).encode('utf-8'))
    divisao = fhash.digest()

    C = divisao[:32]
    K = divisao[32:]
    
    return {'C': C, 'c': c, 'K': K}
    
def Decapsulate(C,c,f,g1):
    a = round3(Rq(3*f)*Rq(c))
    e = R3(a)*g1
    r1 = R3_to_small(e)

    fhash = hashlib.sha512()
    fhash.update(str(r1).encode('utf-8'))
    divisao = fhash.digest()

    CLinha = divisao[:32]
    KLinha = divisao[32:]

    if (pesosR(r1) == w):
        if(CLinha == C):
            return {'r1': r1, 'k': KLinha}
        else:
            return False
    else:
        return False

## Test

In [8]:
import base64

def run():
    keys   = KeyGen()
    crypto = Encapsulate(keys['pk'])
    decryp = Decapsulate(crypto['C'],crypto['c'],keys['f'],keys['g1'])
    if (crypto['c'] == round_next_3(keys['pk']*Rq(decryp['r1']))):
        return base64.b64encode(crypto['K'])==base64.b64encode(decryp['k'])
    else:
        return false

In [9]:
run()

True