## Alisa Todorova

In [1]:
from sage.all import *

### 1. Implement the signature generation algorithm using the Chinese Remainder Theorem (CRT) using the Sage library. 
More precisely, to compute s = md mod N, compute and s=m^d mod N, compute 
s_p=s mod p=m^(d mod p−1) mod p and s_q=s mod q=m^(d mod q−1) mod q

Recover s mod N from s_p and s_q using the CRT.

In [2]:
# Generates a signature using the given parameters
def sig_gen(m, d, p, q):
    # Compute s_p and s_q as shown above
    s_p = power_mod(m, d % (p - 1), p)
    s_q = power_mod(m, d % (q - 1), q)

    # Use CRT to recover s mod N from s_p and s_q
    s = crt([s_p, s_q], [p, q])

    return s

In [3]:
# This implementation is based on slide 7 of Lecture Slides "The RSA cryptosystem - Part 1: encryption and signature"

# Generates an RSA key, where input n is the bitsize n of the RSA modulus N.
def keyGen(n=512):

    # Generate two large distinct primes p and q of same bit-size k/2, where k is a parameter.
    p = random_prime(2**n - 1, lbound=2**(n - 1) + 1)
    q = random_prime(2**n - 1, lbound=2**(n - 1) + 1)

    Nn = p * q

    phi = (p - 1) * (q - 1)

    # Select a random integer e such that gcd(e, φ) = 1
    e = randint(2, phi - 1)
    while gcd(e, phi) != 1:
        e = randint(2, phi - 1)

    # Compute the unique integer d such that e · d ≡ 1 (mod φ)
    d = inverse_mod(e, phi)

    return Nn, p, q, e, d

In [4]:
# Test
m = 12345
Nn, p, q, e, d = keyGen()
signature = sig_gen(m, d, p, q)
print("Signature:", signature)

Signature: 63874662285719623577018983810926457154901160128099550747990999863732549471534666833883323917193716118633294972489378743227334500102294416856314482551248570269850534883348079224857778519862365240891544638984807241945068599477956409429253797700215232381581136650746597003057832543424170885393152540069036482892


In [5]:
# Verifies the signature
def sig_ver(s, e, Nn):
    return pow(s, e, Nn)

recovered_m = sig_ver(signature, e, Nn)
print("Recovered message:", recovered_m)

# Check if the signature is valid
print("Signature is valid:", recovered_m==m)

Recovered message: 12345
Signature is valid: True


### 2. Assume that an error occurs during the computation of s_p, that is, an incorrect value s′_p   ̸=  s_p is computed while s_q is correctly computed. Explain and implement how to recover the factorization of N from s, following the Bellcore attack [BDL97].

The Bellcore attack is a fault attack on RSA signature scheme. If a fault occurs during the computation of s_p (see method sig_gen_error), we can recover the factorization of N by computing gcd(s_correct - s_error, N). We can do this because (s_correct - s_error) is a multiple of p, but not N, so the greatest common divisor of (s_correct - s_error) and N will give us p.

In [6]:
# Generates a signature with an error occuring during the computation of s_p
def sig_gen_error(m, d, p, q):
    # Compute s_p with an error
    s_p_wrong = power_mod(m, d % (p - 1), p) + 1  # +1 introduces the fault
    # s_q stays as shown above
    s_q = power_mod(m, d % (q - 1), q)

    # Use CRT to recover s mod N from s_p_wrong and s_q, i.e., signature with error
    s_error = crt([s_p_wrong, s_q], [p, q])

    return s_error, s_p_wrong, s_q

In [7]:
# Performs the Bellcore attack to factorize N from s using the given parameters
def bellcore_attack(s_correct, s_error, Nn):

    p_recovered = gcd(s_correct - s_error, Nn)
    q_recovered = Nn / p_recovered

    return p_recovered, q_recovered

In [8]:
# Test

m = 12345
Nn, p, q, e, d = keyGen()

# Correct signature
s_correct = sig_gen(m, d, p, q)
print("Correct Signature:", s_correct)

# Signature with an error in s_p
s_error, s_p_wrong, s_q = sig_gen_error(m, d, p, q)
print("Signature with an error in s_p:", s_error)

print("p:", p)
print("q:", q)

# Perform the Bellcore attack
p_recovered, q_recovered = bellcore_attack(s_correct, s_error, Nn)
print("Recovered p:", p_recovered)
print("Recovered q:", q_recovered)

# Check if the factors are equal
print("Recovered Factors are equal:", Nn == p_recovered * q_recovered)

Correct Signature: 86831619450489827034913761224272859671263343474209933604835146022496588342881104064016581256456383212651526295753159184589524539226424522889176828433349250262665637518432712957442950390814941705865180758736160423467658364374455977892457453901630356512760350909647938231371140942894326739872406195144996554838
Signature with an error in s_p: 87787177196987406742253019351124767083092709164964372148265695537513077387456247541759016654760787620614084170302328457665170274241950329322324313873180573583534820303751842882208514090176065830179295111509230101411353584407103795937522306400367124266868338360614770455782205122813321976255385071908734322152
p: 10936774892146756035607236157249614656265408945648882949122640973150121322916498468708567675060512140928792381795800036035423225942155654730120669517540883
q: 8036679004723843968702063591716442726014524285792087452334531773594954909599539006813907886890932026523829050573562033123975933649919579366158165453653633
Recovered p: 8

### 3. How could such error be detected? Propose and implement a simple method to detect such error.

In order to detect an error in s_p (or s_q), we can recompute s from s_p and s_q using the Chinese Remainder Theorem (CRT), and then compare this recomputed s with the original s. If they are not equal, then an error has occurred during the computation of s_p or s_q.

In [9]:
# Detects errors in the values of s, s_p, and s_q
def detect_error(s, s_p, s_q, p, q):
    error_messages = []

    # Recompute s from s_p and s_q using the CRT
    s_recomputed = crt([s_p, s_q], [p,q])

    # Correct s_p and s_q
    s_p_correct = power_mod(m, d % (p - 1), p)
    s_q_correct = power_mod(m, d % (q - 1), q)

    # Detect errors
    if s_recomputed != s:
        error_messages.append("Error in s")
    
    if s_p_correct != s_p:
        error_messages.append("Error in s_p")
    if s_q_correct != s_q:
        error_messages.append("Error in s_q")
    
    if not error_messages:
        error_messages.append("No errors detected")

    return error_messages

In [10]:
# Test

m = 12345
Nn, p, q, e, d = keyGen()

# Correct signature
s_correct = sig_gen(m, d, p, q)
print("Correct Signature:", s_correct)

# Signature with an error in s_p
s_error, s_p_wrong, s_q = sig_gen_error(m, d, p, q)

# Check if any errors can be detected
error_detected = detect_error(s_correct, s_p_wrong, s_q, p,q)
print("Error Detected:", error_detected)

Correct Signature: 26572717191185844328890907562784603483440058082633917246668337158650884023145644696731180539260473351894541189739106226353102123602935554227786285429084692754788418602728882598054262301305850911265916691395087174738729300004913850920648728864033945553436402887018591058051348989448514554172410261016864966870
Error Detected: ['Error in s', 'Error in s_p']


Another way we can know for sure whether there're any errors, is to verify the generated signature against the original message and the public key. 

In [11]:
# Verifies the signature
def sig_ver(s, e, Nn):
    return pow(s, e, Nn)

print("Original message:", m)
# Message we want to recover
recovered_m = sig_ver(s_error, e, Nn)
print("Recovered message:", recovered_m)

# Check if the signature is valid
if recovered_m != m:
    print("Signature is invalid! There is an error in the signature.")

Original message: 12345
Recovered message: 7555553609145045416614731334175904102287415981949891917489262947563325527474544320687948800879833129678382960476695655182377204548752557915846564628756981441982832228888497346841264136130409474665365933708770228473459404321427263355112527172405730511023089768273979945291244878458641430983783787486893099312
Signature is invalid! There is an error in the signature.
