# Lab 4: Fault Attacks
In this lab we will implement a fault attack using differential faults injected on a PRESENT implementation.

In [None]:
import numpy as np
import tqdm.notebook as tqdm
from reference import present
from random import randint

# Theoretical attack

## Faulting round 28 (0-index)
We will implement a differential fault analysis on PRESENT based on **single bitflips** inserted in the PRESENT Sbox input in **round 28** (0-index). The differentials are provided and are loaded below into the variable `diffs`, which is a list of `uint64` representing `ct ^ ct'` values. The plaintext and ciphertext are also provided and loaded into the variables `pt` and `ct` respectively, also represented as `uint64`.

We will use the ciphertext differentials to retrieve the full key used for encryption. This will be done in two steps, with the first one being to retrieve the last round key.

*All steps and algorithms used in this lab, as well as further details relating to the differential fault analysis on PRESENT can be seen in the paper [here](https://link.springer.com/article/10.1186/1687-6180-2013-145).*

**Note** that you can use the reference present implementation to help you throughout this lab.

In [None]:
diffs = np.fromfile("traces/04-PRESENT-DIFFS", dtype=np.uint64)
pt = np.fromfile("traces/04-PRESENT-PT", dtype=np.uint64)[0]
ct = np.fromfile("traces/04-PRESENT-CT", dtype=np.uint64)[0]

## Step 1: Retrieving the last round key
<img src="https://www.researchgate.net/profile/Kiran-Vg/publication/305618598/figure/fig2/AS:388762954682373@1469699719334/Block-diagram-of-present-cipher.png" alt="drawing" width="350"/>

Above shows a diagram of the PRESENT block cipher at a high level. Below is accompanying pseudocode based on the original publication [here](https://link.springer.com/chapter/10.1007/978-3-540-74735-2_31).

```
generateRoundKeys()
for i = 1 to 31 do
    addRoundKey(state,K_{i})
    sBoxLayer(state)
    pLayer(state)
end for
addRoundKey(state,K_{32})
```

With the differentials you've been given we can determine the input to the last Sbox layer. We can do this using two algorithms. 
- Firstly, we will set up a dictionary which tracks for each Sbox of the last layer what the output differentials were. For every given ciphertext differential we need to calculate back what the Sbox output *differential* is. If the output of a specific Sbox was altered, we add that Sbox output differential to the list of differentials for that specific Sbox entry in the dictionary.
- Secondly, we use the specific Sbox differentials above to determine what the input state is. We *know* that exactly one bit of the input state was flipped, giving us the specific Sbox output differentials we determined above. What we need to do now is determine for which state it is possible that flipping a single bit gives us the corresponding Sbox output differentials. This can be done as following: We start with a set of all possible states called `L`. For each Sbox output differential, we go through all possible input states. If there is a singular bitflip on that input state which results in the current Sbox output differential, then we add it to a *second* set of possible states called `M`. When we considered all the input states we make `L` into the intersection of `L` and `M`. After repeating this step for all Sbox output differentials we should have only one remaining possible input state in `L`.

### Algorithm 1
Determining the Sbox output differentials.

In [None]:
# algorithm 1
sbox_out_diffs = {i: [] for i in range(X)}
for diff in diffs:
    sbox_layer_out_diff = present....(diff)
    for sbox_index in range(X):
        # calculate the output differential for the specific sbox
        sbox_out_diff = ...
        
        if sbox_out_diff > 0: # there was a bitflip on this nibble so add it to the group
            sbox_out_diffs[sbox_index] += [sbox_out_diff]

# Keep only the unique values
sbox_out_diffs = dict((k, np.unique(v).tolist()) for k,v in sbox_out_diffs.items())

In [None]:
sbox_out_diffs

### Algorithm 2
Determining the Sbox input states based on the output differentials

In [None]:
# algorithm 2
last_round_state = 0
for sbox_index in range(X):
    L = set(i for i in range(NUM_INPUT_STATES))
    for sbox_out_diff in sbox_out_diffs[sbox_index]:
        M = []
        for s in range(NUM_INPUT_STATES):
            for i in range(4): # There are 4 bits which can be flipped
                if sbox_out_diff == ...TODO OUTPUT DIFFERENTIAL FOR STATE s WITH BITFLIP AT INDEX i...:
                    M.append(s)
        M = set(M)
        L = L.intersection(M)
    print(sbox_index, L)
    last_round_state = last_round_state << 4
    last_round_state |= present.sbox[L.pop()]

### Calculating the round key
You now have calculated completely the input to the final Sbox layer. Use this information together with the given ciphertext to calculate the last round key.

## Step 2: Now what
We have used differential fault analysis to figure out the final round key. Is this enough to know the master key? What remaining information do we need to determine the master key, and what are our options to determine it?

### Option 1: DFA of one more round
We can do a differential analysis going back one more round in the same exact way we have done so for the last round (since we injected singular bitflips 2 rounds back). This consists at a high level of the following steps:
1. Calculate the Sbox input state differential of the final round and then go back one round in the same way as before to calculate the Sbox output differential of one layer before that.
2. Reduce the input state space in the same way as before to determine the exact input state.
3. Use the calculated input state to determine the second-to-last round key, which can be combined with the last round key to determine the master key.

### Option 2: Brute force
We do not completely know all information necessary to calculate the master key, however knowing the last round key drastically reduces the amount of unknown information. This means it is possible to brute force the remaining unknown part of the key register. At a high level this consists of the following steps:
 1. Write a function which calculates based on the last round key and a partial key guess (for the unknown part of the key register) the master key.
 2. Start guessing the unknown part of the key register until the calculated master key you get gives you the `pt` -> `ct` encryption which you know. 

In [None]:
def last_round_key_to_master(last_round_key: int, guess: int):
    assert 0 <= guess < 2**16
    assert 0 <= last_round_key < 2**64
    key_reg = last_round_key << 16 | guess
    for i in range(31, 0, -1):
        # TODO do the inverse of the key schedule, you can use the reference implementation for inspiration and verifying correctness
        
    return key_reg

In [None]:
for partguess in tqdm.tnrange(2**16):
    keyguess = last_round_key_to_master(last_round_key, partguess)
    if ...keyguess is correct...:
        print(f"Found full key using partial guess {partguess}: {keyguess}")
        break