Skip to content

Commit afb5e27

Browse files
committed
Fix side-channel leakage in RSA decryption
1 parent ee91c67 commit afb5e27

17 files changed

+350
-35
lines changed

lib/Crypto/Cipher/PKCS1_OAEP.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,8 @@ def decrypt(self, ciphertext):
167167
raise ValueError("Ciphertext with incorrect length.")
168168
# Step 2a (O2SIP)
169169
ct_int = bytes_to_long(ciphertext)
170-
# Step 2b (RSADP)
171-
m_int = self._key._decrypt(ct_int)
172-
# Complete step 2c (I2OSP)
173-
em = long_to_bytes(m_int, k)
170+
# Step 2b (RSADP) and step 2c (I2OSP)
171+
em = self._key._decrypt(ct_int)
174172
# Step 3a
175173
lHash = self._hashObj.new(self._label).digest()
176174
# Step 3b

lib/Crypto/Cipher/PKCS1_v1_5.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,8 @@ def decrypt(self, ciphertext, sentinel, expected_pt_len=0):
176176
# Step 2a (O2SIP)
177177
ct_int = bytes_to_long(ciphertext)
178178

179-
# Step 2b (RSADP)
180-
m_int = self._key._decrypt(ct_int)
181-
182-
# Complete step 2c (I2OSP)
183-
em = long_to_bytes(m_int, k)
179+
# Step 2b (RSADP) and Step 2c (I2OSP)
180+
em = self._key._decrypt(ct_int)
184181

185182
# Step 3 (not constant time when the sentinel is not a byte string)
186183
output = bytes(bytearray(k))

lib/Crypto/Math/_IntegerBase.py

+20
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,23 @@ def random_range(cls, **kwargs):
390390
)
391391
return norm_candidate + min_inclusive
392392

393+
@staticmethod
394+
@abc.abstractmethod
395+
def _mult_modulo_bytes(term1, term2, modulus):
396+
"""Multiply two integers, take the modulo, and encode as big endian.
397+
This specialized method is used for RSA decryption.
398+
399+
Args:
400+
term1 : integer
401+
The first term of the multiplication, non-negative.
402+
term2 : integer
403+
The second term of the multiplication, non-negative.
404+
modulus: integer
405+
The modulus, a positive odd number.
406+
:Returns:
407+
A byte string, with the result of the modular multiplication
408+
encoded in big endian mode.
409+
It is as long as the modulus would be, with zero padding
410+
on the left if needed.
411+
"""
412+
pass

lib/Crypto/Math/_IntegerBase.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,8 @@ class IntegerBase:
6060
def random(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ...
6161
@classmethod
6262
def random_range(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ...
63+
@staticmethod
64+
def _mult_modulo_bytes(term1: Union[IntegerBase, int],
65+
term2: Union[IntegerBase, int],
66+
modulus: Union[IntegerBase, int]) -> bytes: ...
6367

lib/Crypto/Math/_IntegerCustom.py

+50-6
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,18 @@
4141
from Crypto.Random.random import getrandbits
4242

4343
c_defs = """
44-
int monty_pow(const uint8_t *base,
45-
const uint8_t *exp,
46-
const uint8_t *modulus,
47-
uint8_t *out,
48-
size_t len,
49-
uint64_t seed);
44+
int monty_pow(uint8_t *out,
45+
const uint8_t *base,
46+
const uint8_t *exp,
47+
const uint8_t *modulus,
48+
size_t len,
49+
uint64_t seed);
50+
51+
int monty_multiply(uint8_t *out,
52+
const uint8_t *term1,
53+
const uint8_t *term2,
54+
const uint8_t *modulus,
55+
size_t len);
5056
"""
5157

5258

@@ -116,3 +122,41 @@ def inplace_pow(self, exponent, modulus=None):
116122
result = bytes_to_long(get_raw_buffer(out))
117123
self._value = result
118124
return self
125+
126+
@staticmethod
127+
def _mult_modulo_bytes(term1, term2, modulus):
128+
129+
# With modular reduction
130+
mod_value = int(modulus)
131+
if mod_value < 0:
132+
raise ValueError("Modulus must be positive")
133+
if mod_value == 0:
134+
raise ZeroDivisionError("Modulus cannot be zero")
135+
136+
# C extension only works with odd moduli
137+
if (mod_value & 1) == 0:
138+
raise ValueError("Odd modulus is required")
139+
140+
# C extension only works with non-negative terms smaller than modulus
141+
if term1 >= mod_value or term1 < 0:
142+
term1 %= mod_value
143+
if term2 >= mod_value or term2 < 0:
144+
term2 %= mod_value
145+
146+
modulus_b = long_to_bytes(mod_value)
147+
numbers_len = len(modulus_b)
148+
term1_b = long_to_bytes(term1, numbers_len)
149+
term2_b = long_to_bytes(term2, numbers_len)
150+
out = create_string_buffer(numbers_len)
151+
152+
error = _raw_montgomery.monty_multiply(
153+
out,
154+
term1_b,
155+
term2_b,
156+
modulus_b,
157+
c_size_t(numbers_len)
158+
)
159+
if error:
160+
raise ValueError("monty_multiply failed with error: %d" % error)
161+
162+
return get_raw_buffer(out)

lib/Crypto/Math/_IntegerGMP.py

+20
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,26 @@ def jacobi_symbol(a, n):
749749
raise ValueError("n must be positive odd for the Jacobi symbol")
750750
return _gmp.mpz_jacobi(a._mpz_p, n._mpz_p)
751751

752+
@staticmethod
753+
def _mult_modulo_bytes(term1, term2, modulus):
754+
if not isinstance(term1, IntegerGMP):
755+
term1 = IntegerGMP(term1)
756+
if not isinstance(term2, IntegerGMP):
757+
term2 = IntegerGMP(term2)
758+
if not isinstance(modulus, IntegerGMP):
759+
modulus = IntegerGMP(modulus)
760+
761+
if modulus < 0:
762+
raise ValueError("Modulus must be positive")
763+
if modulus == 0:
764+
raise ZeroDivisionError("Modulus cannot be zero")
765+
if (modulus & 1) == 0:
766+
raise ValueError("Odd modulus is required")
767+
768+
numbers_len = len(modulus.to_bytes())
769+
result = ((term1 * term2) % modulus).to_bytes(numbers_len)
770+
return result
771+
752772
# Clean-up
753773
def __del__(self):
754774

lib/Crypto/Math/_IntegerNative.py

+12
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,15 @@ def jacobi_symbol(a, n):
368368
n1 = n % a1
369369
# Step 8
370370
return s * IntegerNative.jacobi_symbol(n1, a1)
371+
372+
@staticmethod
373+
def _mult_modulo_bytes(term1, term2, modulus):
374+
if modulus < 0:
375+
raise ValueError("Modulus must be positive")
376+
if modulus == 0:
377+
raise ZeroDivisionError("Modulus cannot be zero")
378+
if (modulus & 1) == 0:
379+
raise ValueError("Odd modulus is required")
380+
381+
number_len = len(long_to_bytes(modulus))
382+
return long_to_bytes((term1 * term2) % modulus, number_len)

lib/Crypto/PublicKey/RSA.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from Crypto import Random
3939
from Crypto.Util.py3compat import tobytes, bord, tostr
4040
from Crypto.Util.asn1 import DerSequence, DerNull
41+
from Crypto.Util.number import bytes_to_long
4142

4243
from Crypto.Math.Numbers import Integer
4344
from Crypto.Math.Primality import (test_probable_prime,
@@ -198,10 +199,11 @@ def _decrypt(self, ciphertext):
198199
h = ((m2 - m1) * self._u) % self._q
199200
mp = h * self._p + m1
200201
# Step 4: Compute m = m' * (r**(-1)) mod n
201-
result = (r.inverse(self._n) * mp) % self._n
202-
# Verify no faults occurred
203-
if ciphertext != pow(result, self._e, self._n):
204-
raise ValueError("Fault detected in RSA decryption")
202+
# then encode into a big endian byte string
203+
result = Integer._mult_modulo_bytes(
204+
r.inverse(self._n),
205+
mp,
206+
self._n)
205207
return result
206208

207209
def has_private(self):

lib/Crypto/SelfTest/Math/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ def get_tests(config={}):
3838
from Crypto.SelfTest.Math import test_Numbers
3939
from Crypto.SelfTest.Math import test_Primality
4040
from Crypto.SelfTest.Math import test_modexp
41+
from Crypto.SelfTest.Math import test_modmult
4142
tests += test_Numbers.get_tests(config=config)
4243
tests += test_Primality.get_tests(config=config)
4344
tests += test_modexp.get_tests(config=config)
45+
tests += test_modmult.get_tests(config=config)
4446
return tests
4547

4648
if __name__ == '__main__':

lib/Crypto/SelfTest/Math/test_Numbers.py

+28
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,34 @@ def test_hex(self):
696696
v1, = self.Integers(0x10)
697697
self.assertEqual(hex(v1), "0x10")
698698

699+
def test_mult_modulo_bytes(self):
700+
modmult = self.Integer._mult_modulo_bytes
701+
702+
res = modmult(4, 5, 19)
703+
self.assertEqual(res, b'\x01')
704+
705+
res = modmult(4 - 19, 5, 19)
706+
self.assertEqual(res, b'\x01')
707+
708+
res = modmult(4, 5 - 19, 19)
709+
self.assertEqual(res, b'\x01')
710+
711+
res = modmult(4 + 19, 5, 19)
712+
self.assertEqual(res, b'\x01')
713+
714+
res = modmult(4, 5 + 19, 19)
715+
self.assertEqual(res, b'\x01')
716+
717+
modulus = 2**512 - 1 # 64 bytes
718+
t1 = 13**100
719+
t2 = 17**100
720+
expect = b"\xfa\xb2\x11\x87\xc3(y\x07\xf8\xf1n\xdepq\x0b\xca\xf3\xd3B,\xef\xf2\xfbf\xcc)\x8dZ*\x95\x98r\x96\xa8\xd5\xc3}\xe2q:\xa2'z\xf48\xde%\xef\t\x07\xbc\xc4[C\x8bUE2\x90\xef\x81\xaa:\x08"
721+
self.assertEqual(expect, modmult(t1, t2, modulus))
722+
723+
self.assertRaises(ZeroDivisionError, modmult, 4, 5, 0)
724+
self.assertRaises(ValueError, modmult, 4, 5, -1)
725+
self.assertRaises(ValueError, modmult, 4, 5, 4)
726+
699727

700728
class TestIntegerInt(TestIntegerBase):
701729

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#
2+
# SelfTest/Math/test_modmult.py: Self-test for custom modular multiplication
3+
#
4+
# ===================================================================
5+
#
6+
# Copyright (c) 2023, Helder Eijs <helderijs@gmail.com>
7+
# All rights reserved.
8+
#
9+
# Redistribution and use in source and binary forms, with or without
10+
# modification, are permitted provided that the following conditions
11+
# are met:
12+
#
13+
# 1. Redistributions of source code must retain the above copyright
14+
# notice, this list of conditions and the following disclaimer.
15+
# 2. Redistributions in binary form must reproduce the above copyright
16+
# notice, this list of conditions and the following disclaimer in
17+
# the documentation and/or other materials provided with the
18+
# distribution.
19+
#
20+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23+
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24+
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25+
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26+
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29+
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30+
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31+
# POSSIBILITY OF SUCH DAMAGE.
32+
# ===================================================================
33+
34+
"""Self-test for the custom modular multiplication"""
35+
36+
import unittest
37+
38+
from Crypto.SelfTest.st_common import list_test_cases
39+
40+
from Crypto.Util.number import long_to_bytes, bytes_to_long
41+
42+
from Crypto.Util._raw_api import (create_string_buffer,
43+
get_raw_buffer,
44+
c_size_t)
45+
46+
from Crypto.Math._IntegerCustom import _raw_montgomery
47+
48+
49+
class ExceptionModulus(ValueError):
50+
pass
51+
52+
53+
def monty_mult(term1, term2, modulus):
54+
55+
if term1 >= modulus:
56+
term1 %= modulus
57+
if term2 >= modulus:
58+
term2 %= modulus
59+
60+
modulus_b = long_to_bytes(modulus)
61+
numbers_len = len(modulus_b)
62+
term1_b = long_to_bytes(term1, numbers_len)
63+
term2_b = long_to_bytes(term2, numbers_len)
64+
65+
out = create_string_buffer(numbers_len)
66+
error = _raw_montgomery.monty_multiply(
67+
out,
68+
term1_b,
69+
term2_b,
70+
modulus_b,
71+
c_size_t(numbers_len)
72+
)
73+
74+
if error == 17:
75+
raise ExceptionModulus()
76+
if error:
77+
raise ValueError("monty_multiply() failed with error: %d" % error)
78+
79+
return get_raw_buffer(out)
80+
81+
82+
modulus1 = 0xd66691b20071be4d66d4b71032b37fa007cfabf579fcb91e50bfc2753b3f0ce7be74e216aef7e26d4ae180bc20d7bd3ea88a6cbf6f87380e613c8979b5b043b200a8ff8856a3b12875e36e98a7569f3852d028e967551000b02c19e9fa52e83115b89309aabb1e1cf1e2cb6369d637d46775ce4523ea31f64ad2794cbc365dd8a35e007ed3b57695877fbf102dbeb8b3212491398e494314e93726926e1383f8abb5889bea954eb8c0ca1c62c8e9d83f41888095c5e645ed6d32515fe0c58c1368cad84694e18da43668c6f43e61d7c9bca633ddcda7aef5b79bc396d4a9f48e2a9abe0836cc455e435305357228e93d25aaed46b952defae0f57339bf26f5a9
83+
84+
85+
class TestModMultiply(unittest.TestCase):
86+
87+
def test_small(self):
88+
self.assertEqual(b"\x01", monty_mult(5, 6, 29))
89+
90+
def test_large(self):
91+
numbers_len = (modulus1.bit_length() + 7) // 8
92+
93+
t1 = modulus1 // 2
94+
t2 = modulus1 - 90
95+
expect = b'\x00' * (numbers_len - 1) + b'\x2d'
96+
self.assertEqual(expect, monty_mult(t1, t2, modulus1))
97+
98+
def test_zero_term(self):
99+
numbers_len = (modulus1.bit_length() + 7) // 8
100+
expect = b'\x00' * numbers_len
101+
self.assertEqual(expect, monty_mult(0x100, 0, modulus1))
102+
self.assertEqual(expect, monty_mult(0, 0x100, modulus1))
103+
104+
def test_larger_term(self):
105+
t1 = 2**2047
106+
expect_int = 0x8edf4071f78e3d7ba622cdbbbef74612e301d69186776ae6bf87ff38c320d9aebaa64889c2f67de2324e6bccd2b10ad89e91fd21ba4bb523904d033eff5e70e62f01a84f41fa90a4f248ef249b82e1d2729253fdfc2a3b5b740198123df8bfbf7057d03e15244ad5f26eb9a099763b5c5972121ec076b0bf899f59bd95f7cc129abddccf24217bce52ca0f3a44c9ccc504765dbb89734205f3ae6a8cc560494a60ea84b27d8e00fa24bdd5b4f1d4232edb61e47d3d984c1fa50a3820a2e580fbc3fc8bc11e99df53b9efadf5a40ac75d384e400905aa6f1d88950cd53b1c54dc2222115ad84a27260fa4d978155c1434c551de1ee7361a17a2f79d4388f78a5d
107+
res = bytes_to_long(monty_mult(t1, t1, modulus1))
108+
self.assertEqual(res, expect_int)
109+
110+
111+
def get_tests(config={}):
112+
tests = []
113+
tests += list_test_cases(TestModMultiply)
114+
return tests
115+
116+
117+
if __name__ == '__main__':
118+
def suite():
119+
return unittest.TestSuite(get_tests())
120+
unittest.main(defaultTest='suite')

lib/Crypto/SelfTest/PublicKey/test_RSA.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _exercise_primitive(self, rsaObj):
279279
ciphertext = bytes_to_long(a2b_hex(self.ciphertext))
280280

281281
# Test decryption
282-
plaintext = rsaObj._decrypt(ciphertext)
282+
plaintext = bytes_to_long(rsaObj._decrypt(ciphertext))
283283

284284
# Test encryption (2 arguments)
285285
new_ciphertext2 = rsaObj._encrypt(plaintext)
@@ -304,7 +304,7 @@ def _check_decryption(self, rsaObj):
304304
ciphertext = bytes_to_long(a2b_hex(self.ciphertext))
305305

306306
# Test plain decryption
307-
new_plaintext = rsaObj._decrypt(ciphertext)
307+
new_plaintext = bytes_to_long(rsaObj._decrypt(ciphertext))
308308
self.assertEqual(plaintext, new_plaintext)
309309

310310

0 commit comments

Comments
 (0)