# Chapter 12: Bench time: Differential Power Analysis

This notebook is the companion to Chapter 12 of The Hardware Hacking Handbook by Jasper van Woudenberg and Colin O'Flynn.

- If you'd like to use the pre-recorded traces, set `SCOPETYPE` to `FILE` and `PLATFORM` to `CWLITEXMEGA`. 
- If you'd like to use ChipWhisperer hardware, set `SCOPETYPE` to `OPENADC` and `PLATFORM` to `CWLITEXMEGA`.

If you are using your own target hardware, the bootloader code can be found at http://localhost:8888/tree/hardware/victims/firmware/bootloader-aes256 when the VM is running.

© 2021. This work is licensed under a [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/). 


In [None]:
SCOPETYPE = 'FILE'
PLATFORM = 'CWLITEXMEGA'

# Files were captured on CWLITEXMEGA
if SCOPETYPE == 'FILE':
    PLATFORM = 'CWLITEXMEGA'

# Imports needed for this notebook
from bokeh.io import output_notebook
from bokeh.palettes import Category20
from bokeh.plotting import figure, show
from Crypto.Cipher import AES
from tqdm import tnrange,tqdm

import chipwhisperer.analyzer as cwa
import chipwhisperer as cw
import numpy as np
import random
import shutil
import time


In [None]:
# Basic RSA encryption/decryption
def encrypt(private_key, plaintext):
    d, n = private_key
    ciphertext = [(m ** d) % n for m in plaintext]
    return ciphertext

def decrypt(public_key, ciphertext):
    e, n = public_key
    plaintext = [(c ** e) % n for c in ciphertext]
    return plaintext
    

# Our RSA-16 private and public key
p = 211
q = 223

public = (36077, 47053)
private = (29693, 47053)

## Obtaining and building the bootloader code

For this lab, we'll be using the `bootloader-aes256` project, which we'll build as follows:

In [None]:
%%bash -s "$PLATFORM" 
cd ../hardware/victims/firmware/bootloader-aes256
make PLATFORM=$1 CRYPTO_TARGET=NONE

Next, we'll flash the code onto the target.

In [None]:
if SCOPETYPE == 'OPENADC':
    %run "Helper_Scripts/Setup_Generic.ipynb"

    fw_path = "../hardware/victims/firmware/bootloader-aes256/bootloader-aes256-{}.hex".format(PLATFORM)
    cw.program_target(scope, prog, fw_path)

## Running the target and capturing traces

### Calculating the CRC
Some CRC code from pycrc:

In [None]:
# Class Crc
#############################################################
# These CRC routines are copy-pasted from pycrc, which are:
# Copyright (c) 2006-2013 Thomas Pircher <tehpeh@gmx.net>
#
class Crc(object):
    """
    A base class for CRC routines.
    """

    def __init__(self, width, poly):
        """The Crc constructor.

        The parameters are as follows:
            width
            poly
            reflect_in
            xor_in
            reflect_out
            xor_out
        """
        self.Width = width
        self.Poly = poly


        self.MSB_Mask = 0x1 << (self.Width - 1)
        self.Mask = ((self.MSB_Mask - 1) << 1) | 1

        self.XorIn = 0x0000
        self.XorOut = 0x0000

        self.DirectInit = self.XorIn
        self.NonDirectInit = self.__get_nondirect_init(self.XorIn)
        if self.Width < 8:
            self.CrcShift = 8 - self.Width
        else:
            self.CrcShift = 0

    def __get_nondirect_init(self, init):
        """
        return the non-direct init if the direct algorithm has been selected.
        """
        crc = init
        for i in range(self.Width):
            bit = crc & 0x01
            if bit:
                crc ^= self.Poly
            crc >>= 1
            if bit:
                crc |= self.MSB_Mask
        return crc & self.Mask


    def bit_by_bit(self, in_data):
        """
        Classic simple and slow CRC implementation.  This function iterates bit
        by bit over the augmented input message and returns the calculated CRC
        value at the end.
        """
        # If the input data is a string, convert to bytes.
        if isinstance(in_data, str):
            in_data = [ord(c) for c in in_data]

        register = self.NonDirectInit
        for octet in in_data:
            for i in range(8):
                topbit = register & self.MSB_Mask
                register = ((register << 1) & self.Mask) | ((octet >> (7 - i)) & 0x01)
                if topbit:
                    register ^= self.Poly

        for i in range(self.Width):
            topbit = register & self.MSB_Mask
            register = ((register << 1) & self.Mask)
            if topbit:
                register ^= self.Poly

        return register ^ self.XorOut
    
bl_crc = Crc(width = 16, poly=0x1021)

### Communicating with the Bootloader

In [None]:
# Synchronize with the target
def target_sync(loud=False):
    if SCOPETYPE == 'FILE':
        return  #Nop
    
    okay = 0
    reset_target(scope)

    # Loop until we get a correct response
    while not okay:
        target.write('\0xxxxxxxxxxxxxxxxxx')
        time.sleep(0.05)
        response = target.read()
        if response:
            if loud:
                print("Target said: %s" % response.encode("utf-8").hex())
            if ord(response[0]) == 0xA1:
                okay = 1
                
target_sync(True) # True => print target responses

In [None]:
ktp = cw.ktp.Basic() # Random message/key generation

# Return a random key/message tuple
def random_message():
    return ktp.next()
    
# Wrap the text in a message buffer for target
def prep_message(text):
    message = [0x00] # Leading 0
    
    # Add text
    message.extend(text) 

    # Add CRC
    crc = bl_crc.bit_by_bit(text) 
    message.append(crc >> 8)
    message.append(crc & 0xFF)
    
    return message

In [None]:
if SCOPETYPE != 'FILE':
    # clear serial buffer
    target.read()

    # Send a random message
    key, text = random_message()
    target.write(prep_message(text))
    time.sleep(0.1)

    # Check for correct response (A4)
    response = target.read()
    print("Response: {:02X}".format(ord(response[0])))

### Capturing Traces

In [None]:
# Capture Traces; or load from file, depending on scopetype 
#
# Reset = indicates whether to reset the target for every trace
# msg_gen = input text and key generator
# returns a new project 
def capture(num_traces, name, reset=False, msg_gen=random_message, export=False): 
    projname = "projects/" + name
    zipname = projname + ".zip"
    if SCOPETYPE == 'FILE':
        return cw.import_project(zipname, overwrite=True)
    
    # Create new project
    try:
        shutil.rmtree(projname + "_data") # Delete stale data from previous runs
    except IOError:
        # Ignore on rmtree error
        True
    
    project = cw.create_project(projname, overwrite=True)

    # Loop over all traces
    for i in tnrange(num_traces, desc='Capturing traces'):
        if reset:
            target_sync(False)

        # Create input data
        key, text = msg_gen()
        message = prep_message(text)
        
        # clear serial buffer
        target.read()

        # Get scope ready and send message
        scope.arm()
        target.write(message)
        ret = scope.capture()
        
        # Check results
        if ret:
            print('Timeout happened during acquisition')
        response = target.read()
        if ord(response[0]) != 0xA4:
            # Bad response, just skip
            print("Bad response: {:02X}".format(ord(response[0])))
            continue

        # Add trace to project
        project.traces.append(cw.Trace(scope.get_last_trace(), text, "", key))

    project.save()
    
    if export:
        project.export(zipname)
    return project

In [None]:
# Config scope and capture overview trace
if SCOPETYPE != 'FILE':
    scope.adc.samples = 24400
    scope.clock.adc_src = "clkgen_x1" # Slower sampling to get more time
    scope.adc.basic_mode = "rising_edge"
overview = capture(3, "aeskey_overview")

In [None]:
from bokeh.io import export_svgs

# Plot 3 traces from the given project, zoom into x_range
def plot3(project, x_range=None, svgfile=None):
    output_notebook()
    p = figure(x_range=x_range)

    xrange = range(len(project.traces[0].wave))
    p.line(xrange, project.traces[0].wave, line_color="red")
    p.line(xrange, project.traces[1].wave, line_color="blue")
    p.line(xrange, project.traces[2].wave, line_color="green")
    show(p)
    
    if svgfile:
        p.output_backend = "svg"
        export_svgs(p, filename=f"fig/12_{svgfile}")
    
plot3(overview, [0,24400], "4.svg")



### Capturing detailed traces

In [None]:
# Setup scope. 
if SCOPETYPE != 'FILE':
    scope.adc.samples = 24400
    scope.clock.adc_src = "clkgen_x4" # Faster sampling
    scope.adc.basic_mode = "rising_edge"
project = capture(200, "aeskey_r14r13")
plot3(project)

# We "guess" this leaks
leak_model = cwa.leakage_models.inverse_sbox_output 

# Perform CPA analysis
attack = cwa.cpa(project, leak_model) 

## Analysis
### 14th Round Key

Set `attack.point_range` to the range of samples where you think round 14 is executed.

In [None]:
# Run attack on round 14 range
cb = cwa.get_jupyter_callback(attack)
if PLATFORM == "CWLITEARM" or PLATFORM == "CW308_STM32F3":
    attack.point_range = decrypt(public, [4792, 39132])
elif PLATFORM == "CWLITEXMEGA" or PLATFORM == "CW303":
    attack.point_range = decrypt(public, [8492, 21014])
attack_results = attack.run(cb)

In [None]:
key = decrypt(public, [4782, 41021, 41021, 34848, 40659, 46642, 23307, 21303, 34180, 29318, 38236, 36358, 5628, 8565, 10190, 44112]
)

# Recover round 14 key
rec_key = []
for bnum in attack_results.find_maximums():
    rec_key.append(bnum[0][0])
    print("Best Guess = 0x{:02X}, Correlation = {}".format(bnum[0][0], bnum[0][2]))

# key correct?
if rec_key == key:
    print("Correct k14!")
else:
    print("Go fix k14 first!")

### 13th Round Key

#### Resyncing Traces

Set `k13range` to the range you think the alignment should be done one, and `max_shift` to the max amount of misalignment to fix.

In [None]:
if PLATFORM == "CWLITEXMEGA" or PLATFORM == "CW303":
    # Select range for round 13
    k13range = decrypt(public,[46957, 16007])
    
    # Plot unsynchronized traces
    plot3(project, (k13range[0]+2000,k13range[0]+3000), "6_left.svg")
    
    # Preprocess traces with SA over range
    resync_traces = cwa.preprocessing.ResyncSAD(project)
    resync_traces.enabled = True
    resync_traces.ref_trace = 0
    resync_traces.target_window = (k13range[0], k13range[1])
    resync_traces.max_shift = decrypt(public,[40659])[0]
    projsync = resync_traces.preprocess()
    attack.change_project(projsync)
    
    # Plot synchronized traces
    plot3(projsync, (k13range[0]+2000,k13range[0]+3000), "6_right.svg")

#### Leakage model

In [None]:
# This class implements a AES256 round 13 leakage model 
class AES256_Round13_Model(cwa.AESLeakageHelper):
    def leakage(self, pt, ct, guess, bnum):
        # Start from known k14
        k14 = rec_key
        block = [k14[i] ^ pt[i] for i in range(0, 16)]  # AddRoundKey
        block = self.inv_shiftrows(block)
        x14 = self.inv_subbytes(block)
        # Round 14 done, start round 13
        block = self.inv_mixcolumns(x14)
        block = self.inv_shiftrows(block)
        result = block
        # Leak after the inv sbox of r13
        return self.inv_sbox((result[bnum] ^ guess[bnum]))
    
# Set up new leakage model
leak_model = cwa.leakage_models.new_model(AES256_Round13_Model)
attack.leak_model = leak_model

#### Running the Attack
Set `attack.point_range` to the samples that contain round 13.

In [None]:
if PLATFORM == "CWLITEARM" or PLATFORM == "CW308_STM32F3":
    attack.point_range = decrypt(public,[15849, 5765])
elif PLATFORM == "CWLITEXMEGA" or PLATFORM == "CW303":
    attack.point_range = decrypt(public,[46957, 16007])

# Run attack on round 13, in range defined above
cb = cwa.get_jupyter_callback(attack)
attack_results = attack.run(cb)

In [None]:
# Recover round 13 key
rec_key2 = []
for bnum in attack_results.find_maximums():
    print("Best Guess = 0x{:02X}, Corr = {}".format(bnum[0][0], bnum[0][2]))
    rec_key2.append(bnum[0][0])
    
# Transform key to 'real' key
real_key2 = cwa.aes_funcs.shiftrows(rec_key2)
real_key2 = cwa.aes_funcs.mixcolumns(real_key2)

print("Recovered:", end="")
for subkey in real_key2:
    print(" {:02X}".format(subkey), end="")
print("")

In [None]:
# Append r14 key
rec_key_comb = real_key2.copy()
rec_key_comb.extend(rec_key)

print("Key:", end="")
for subkey in rec_key_comb:
    print(" {:02X}".format(subkey), end="")
print("")

In [None]:
# Roll key back to get proper AES key
btldr_key = leak_model.key_schedule_rounds(rec_key_comb, 13, 0)
btldr_key.extend(leak_model.key_schedule_rounds(rec_key_comb, 13, 1))
print("Key:", end="")
for subkey in btldr_key:
    print(" {:02X}".format(subkey), end="")
print("")

# Check key
real_btldr_key_enc = [15336, 3529, 42394, 42472, 30505, 12484, 32645, 3152, 22549, 31026, 5560, 38283, 37002, 22386, 45783, 5737, 14, 43638, 31122, 17972, 19453, 6921, 23470, 43009, 39379, 3529, 33128, 16722, 31089, 42985, 5516, 28658]

if decrypt(public,real_btldr_key_enc) == btldr_key:
    print("Correct key!")
else:
    print("Try again!")

## Recovering the IV
### What to capture

In [None]:
# Get an overview trace after the falling edge, which is where the IV is used
if SCOPETYPE != 'FILE':
    scope.adc.samples = 24400
    scope.clock.adc_src = "clkgen_x4"
    scope.adc.basic_mode = "falling_edge"

### Getting the first trace

In [None]:
iv_overview = capture(3, "iv_overview")
plot3(iv_overview, x_range=(0,2500), svgfile="8.svg")

### Getting the rest of the traces

Zoom into the picture above to find the range for capturing traces for IV, and set it in `scope.adc.samples`.

In [None]:
# Number of traces differs per platform
if PLATFORM == "CWLITEARM" or PLATFORM == "CW308_STM32F3":
    N = 100 
elif PLATFORM == "CWLITEXMEGA" or PLATFORM == "CW303":
    N = 500
    
# Set up scope and capture
if SCOPETYPE != 'FILE':
    scope.adc.samples = decrypt(public,[42169])[0]
    scope.clock.adc_src = "clkgen_x4"
    scope.adc.basic_mode = "falling_edge"
project_iv = capture(N, "iv_full", True) # True=> Reset every time so we capture initial IV

In [None]:
# We're going to use numpy, so convert traces and text
trace_array = np.asarray([project_iv.traces[i].wave for i in range(0,N)]) 
textin_array = np.asarray([project_iv.traces[i].textin for i in range(0,N)])

# Get some info
num_traces = len(trace_array)
trace_len = len(trace_array[0])
num_bytes = len(textin_array[0])

# Calculate dr array by decrypting the ciphertexts using the known key
knownkey = bytes(btldr_key)
dr = []
aes = AES.new(knownkey, AES.MODE_ECB)
for i in range(num_traces):
    ct = bytes(textin_array[i])
    pt = aes.decrypt(ct)
    d = [bytearray(pt)[i] for i in range(16)]
    dr.append(d)


### Analysis
#### Doing the 1-Bit Attack

In [None]:
# Calculate the difference trace for given byte and bit
def get_diff(byte, bit):
    # Create two groups depending on specific bit
    grouped_traces = [[] for _ in range(2)]
    for i in range(num_traces):
        # Determine bit
        bit_split = (dr[i][byte] >> (7-bit)) & 0x01
        
        # Add to right group
        grouped_traces[bit_split].append(trace_array[i])
    
    # Find averages and differences per group
    means0 = np.average(grouped_traces[0], axis=0)
    means1 = np.average(grouped_traces[1], axis=0)
    return means1 - means0

In [None]:
bit = 0 # only plot LSB

# Do the plot for numbytes and one specific bit
output_notebook()
p = figure()
for byte in range(0,num_bytes):
    d = get_diff(byte, bit)
    xrange = range(len(d))
    p.line(xrange, d, line_color=Category20[16][byte])
show(p)
p.output_backend = "svg"
export_svgs(p, filename="fig/12_9.svg")

In [None]:
# Finds num peaks in the given trace. 
# Return peaks and their location 
def findpeaks(trace, num=5):
    abstrace = -np.absolute(trace)     # Sort by absolute largest peaks, - for descending
    indx = np.argsort(abstrace)[0:num] # Get locations
    return (trace[indx], indx) 

# Obtain start and slope
if PLATFORM == "CWLITEARM" or PLATFORM == "CW308_STM32F3":
    (start,slope) = decrypt(public,[38167, 3529])  
elif PLATFORM == "CWLITEXMEGA" or PLATFORM == "CW303":
    (start,slope) = decrypt(public,[22386, 23307])

Play with `start` and `slope` variables until you get a perfect linear fit through some significant peaks, and peaks that have roughly both green and red circles. 

In [None]:
peak_per_byte = 10 # Number of highest peaks per byte to plot
base_radius = 10   # Scale the size of the dots in the graph

# Set start and slope here
# (start,slope) = (0, 0)
locations =  [d*slope+start for d in range(0,num_bytes)]

# Find peaks
peaks = [findpeaks(get_diff(byte,bit),peak_per_byte) for byte in range(0,num_bytes)]

# Plot the peaks per byte. 
output_notebook()
p = figure(x_axis_label="Key byte", y_axis_label="Time")

# Create x and y coordinates. For y, we extract the location of the peak.
x = np.array([i for i in range(0,num_bytes) for _ in range(0,peak_per_byte)])
y = np.array([peaks[byte][1][peak] for byte in range(0,num_bytes) for peak in range(0,peak_per_byte)])

# Red = negative peak, green = positive peak, black=no peak
pal = np.array(["red","black","green"])
color = np.array([int(np.sign(peaks[byte][0][peak])+1) for byte in range(0,num_bytes) for peak in range(0,peak_per_byte)])

# Radius is relative to size of peak
radius = np.array([base_radius * peaks[byte][0][peak] for byte in range(0,num_bytes) for peak in range(0,peak_per_byte)])

# Draw red circle for each negative peak, black square for no peak, and green star for positive peak
p.circle(x[color==0], y[color==0], fill_color=None, line_color=pal[0], radius=radius[color==0])
p.square(x[color==1], y[color==1], fill_color=None, line_color=pal[1], size=radius[color==1]*100)
p.star  (x[color==2], y[color==2], fill_color=None, line_color=pal[2], size=radius[color==2]*100)

# Plot a line for the exact linear relation
x = range(0,num_bytes)
y = locations
p.line(x, y, color="black")

show(p)
p.output_backend = "svg"
export_svgs(p, filename="fig/12_10.svg")

#### The Other 127

In [None]:
flip = 0 # Set to 0 iff negative peaks are a bit value of 0
print("Bits are flipped:", flip)

btldr_IV = [0] * 16

# Loop over all bytes
for byte in range(16):
    location = locations[byte]
    iv = 0
    print("IV byte {:02d}:".format(byte), end = " ")
    
    # Loop over all bits
    for bit in range(8):
        # Extract a vector of a plaintext bit for each trace. Byte and bit indicate which bit to extract.
        pt_bits = [((dr[i][byte] >> (7-bit)) & 0x01) for i in range(num_traces)]

        # Split traces into 2 groups
        grouped_points = [[] for _ in range(2)]
        for i in range(num_traces):
            grouped_points[pt_bits[i]].append(trace_array[i][location])      

        # Calculate diff of means
        means = []
        for i in range(2):
            means.append(np.average(grouped_points[i]))
        diff = means[1] - means[0]

        # Set bit depending on sign of diff
        iv_bit = 1 if diff > 0 else 0
        iv = (iv << 1) | (iv_bit ^ flip)

        print(iv_bit, end = " ")

    print("{:02X}".format(iv))
    btldr_IV[byte] = iv
    
# Check IV
real_btldr_IV = decrypt(public,[8565, 22386, 5509, 33004, 9842, 41989, 2894, 24955, 10931, 10114, 12531, 46642, 21602, 22696, 18667, 29186]
)
if real_btldr_IV == btldr_IV:
    print("You got the IV!")
else:
    print("Bummer, no IV. Please come again.") 

## Attacking the Signature
### Power traces

In [None]:
# For capturing traces; iterate over all signature bytes
def next_sig_byte():
    global byte_val
    global byte_idx
    global iv
    global aes
    
    # Init
    text = [0] * 16
    text[0:byte_idx] = btldr_sig[0:byte_idx]
    
    # Set the signature byte
    text[byte_idx] = byte_val
        
    # Apply IV
    for i in range(len(iv)):
        text[i] ^= iv[i]
    
    # Encrypt text
    ct = aes.encrypt(bytes(text))

    # Use ct as new IV
    iv[:] = ct[:]
    byte_val += 1

    # Get "key" to satisfy framework
    key, _ = ktp.next()

    return key, ct

In [None]:
# Initialize signature variables
iv = btldr_IV.copy() 
knownkey = bytes(btldr_key)
aes = AES.new(knownkey, AES.MODE_ECB)
btldr_sig = [0] * 4
byte_idx = 0
byte_val = 0

# Capture settings
if SCOPETYPE != 'FILE':
    scope.adc.basic_mode = "falling_edge"
    scope.adc.samples = 24000
    scope.adc.offset = 0
N = 256 # Number of traces

# Capture 256 traces for all 256 values of one signature byte
target_sync()
project_sig = capture(N, "sig_one", msg_gen=next_sig_byte)
plot3(project_sig)

### Analysis

Based on the overview trace above, select a `sign_range` in which to check for SPA differences with the mean.

In [None]:
# Analysis range
sign_range = decrypt(public,[34264, 16887])
sign_range = range(sign_range[0],sign_range[1])

# Number of results to show; 1..16
numprint = 5 

# Guess a signature based on timing differences. Set plot to true to show the top differing traces 
def guess_signature(project,plot=False,svgfile=None):
    # Traces as numpy matrix
    traces = np.asarray([project.traces[i].wave for i in range(0,N)])  # if you prefer to work with numpy array for number crunching

    # Calculate correlation between mean trace (reference) and individual byte guesses
    corr = []
    mean = np.average(traces, axis=0) # Reference trace
    for i in range(256):
        corr.append(np.corrcoef(mean[sign_range], traces[i][sign_range])[0, 1])  
        
    # Sort to get correlation and bytes
    corr_sort = np.sort(corr)
    corr_sort_idx = np.argsort(corr)

    # Print it
    print("Correlation values:  ", corr_sort[0:numprint])
    print("Signature byte guess:", corr_sort_idx[0:numprint])
    
    if plot:
        # Plot it
        output_notebook()
        p = figure()
        for j in range(numprint):
            i = corr_sort_idx[j]
            p.line(range(len(traces[i])), traces[i]-mean, line_color=Category20[numprint][j])  
        show(p)    
        if svgfile:
            p.output_backend = "svg"
            export_svgs(p, filename=f"fig/12_{svgfile}")
        
    return corr_sort_idx
    
# Guess signature for 1 byte
corr_sort_idx = guess_signature(project_sig,True, "11.svg")


### All 4 bytes

In [None]:
# Initialize 
btldr_sig = [0] * 4
iv = btldr_IV.copy() # Make a copy so next_sig_byte can modify it
knownkey = bytes(btldr_key)
aes = AES.new(knownkey, AES.MODE_ECB)

# Scope and target settings
N = 256 # Number of traces
if SCOPETYPE != 'FILE':
    scope.adc.samples = 24000
    scope.adc.offset = 0
target_sync()

# Loop over 4 sig bytes
for bnum in range(4):
    # Set byte_idx and byte_val for next_sig_byte
    byte_idx = bnum
    byte_val = 0
    
    # Capture
    project_sig = capture(N, "sig_byte" + str(bnum), msg_gen=next_sig_byte)
        
    # Analyze and print
    btldr_sig[bnum] = guess_signature(project_sig)[0]
    print("Signature guess:      ", btldr_sig[0:bnum+1])

In [None]:
# Check it's correct
if btldr_sig == decrypt(public, [0, 18667, 10050, 15766]):
    print("You got the signature too!")
else:
    print("No signature for you.")

In [None]:
if SCOPETYPE != 'FILE':
    scope.dis()
    target.dis()