Chương trình Sagemath thực thi thuật toán RSA-PSS

Hàm sinh mặt nạ (MGF-Mask Generation Function)\
*Trong chương trình này, ta sẽ sử dụng hàm băm SHA-1, vậy nên hLen = 20.

In [1]:
from Crypto.Util.number import ceil_div, long_to_bytes, size
from Crypto.Util.strxor import strxor
from Crypto.Util.py3compat import *
from Crypto.Hash import SHA1
print('done')

done


In [2]:
def generate_key(s):
    p=random_prime(s);q=random_prime(s);
    if p!=q:
        n=p*q
        phi=(p-1)*(q-1);
        ch=True
        while ch:
            e=randint(2,phi);
            if gcd(e,phi)==1:
                ch=False
        d=xgcd(e,phi)[1] % phi;
    return([n,phi,[e,n],[d,n]])

In [4]:
def mgf1(X, maskLen):
    hashfn = SHA1
    hLen = hashfn.digest_size
    T = b('')
    for counter in range(ceil_div(maskLen, hLen)):
        C = long_to_bytes(counter, 4)
        T += hashfn.new(X + C).digest()
    return T[:maskLen]

In [5]:
def random_octet_string(length):
    import random
    return bytes([random.randint(0, 255) for _ in range(length)])

In [10]:
def encoding_message(M, emBits):
    hashfn = SHA1
    mhash = hashfn.new(str.encode(M))
    hLen = hashfn.digest_size
    sLen = hLen
    # Parameters
    emLen = ceil_div(emBits, 8)
    padding1 = bchr(0x00) * 8
    padding2 = bchr(0x00) *(emLen-sLen-hLen-2) + bchr(0x01)
    bc = bchr(0xBC)
    ##########
    salt = random_octet_string(sLen)
    h = hashfn.new(padding1 + mhash.digest() + salt)
    db = padding2 + salt
    dbMask = mgf1(h.digest(), emLen-hLen-1)
    maskedDB = strxor(db, dbMask)
    lmask = 0
    for i in range(8*emLen-emBits):
        lmask = lmask>>1 | 0x80
    maskedDB = bchr(bord(maskedDB[0]) & ~lmask) + maskedDB[1:]
    em = maskedDB + h.digest() + bc
    return em

In [8]:
def sign(M, pr_key):
    d = pr_key[0]
    n = pr_key[1]
    modBits = size(n)
    k = ceil_div(modBits, 8)
    em = encoding_message(M, modBits-1)
    m = int.from_bytes(em, 'big')
    s = (m**d) % n
    S = s.to_bytes(k, 'big')
    return S

In [9]:
def decrypt(S, pl_key):
    e = pl_key[0]
    n = pl_key[1]
    s = int.from_bytes(S, 'big')
    m = (s**e) % n
    modBits = size(n)
    emLen = ceil_div(modBits-1, 8)
    em = m.to_bytes(emLen, 'big')
    return em

In [12]:
def verify(M, em, emBits):
    hashfn = SHA1
    hLen = hashfn.digest_size
    sLen = hLen
    emLen = ceil_div(emBits, 8)
    #
    padding1 = bchr(0x00) * 8
    padding2 = bchr(0x00) *(emLen-sLen-hLen-2) + bchr(0x01)
    #
    mhash = hashfn.new(str.encode(M))
    if emLen < hLen + sLen + 2:
        return False
    if ord(em[-1:]) != 0xBC:
        return False
    maskedDB = em[:emLen-hLen-1]
    h = em[emLen-hLen-1:-1]
    lmask = 0
    for i in range(8*emLen-emBits):
        lmask = lmask>>1 | 0x80
    if lmask & bord(em[0]):
        return False
    dbMask = mgf1(h, emLen-hLen-1)
    db = strxor(maskedDB, dbMask)
    db = bchr(bord(db[0]) & ~lmask) + db[1:]
    if not db.startswith(padding2):
        return False
    salt = b('')
    salt = db[-sLen:]
    h_ = hashfn.new(padding1 + mhash.digest() + salt).digest()
    if h == h_:
        return True
    else:
        return False