# SSE - Demo - Ferrara Justin
- Notebook to run in a sage environnment
- Proof of concept de l'article : **Keegan Ryan, Kaiwen He, George Arnold Sullivan, Nadia Heninger**. *Passive SSH Key Compromise via Lattices*. Cryptology ePrint Archive, Paper 2023/1711, 2023. [DOI: 10.1145/3576915.3616629](https://doi.org/10.1145/3576915.3616629), [URL: https://eprint.iacr.org/2023/1711](https://eprint.iacr.org/2023/1711).

## PKCS#1 V1.5 Signatures

In [1]:
from Crypto.Signature.pkcs1_15 import new as pkcs1_15_new
from Crypto.PublicKey import RSA
from hashlib import sha256
from Crypto.Hash import SHA256
from math import gcd
from random import randint
from Crypto.Util.number import getPrime, inverse
import numpy as np


def key_gen():
    phi = 2
    e = 2
    while gcd(phi, e) != 1 : 
        p = random_prime(2**1024, proof = False)
        q = random_prime(2**1024, proof = False)
        n = p*q
        phi = (p-1) * (q-1)
        e = 65537
    d = inverse_mod(e, phi)
    return (e, d, n, p, q)

def sign(m, d, p, q, n, bug = False):
    dp = d % (p-1)
    dq = d % (q-1)
    h = int.from_bytes(sha256(m).digest(), byteorder = "big")
    
    # PKCS#1 V1.5 signature
    prefix_sha256 = b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'
    pad_len = 256 - len(h.to_bytes(32, byteorder="big")) - len(prefix_sha256) - 3
    pad = b'\x00\x01' + b'\xff' * pad_len + b'\x00' + prefix_sha256 + h.to_bytes(32, byteorder="big")
    
    pad_int = int.from_bytes(pad, byteorder="big")

    n1 = 0
    if bug:
        # n1 = (p-6) * p
        rand_p = p-22
        sp = power_mod(pad_int, dp, rand_p)
    else:
        sp = power_mod(pad_int, dp, p)

    sq = power_mod(pad_int, dq, q)
    signature = crt([sp, sq], [p, q])

    return signature, pad

def verify(s, pad, e, n):
    sign_calc = power_mod(s, e, n)
    sign_calc_bytes = int(sign_calc).to_bytes(256, byteorder="big")

    return sign_calc_bytes == pad


def testSignature():
    print("!!! Warning:This test is not a complete test of the signature !!!")
    
    for i in range(10):
        # Génération des clés avec PyCryptodome
        key = RSA.generate(2048)
        n, e, d, p, q = key.n, key.e, key.d, key.p, key.q
        
        signer = pkcs1_15_new(key)
        
        # Message aléatoire à signer
        message = f"Test message {i}".encode()
        
        # Signature avec notre fonction
        signature_custom, pad = sign(message, d, p, q, n)
        
        # Signature avec PyCryptodome
        hash_obj = SHA256.new(message)
        signature_pycrypto = signer.sign(hash_obj)
        
        
        # Comparaison des signatures
        assert int(signature_custom).to_bytes(256, byteorder="big") == signature_pycrypto, "Error signatures doesn't match"
        
        print("Signature", i, " ok")
        
    print("Testing wrong signatures")
    for i in range(10):
        # Génération des clés avec PyCryptodome
        key = RSA.generate(2048)
        n, e, d, p, q = key.n, key.e, key.d, key.p, key.q
        
        signer = pkcs1_15_new(key)
        
        # Message aléatoire à signer
        message = f"Test message {i}".encode()
        
        # Signature avec notre fonction
        signature_custom, pad = sign(message, d, p, q, n, True)
        
        # Signature avec PyCryptodome
        hash_obj = SHA256.new(message)
        signature_pycrypto = signer.sign(hash_obj)
        
        
        # Comparaison des signatures
        assert int(signature_custom).to_bytes(256, byteorder="big") != signature_pycrypto, "Error signatures doesn't match"
        
        print("Signature", i, " ok")

testSignature()


Signature 0  ok
Signature 1  ok
Signature 2  ok
Signature 3  ok
Signature 4  ok
Signature 5  ok
Signature 6  ok
Signature 7  ok
Signature 8  ok
Signature 9  ok
Testing wrong signatures
Signature 0  ok
Signature 1  ok
Signature 2  ok
Signature 3  ok
Signature 4  ok
Signature 5  ok
Signature 6  ok
Signature 7  ok
Signature 8  ok
Signature 9  ok


## Attack code

In [10]:
def attack(s, e, n, pad):

    # Known part of the padding
    prefix_sha256 = b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'
    pad_len = 202
    a = b'\x00\x01' + b'\xff' * pad_len + b'\x00' + prefix_sha256 + b'\x80' + 31 * b'\x00'
    
    a_int = int.from_bytes(a, byteorder='big')
    
    n1 = power_mod(s, e, n) - a_int
    
    #print("test mod q : ", (n1 % q) < 2**256)
    #print("test mod n : ", n1< 2**256)
    
    print("Calculated n1 :", n1)
    print("\n")

    # Construct the matrix
    log_r = 255            # TODO change
    B = Matrix(QQ, 3, 3)
    B[0, 0] = -2**(2 * log_r)
    B[0, 1] = 2**log_r * n1
    B[0, 2] = 0
    B[1, 0] = 0
    B[1, 1] = -2**log_r
    B[1, 2] = n1
    B[2, 0] = 0
    B[2, 1] = 0
    B[2, 2] = n
    
    
    # Réduction LLL sur la matrice
    reduced_matrix = B.LLL()

    # Affichage de la matrice réduite
    # print("Matrice réduite :", reduced_matrix)
    
    print("-------------------")
    
    
    for row in reduced_matrix:

        # Construction du polynôme g(x)
        R.<x> = PolynomialRing(QQ)
        g = R(row[0]/2**log_r + row[1]/2**log_r * x + row[2]/2**log_r * x**2)
        #g = R(row[0] + row[1]* x + row[2] * x**2)

        print("Polynôme g(x) :", g)

        # Recherche des racines rationnelles
        roots = g.roots()  # Recherche des racines dans les entiers

        print("Roots :", roots)

        for r1, _ in roots:
            p = gcd(n, n1 - int(r1) % n)
            if 1 < p < n:
                print("-------------------")
                print("Attack success !!!")
                print("n1 :", n1)
                print("r1 :", r1)
                print("p :", p)
                return p
    

    print("Attack failed.")
    return None



(e, d, n, p, q) = key_gen()
m = b"This message is signed with RSA-CRT!"
(s, pad) = sign(m, d, p, q, n, True)
# print("e = %s" % str(e))
# print("n = %s" %str(n))
# print("s = %s" %s)
# print("pad = ", pad)

print("-------------------")
print("Verify signature: ", verify(s, pad,e, n))

print("-------------------")
print("Attack")
print("-------------------")
new_p = attack(s, e, n, pad)

print("Correct p founded:", p == new_p or q == new_p)

-------------------
Verify signature:  False
-------------------
Attack
-------------------
Calculated n1 : 21069101706519537947320311613862610416857741528004057128128374859741050152138922744827569584742755865582642547973675302281460172953180943986473632887946596332241308714911670572235506404378343844806318173396563260394233982296179615895307605317731227985043577493746415331201719805011701902566262463612342100259681998168607120840824272128238509801766418177487778500591626841638255437944558149458352684139388823853142284250634045474080887124272807362979164755836609413423627010284431984248709160860619700328679302413104731257125480147725731452147539205659813488796409916179806791734330084218316687059317743031093663119029


-------------------
Polynôme g(x) : -1013491448474378592299413430595646496997455649774075823461755912977696718002456979625379322717457001598790593689885136348221617922112658624619089532374722582800705155639141000642057007117756413257621856720070524241063224460796371953380