# SSE - Demo - Ferrara Justin
- Notebook to run in a sage environnment
- Proof of concept de l'article : **Keegan Ryan, Kaiwen He, George Arnold Sullivan, Nadia Heninger**. *Passive SSH Key Compromise via Lattices*. Cryptology ePrint Archive, Paper 2023/1711, 2023. [DOI: 10.1145/3576915.3616629](https://doi.org/10.1145/3576915.3616629), [URL: https://eprint.iacr.org/2023/1711](https://eprint.iacr.org/2023/1711).

In [8]:
from hashlib import sha256
from math import gcd
from random import randint
from Crypto.Util.number import getPrime, inverse
import numpy as np

# Import lll
#from lll import lll

def key_gen():
    phi = 2
    e = 2
    while gcd(phi, e) != 1 : 
        p = random_prime(2**1024, proof = False)
        q = random_prime(2**1024, proof = False)
        n = p*q
        phi = (p-1) * (q-1)
        e = 65537
    d = inverse_mod(e, phi)
    return (e, d, n, p, q)

def sign(m, d, p, q, n, bug = False):
    dp = d % (p-1)
    dq = d % (q-1)
    h = int.from_bytes(sha256(m).digest(), byteorder = "big")

    n1 = 0
    if bug:
        # n1 = (p-6) * p
        rand_p = randint(1, p)
        sp = power_mod(h, dp, rand_p)
        
        n1 = q * rand_p
        
        print("Real n1 :", n1)
        print("\n")
    else:
        sp = power_mod(h, dp, p)

    sq = power_mod(h, dq, q)
    signature = crt([sp, sq], [p, q])

    # PKCS#1 V1.5 signature
    prefix_sha256 = b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'
    pad_len = 256 - len(h.to_bytes(32, byteorder="big")) - len(prefix_sha256) - 3
    pad = b'\x00\x01' + b'\xff' * pad_len + b'\x00' + prefix_sha256 + h.to_bytes(32, byteorder="big")

    print("Signature (int):", signature)

    return signature, pad

def verify(m, s, e, n):
    sign_calc = power_mod(s, e, n)
    sign_calc_bytes = int(sign_calc).to_bytes(256, byteorder="big")

    # Extraction du hash et comparaison
    h_calculated = sha256(m).digest()
    h_extracted = sign_calc_bytes[-32:]

    print("Calculated hash from m:", h_calculated.hex())
    print("Extracted hash of pad :", h_extracted.hex())

    return h_calculated == h_extracted


def attack(s, e, n, pad):

    # Known part of the padding
    a = pad[:-32]
    sign_calc = power_mod(s, e, n)
    sign_calc_bytes = int(sign_calc).to_bytes(256, byteorder="big")
    
    h_calculated = int.from_bytes(sha256(sign_calc_bytes).digest(), byteorder='big')
    h_extracted = int.from_bytes(pad[-32:], byteorder='big')
    
    print("h_calculated :", h_calculated)
    print("h_extracted :", h_extracted)
    
    a_int = int.from_bytes(a, byteorder='big')
    
    n1 = (a_int - h_calculated)
    
    print("Calculated n1 :", n1)
    print("\n")

    # Construct the matrix
    log_r = 256            # TODO change
    B = Matrix(ZZ, 3, 3)
    B[0, 0] = -2**(2 * log_r)
    B[0, 1] = 2**log_r * n1
    B[0, 2] = 0
    B[1, 0] = 0
    B[1, 1] = -2**log_r
    B[1, 2] = n1
    B[2, 0] = 0
    B[2, 1] = 0
    B[2, 2] = n
    
    
    # Réduction LLL sur la matrice
    reduced_matrix = B.LLL()

    # Affichage de la matrice réduite
    print("Matrice réduite :", reduced_matrix)
    
    print("-------------------")
    
    
    for row in reduced_matrix:

        # Construction du polynôme g(x)
        R = PolynomialRing(ZZ, 'x')
        x = R.gen()
        g = row[2] + row[1] * x + row[0] * x**2

        print("Polynôme g(x) :", g)

        # Recherche des racines rationnelles
        roots = g.roots()  # Recherche des racines dans les entiers

        print("Roots :", roots)

        for r1, _ in roots:
            p = gcd(n, n1 - r1)
            if 1 < p < n:
                print("-------------------")
                print("Attack success !!!")
                print("n1 :", n1)
                print("r1 :", r1)
                print("p :", p)
                return p
    

    print("Attack failed.")
    return None



(e, d, n, p, q) = key_gen()
m = b"This message is signed with RSA-CRT!"
(s, pad) = sign(m, d, p, q, n, True)
# print("e = %s" % str(e))
# print("n = %s" %str(n))
# print("s = %s" %s)

print("-------------------")
print("Verify signature: ", verify(m, s, e, n))

print("-------------------")
print("Attack")
print("-------------------")
new_p = attack(s, e, n, pad)

print("Correct p founded:", p == new_p or q == new_p)

Real n1 : 22360547374661483384536154957969013317380551675130914489239631424181274445413279619189947353767176783379110658780678820852148010208583615517793649289511982294793446959998551436754146968096582941478046030913897299403485269534098324681477566655974041138375913720953810242311565514122799498300106951606112922166472618634126686774816178755727687095373750821858704084175114681426658567358813987007082763766896871814671024317511451461138679658254729200634874996139292758063885052961784689922546828730983911734910377069971385965944600637018427740413789342191822345423896348131323242925189214644300860436852331673336365300295


Signature (int): 26689521202949465990274092717106137076722071703810258967027446941607975997176096884164562738203099834030284963075541254210486413505216522986023777616430309870381018724664959399034250878694434103654067523163388753469292852293539331863071029565768732207320548759811185158498181977648031726210013142847574211075943107650028185709791234721143641977801765653