Skip to content

Commit

Permalink
Use constant-time (faster) padding decoding also for OAEP
Browse files Browse the repository at this point in the history
  • Loading branch information
Legrandin committed Dec 27, 2023
1 parent 519e7ae commit 0deea1b
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 66 deletions.
38 changes: 16 additions & 22 deletions lib/Crypto/Cipher/PKCS1_OAEP.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from Crypto.Signature.pss import MGF1
import Crypto.Hash.SHA1

from Crypto.Util.py3compat import bord, _copy_bytes
from Crypto.Util.py3compat import _copy_bytes
import Crypto.Util.number
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
from Crypto.Util.strxor import strxor
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
from Crypto.Util.strxor import strxor
from Crypto import Random
from ._pkcs1_oaep_decode import oaep_decode


class PKCS1OAEP_Cipher:
"""Cipher object for PKCS#1 v1.5 OAEP.
Expand Down Expand Up @@ -68,7 +70,7 @@ def __init__(self, key, hashAlgo, mgfunc, label, randfunc):
if mgfunc:
self._mgf = mgfunc
else:
self._mgf = lambda x,y: MGF1(x,y,self._hashObj)
self._mgf = lambda x, y: MGF1(x, y, self._hashObj)

self._label = _copy_bytes(None, None, label)
self._randfunc = randfunc
Expand Down Expand Up @@ -105,7 +107,7 @@ def encrypt(self, message):

# See 7.1.1 in RFC3447
modBits = Crypto.Util.number.size(self._key.n)
k = ceil_div(modBits, 8) # Convert from bits to bytes
k = ceil_div(modBits, 8) # Convert from bits to bytes
hLen = self._hashObj.digest_size
mLen = len(message)

Expand Down Expand Up @@ -159,20 +161,18 @@ def decrypt(self, ciphertext):

# See 7.1.2 in RFC3447
modBits = Crypto.Util.number.size(self._key.n)
k = ceil_div(modBits,8) # Convert from bits to bytes
k = ceil_div(modBits, 8) # Convert from bits to bytes
hLen = self._hashObj.digest_size

# Step 1b and 1c
if len(ciphertext) != k or k<hLen+2:
if len(ciphertext) != k or k < hLen+2:
raise ValueError("Ciphertext with incorrect length.")
# Step 2a (O2SIP)
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP) and step 2c (I2OSP)
em = self._key._decrypt_to_bytes(ct_int)
# Step 3a
lHash = self._hashObj.new(self._label).digest()
# Step 3b
y = em[0]
# y must be 0, but we MUST NOT check it here in order not to
# allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143)
maskedSeed = em[1:hLen+1]
Expand All @@ -185,22 +185,17 @@ def decrypt(self, ciphertext):
dbMask = self._mgf(seed, k-hLen-1)
# Step 3f
db = strxor(maskedDB, dbMask)
# Step 3g
one_pos = hLen + db[hLen:].find(b'\x01')
lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen)
hash_compare = strxor(lHash1, lHash)
for x in hash_compare:
invalid |= bord(x)
for x in db[hLen:one_pos]:
invalid |= bord(x)
if invalid != 0:
# Step 3b + 3g
res = oaep_decode(em, lHash, db)
if res <= 0:
raise ValueError("Incorrect decryption.")
# Step 4
return db[one_pos + 1:]
return db[res:]


def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
"""Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
"""Return a cipher object :class:`PKCS1OAEP_Cipher`
that can be used to perform PKCS#1 OAEP encryption or decryption.
:param key:
The key object to use to encrypt or decrypt the message.
Expand Down Expand Up @@ -234,4 +229,3 @@ def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
if randfunc is None:
randfunc = Random.get_random_bytes
return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc)

31 changes: 3 additions & 28 deletions lib/Crypto/Cipher/PKCS1_v1_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,7 @@
from Crypto import Random
from Crypto.Util.number import bytes_to_long, long_to_bytes
from Crypto.Util.py3compat import bord, is_bytes, _copy_bytes

from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, c_size_t,
c_uint8_ptr)


_raw_pkcs1_decode = load_pycryptodome_raw_lib("Crypto.Cipher._pkcs1_decode",
"""
int pkcs1_decode(const uint8_t *em, size_t len_em,
const uint8_t *sentinel, size_t len_sentinel,
size_t expected_pt_len,
uint8_t *output);
""")


def _pkcs1_decode(em, sentinel, expected_pt_len, output):
if len(em) != len(output):
raise ValueError("Incorrect output length")

ret = _raw_pkcs1_decode.pkcs1_decode(c_uint8_ptr(em),
c_size_t(len(em)),
c_uint8_ptr(sentinel),
c_size_t(len(sentinel)),
c_size_t(expected_pt_len),
c_uint8_ptr(output))
return ret
from ._pkcs1_oaep_decode import pkcs1_decode


class PKCS115_Cipher:
Expand Down Expand Up @@ -113,7 +89,6 @@ def encrypt(self, message):
continue
ps.append(new_byte)
ps = b"".join(ps)
assert(len(ps) == k - mLen - 3)
# Step 2b
em = b'\x00\x02' + ps + b'\x00' + _copy_bytes(None, None, message)
# Step 3a (OS2IP)
Expand Down Expand Up @@ -182,14 +157,14 @@ def decrypt(self, ciphertext, sentinel, expected_pt_len=0):
# Step 3 (not constant time when the sentinel is not a byte string)
output = bytes(bytearray(k))
if not is_bytes(sentinel) or len(sentinel) > k:
size = _pkcs1_decode(em, b'', expected_pt_len, output)
size = pkcs1_decode(em, b'', expected_pt_len, output)
if size < 0:
return sentinel
else:
return output[size:]

# Step 3 (somewhat constant time)
size = _pkcs1_decode(em, sentinel, expected_pt_len, output)
size = pkcs1_decode(em, sentinel, expected_pt_len, output)
return output[size:]


Expand Down
41 changes: 41 additions & 0 deletions lib/Crypto/Cipher/_pkcs1_oaep_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, c_size_t,
c_uint8_ptr)


_raw_pkcs1_decode = load_pycryptodome_raw_lib("Crypto.Cipher._pkcs1_decode",
"""
int pkcs1_decode(const uint8_t *em, size_t len_em,
const uint8_t *sentinel, size_t len_sentinel,
size_t expected_pt_len,
uint8_t *output);
int oaep_decode(const uint8_t *em,
size_t em_len,
const uint8_t *lHash,
size_t hLen,
const uint8_t *db,
size_t db_len);
""")


def pkcs1_decode(em, sentinel, expected_pt_len, output):
if len(em) != len(output):
raise ValueError("Incorrect output length")

ret = _raw_pkcs1_decode.pkcs1_decode(c_uint8_ptr(em),
c_size_t(len(em)),
c_uint8_ptr(sentinel),
c_size_t(len(sentinel)),
c_size_t(expected_pt_len),
c_uint8_ptr(output))
return ret


def oaep_decode(em, lHash, db):
ret = _raw_pkcs1_decode.oaep_decode(c_uint8_ptr(em),
c_size_t(len(em)),
c_uint8_ptr(lHash),
c_size_t(len(lHash)),
c_uint8_ptr(db),
c_size_t(len(db)))
return ret
79 changes: 74 additions & 5 deletions src/pkcs1_decode.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ STATIC size_t safe_select_idx(size_t in1, size_t in2, uint8_t choice)
* - in1[] is NOT equal to in2[] where neq_mask[] is 0xFF.
* Return non-zero otherwise.
*/
STATIC uint8_t safe_cmp(const uint8_t *in1, const uint8_t *in2,
STATIC uint8_t safe_cmp_masks(const uint8_t *in1, const uint8_t *in2,
const uint8_t *eq_mask, const uint8_t *neq_mask,
size_t len)
{
Expand Down Expand Up @@ -187,7 +187,7 @@ STATIC size_t safe_search(const uint8_t *in1, uint8_t c, size_t len)
return result;
}

#define EM_PREFIX_LEN 10
#define PKCS1_PREFIX_LEN 10

/*
* Decode and verify the PKCS#1 padding, then put either the plaintext
Expand Down Expand Up @@ -222,13 +222,13 @@ EXPORT_SYM int pkcs1_decode(const uint8_t *em, size_t len_em_output,
if (NULL == em || NULL == output || NULL == sentinel) {
return -1;
}
if (len_em_output < (EM_PREFIX_LEN + 2)) {
if (len_em_output < (PKCS1_PREFIX_LEN + 2)) {
return -1;
}
if (len_sentinel > len_em_output) {
return -1;
}
if (expected_pt_len > 0 && expected_pt_len > (len_em_output - EM_PREFIX_LEN - 1)) {
if (expected_pt_len > 0 && expected_pt_len > (len_em_output - PKCS1_PREFIX_LEN - 1)) {
return -1;
}

Expand All @@ -240,7 +240,7 @@ EXPORT_SYM int pkcs1_decode(const uint8_t *em, size_t len_em_output,
memcpy(padded_sentinel + (len_em_output - len_sentinel), sentinel, len_sentinel);

/** The first 10 bytes must follow the pattern **/
match = safe_cmp(em,
match = safe_cmp_masks(em,
(const uint8_t*)"\x00\x02" "\x00\x00\x00\x00\x00\x00\x00\x00",
(const uint8_t*)"\xFF\xFF" "\x00\x00\x00\x00\x00\x00\x00\x00",
(const uint8_t*)"\x00\x00" "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
Expand Down Expand Up @@ -283,3 +283,72 @@ EXPORT_SYM int pkcs1_decode(const uint8_t *em, size_t len_em_output,
free(padded_sentinel);
return result;
}

/*
* Decode and verify the OAEP padding in constant time.
*
* The function returns the number of bytes to ignore at the beginning
* of db (the rest is the plaintext), or -1 in case of problems.
*/

EXPORT_SYM int oaep_decode(const uint8_t *em,
size_t em_len,
const uint8_t *lHash,
size_t hLen,
const uint8_t *db,
size_t db_len) /* em_len - 1 - hLen */
{
int result;
size_t one_pos, search_len, i;
uint8_t wrong_padding;
uint8_t *eq_mask = NULL;
uint8_t *neq_mask = NULL;
uint8_t *target_db = NULL;

if (NULL == em || NULL == lHash || NULL == db) {
return -1;
}

if (em_len < 2*hLen+2 || db_len != em_len-1-hLen) {
return -1;
}

/* Allocate */
eq_mask = (uint8_t*) calloc(1, db_len);
neq_mask = (uint8_t*) calloc(1, db_len);
target_db = (uint8_t*) calloc(1, db_len);
if (NULL == eq_mask || NULL == neq_mask || NULL == target_db) {
result = -1;
goto cleanup;
}

/* Step 3g */
search_len = db_len - hLen;

one_pos = safe_search(db + hLen, 0x01, search_len);
if (SIZE_T_MAX == one_pos) {
result = -1;
goto cleanup;
}

memset(eq_mask, 0xAA, db_len);
memcpy(target_db, lHash, hLen);
memset(eq_mask, 0xFF, hLen);

for (i=0; i<search_len; i++) {
eq_mask[hLen + i] = propagate_ones(i < one_pos);
}

wrong_padding = em[0];
wrong_padding |= safe_cmp_masks(db, target_db, eq_mask, neq_mask, db_len);
set_if_match(&wrong_padding, one_pos, search_len);

result = wrong_padding ? -1 : (int)(hLen + 1 + one_pos);

cleanup:
free(eq_mask);
free(neq_mask);
free(target_db);

return result;
}
22 changes: 11 additions & 11 deletions src/test/test_pkcs1.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ void set_if_match(uint8_t *flag, size_t term1, size_t term2);
void set_if_no_match(uint8_t *flag, size_t term1, size_t term2);
void safe_select(const uint8_t *in1, const uint8_t *in2, uint8_t *out, uint8_t choice, size_t len);
size_t safe_select_idx(size_t in1, size_t in2, uint8_t choice);
uint8_t safe_cmp(const uint8_t *in1, const uint8_t *in2,
uint8_t safe_cmp_masks(const uint8_t *in1, const uint8_t *in2,
const uint8_t *eq_mask, const uint8_t *neq_mask,
size_t len);
size_t safe_search(const uint8_t *in1, uint8_t c, size_t len);
Expand Down Expand Up @@ -80,57 +80,57 @@ void test_safe_select_idx(void)
assert(safe_select_idx(0x100004, 0x223344, 1) == 0x223344);
}

void test_safe_cmp(void)
void test_safe_cmp_masks(void)
{
uint8_t res;

res = safe_cmp(onezero, onezero,
res = safe_cmp_masks(onezero, onezero,
(uint8_t*)"\xFF\xFF",
(uint8_t*)"\x00\x00",
2);
assert(res == 0);

res = safe_cmp(onezero, zerozero,
res = safe_cmp_masks(onezero, zerozero,
(uint8_t*)"\xFF\xFF",
(uint8_t*)"\x00\x00",
2);
assert(res != 0);

res = safe_cmp(onezero, oneone,
res = safe_cmp_masks(onezero, oneone,
(uint8_t*)"\xFF\xFF",
(uint8_t*)"\x00\x00",
2);
assert(res != 0);

res = safe_cmp(onezero, oneone,
res = safe_cmp_masks(onezero, oneone,
(uint8_t*)"\xFF\x00",
(uint8_t*)"\x00\x00",
2);
assert(res == 0);

/** -- **/

res = safe_cmp(onezero, onezero,
res = safe_cmp_masks(onezero, onezero,
(uint8_t*)"\x00\x00",
(uint8_t*)"\xFF\xFF",
2);
assert(res != 0);

res = safe_cmp(oneone, zerozero,
res = safe_cmp_masks(oneone, zerozero,
(uint8_t*)"\x00\x00",
(uint8_t*)"\xFF\xFF",
2);
assert(res == 0);

res = safe_cmp(onezero, oneone,
res = safe_cmp_masks(onezero, oneone,
(uint8_t*)"\x00\x00",
(uint8_t*)"\x00\xFF",
2);
assert(res == 0);

/** -- **/

res = safe_cmp(onezero, oneone,
res = safe_cmp_masks(onezero, oneone,
(uint8_t*)"\xFF\x00",
(uint8_t*)"\x00\xFF",
2);
Expand Down Expand Up @@ -158,7 +158,7 @@ int main(void)
test_set_if_no_match();
test_safe_select();
test_safe_select_idx();
test_safe_cmp();
test_safe_cmp_masks();
test_safe_search();
return 0;
}

0 comments on commit 0deea1b

Please sign in to comment.