# SSE - Demo - Ferrara Justin
- Notebook to run in a sage environnment
- Proof of concept de l'article : **Keegan Ryan, Kaiwen He, George Arnold Sullivan, Nadia Heninger**. *Passive SSH Key Compromise via Lattices*. Cryptology ePrint Archive, Paper 2023/1711, 2023. [DOI: 10.1145/3576915.3616629](https://doi.org/10.1145/3576915.3616629), [URL: https://eprint.iacr.org/2023/1711](https://eprint.iacr.org/2023/1711).

## PKCS#1 V1.5 Signatures

In [1]:
from Crypto.Signature.pkcs1_15 import new as pkcs1_15_new
from Crypto.PublicKey import RSA
from Crypto.Hash import SHA256, SHA512
from hashlib import sha256, sha512
from math import gcd
from Crypto.Util.number import getPrime, inverse


def key_gen(len_n):
    key = RSA.generate(len_n)
    n, e, d, p, q = key.n, key.e, key.d, key.p, key.q
    return (e, d, n, p, q)


def pkcs1_v1_5_pad_sha256(len_n, m):
    h = sha256(m).digest()
    prefix_sha256 = b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'
    pad_len = len_n // 8 - len(h) - len(prefix_sha256) - 3
    return b'\x00\x01' + b'\xff' * pad_len + b'\x00' + prefix_sha256 + h

def pkcs1_v1_5_pad_sha256_mean(len_n):
    h = b'\x80' + 31 * b'\x00'   # SHA-256 hash size is 32 bytes
    prefix_sha256 = b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'
    pad_len = len_n // 8 - len(h) - len(prefix_sha256) - 3
    return b'\x00\x01' + b'\xff' * pad_len + b'\x00' + prefix_sha256 + h

def pkcs1_v1_5_pad_sha512(len_n, m):
    h = sha512(m).digest()
    prefix_sha512 = b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'
    pad_len = len_n // 8 - len(h) - len(prefix_sha512) - 3
    return b'\x00\x01' + b'\xff' * pad_len + b'\x00' + prefix_sha512 + h

def pkcs1_v1_5_pad_sha512_mean(len_n):
    h = b'\x80' + 63 * b'\x00'  # SHA-512 hash size is 64 bytes
    prefix_sha512 = b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'
    pad_len = len_n // 8 - len(h) - len(prefix_sha512) - 3
    return b'\x00\x01' + b'\xff' * pad_len + b'\x00' + prefix_sha512 + h


def sign(m, d, p, q, n, hash_alg, bug = False):
    dp = d % (p-1)
    dq = d % (q-1)
    
    # PKCS#1 V1.5 signature
    if hash_alg == "SHA256" :
        pad = pkcs1_v1_5_pad_sha256(n.bit_length(), m)
    elif hash_alg == "SHA512" :
        pad = pkcs1_v1_5_pad_sha512(n.bit_length(), m)
    else :
        raise ValueError('Hash format unknown !')
        
    pad_int = int.from_bytes(pad, byteorder="big")
    
    sp = power_mod(pad_int, dp, p)
    sq = power_mod(pad_int, dq, q - 22 if bug else q) # The bug is introduce in the mod q part

    return crt([sp, sq], [p, q]), pad

def verify(s, pad, e, n):
    byte_length = (n.bit_length() + 7) // 8
    return int(power_mod(s, e, n)).to_bytes(byte_length, byteorder="big") == pad


def testSignature(hash_alg):
    print("!!! Warning:This test is not a complete test of the signature !!!")
    print("Testing signatures for ", hash_alg)
    
    for i in range(10):
        # Génération des clés avec PyCryptodome
        key = RSA.generate(2048)
        n, e, d, p, q = key.n, key.e, key.d, key.p, key.q
        
        signer = pkcs1_15_new(key)
        
        # Message aléatoire à signer
        message = f"Test message {i}".encode()
        
        # Signature avec notre fonction
        signature_custom, pad = sign(message, d, p, q, n, hash_alg)
        
        # Signature avec PyCryptodome
        if hash_alg == "SHA256" :
            hash_obj = SHA256.new(message)
        elif hash_alg == "SHA512" :
            hash_obj = SHA512.new(message)
        else :
            raise ValueError('Hash format unknown !')
            
        signature_pycrypto = signer.sign(hash_obj)
        
        
        # Comparaison des signatures        
        assert int(signature_custom).to_bytes(256, byteorder="big") == signature_pycrypto, "Error signatures doesn't match"
        assert verify(signature_custom, pad,e, n)
        
        print("Signature", i, " ok")
        
    print("Testing wrong signatures for ", hash_alg)
    for i in range(10):
        # Génération des clés avec PyCryptodome
        key = RSA.generate(2048)
        n, e, d, p, q = key.n, key.e, key.d, key.p, key.q
        
        signer = pkcs1_15_new(key)
        
        # Message aléatoire à signer
        message = f"Test message {i}".encode()
        
        # Signature avec notre fonction
        signature_custom, pad = sign(message, d, p, q, n, hash_alg, True)
        
        # Signature avec PyCryptodome
        if hash_alg == "SHA256" :
            hash_obj = SHA256.new(message)
        elif hash_alg == "SHA512" :
            hash_obj = SHA512.new(message)
        else :
            raise ValueError('Hash format unknown !')
            
        signature_pycrypto = signer.sign(hash_obj)
        
        # Comparaison des signatures  
        assert int(signature_custom).to_bytes(256, byteorder="big") != signature_pycrypto, "Error signatures doesn't match"
        assert not verify(signature_custom, pad,e, n)
        
        print("Signature", i, " ok")
    
    print("Tests passed for ", hash_alg, " !")


testSignature("SHA256")
testSignature("SHA512")

Testing signatures for  SHA256
Signature 0  ok
Signature 1  ok
Signature 2  ok
Signature 3  ok
Signature 4  ok
Signature 5  ok
Signature 6  ok
Signature 7  ok
Signature 8  ok
Signature 9  ok
Testing wrong signatures for  SHA256
Signature 0  ok
Signature 1  ok
Signature 2  ok
Signature 3  ok
Signature 4  ok
Signature 5  ok
Signature 6  ok
Signature 7  ok
Signature 8  ok
Signature 9  ok
Tests passed for  SHA256  !
Testing signatures for  SHA512
Signature 0  ok
Signature 1  ok
Signature 2  ok
Signature 3  ok
Signature 4  ok
Signature 5  ok
Signature 6  ok
Signature 7  ok
Signature 8  ok
Signature 9  ok
Testing wrong signatures for  SHA512
Signature 0  ok
Signature 1  ok
Signature 2  ok
Signature 3  ok
Signature 4  ok
Signature 5  ok
Signature 6  ok
Signature 7  ok
Signature 8  ok
Signature 9  ok
Tests passed for  SHA512  !


## Attack code

In [2]:
def attack(s, e, n, pad, hash_alg, verbose=False):
    
    if hash_alg == "SHA256" :
        log_h = 255
        mean_pad = pkcs1_v1_5_pad_sha256_mean(n.bit_length())
    elif hash_alg == "SHA512" :
        log_h = 511
        mean_pad = pkcs1_v1_5_pad_sha512_mean(n.bit_length())
    else :
        raise ValueError('Hash format unknown !')
    
    
    
    mean_pad_int = int.from_bytes(mean_pad, byteorder='big')
    
    n1 = power_mod(s, e, n) - mean_pad_int

    # Construct the matrix
    B = Matrix(QQ, 3, 3)
    B[0, 0] = -2**(2 * log_h)
    B[0, 1] = 2**log_h * n1
    B[0, 2] = 0
    B[1, 0] = 0
    B[1, 1] = -2**log_h
    B[1, 2] = n1
    B[2, 0] = 0
    B[2, 1] = 0
    B[2, 2] = n
    
    
    # Réduction LLL sur la matrice
    reduced_matrix = B.LLL()
    
    for row in reduced_matrix:

        # Construction du polynôme g(x)
        R.<x> = PolynomialRing(QQ)
        g = R((row[0]/2**(2*log_h)) * x**2 + row[1]/(2**log_h) * x + row[2])
        
        roots = g.roots()

        if verbose : print("Roots :", roots)

        for r1, _ in roots:
            p = gcd(n, n1 - int(r1) % n)
            if 1 < p < n:
                if verbose :
                    print("-------------------")
                    print("Attack success !!!")
                    print("r1 :", r1)
                    print("p :", p)
                return p
    

    if verbose : print("Attack failed.")
    return None

hash_alg = "SHA256"
(e, d, n, p, q) = key_gen(2048)
m = b"This message is signed with RSA-CRT!"
(s, pad) = sign(m, d, p, q, n, hash_alg, True)

print("-------------------")
print("Signature valid: ", verify(s, pad,e, n))

print("-------------------")
print("Attack")
print("-------------------")
new_p = attack(s, e, n, pad, hash_alg, True)

print("Correct p founded:", p == new_p)

-------------------
Signature valid:  False
-------------------
Attack
-------------------
Roots : [(13211956665292931308949066983794306422363417166108650900646494265497385318803, 1), (-1753615932881276906251817973529014357932800025322412989789029237286801694805010018066437491502305391878495867287367515923286519265650203603780826960286646268034468017504445476702789051144410319/24154868634327141756302829194346689460878689177316474915119301712042589629800885813786795400027970851178903354819045, 1)]
-------------------
Attack success !!!
r1 : 13211956665292931308949066983794306422363417166108650900646494265497385318803
p : 150354461926490064821228477574493384680350032316704444657305540829456043007346591619899746672359868329847685441583830919812717453753939389642719037329730321454293367178531339779652207932821611999553976040605442986748705656551102993577476209867498567863135553201391577143208949531197689249413410371324862023733
Correct p founded: True


## Stats

In [3]:
from prettytable import PrettyTable

key_sizes = [1024, 1280, 1408, 1536, 1792, 2048, 3072]
hash_algorithms = ["SHA256", "SHA512"]

table = PrettyTable()
table.field_names = ["Key Size", "Hash Alg", "Success Rate (%)"]

for hash_alg in hash_algorithms:
    for size in key_sizes:
        success_count = 0
        attempts = 10  # Nombre d'essais pour obtenir un résultat fiable
        
        for _ in range(attempts):
            (e, d, n, p, q) = key_gen(size)
            m = b"This message is signed with RSA-CRT!"
            (s, pad) = sign(m, d, p, q, n, hash_alg, True)
            new_p = attack(s, e, n, pad, hash_alg)
            if new_p == p:
                success_count += 1
        
        success_rate = (success_count / attempts) * 100
        table.add_row([size, hash_alg, success_rate])

print(table)

+----------+----------+------------------+
| Key Size | Hash Alg | Success Rate (%) |
+----------+----------+------------------+
|   1024   |  SHA256  |        0         |
|   1280   |  SHA256  |        0         |
|   1408   |  SHA256  |        0         |
|   1536   |  SHA256  |       100        |
|   1792   |  SHA256  |       100        |
|   2048   |  SHA256  |       100        |
|   3072   |  SHA256  |       100        |
|   1024   |  SHA512  |        0         |
|   1280   |  SHA512  |        0         |
|   1408   |  SHA512  |        0         |
|   1536   |  SHA512  |        0         |
|   1792   |  SHA512  |        0         |
|   2048   |  SHA512  |        0         |
|   3072   |  SHA512  |       100        |
+----------+----------+------------------+
