# Chapter 14: Think of the Children: Countermeasures, Certifications and Goodbytes

This notebook is a companion to Chapter 14 of The Hardware Hacking Handbook by Jasper van Woudenberg and Colin O'Flynn. The headings in this notebook follow the headings in the book.

© 2021. This work is licensed under a [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/). 

In [None]:
from timeit import timeit
from copy import copy

# Do some functional tests and performance tests on memcmp functions
def test_memcmp(memcmp_fn, num=None, sameval=0):
    # Buffers to compare
    c1 = b'01234'
    c2 = b'0123x'
    
    # By default, loop over length of buffer
    if num is None:
        num = len(c1)
    
    # Some simple functional tests
    assert memcmp_fn(c1,c1,len(c1)) == sameval
    assert memcmp_fn(c1,c2,len(c2)) != sameval
    
    # Prepare statement for timing
    statement = "%s(c1, c2, %d)" % (memcmp_fn.__name__, num)
    print(statement)
    
    for l in range(num):
        # Copy and change one byte
        c2 = bytearray(c1)
        c2[l] = ord('x')
        c2 = bytes(c2)
        
        # Prep statement for timing
        print("Timing for diff at %d" % l, end=": ")
        
        # Create dict of locals + globals, to make it find the memcmps and c1/c2
        glb = dict(globals())
        glb.update(locals())
        
        # Time it
        print("%.3f" % timeit(statement, globals=glb))  

## Non-correlating / constant time everywhere
### Constant time memory compare

In [None]:
def memcmp(c1, c2, num):
    # Terminate as soon as we find a difference, leaking differing byte through timing
    for i in range(num):
        if c1[i] != c2[i]:
            return 1
    
    return 0

test_memcmp(memcmp)

In [None]:
def memcmp_consttime(c1, c2, num):
    # Accumulate differing bits in diff
    diff = 0
    for i in range(num):
        diff = diff | (c1[i] ^ c2[i])  # If bits differ, the xor is nonzero, therefore diff will be nonzero
    return diff 

test_memcmp(memcmp_consttime)

### Constant time conditionals

In [None]:
def takesLong():
    j = 0
    for i in range(10):
        j = j + i * i
    return j

def muchShorter():
    i = 10
    j = i * i
    return j
    
def leakSecret(secret):
    # This leaks whether the secret is 0xca because of timing 
    if secret == 0xca:
        res = takesLong()
    else:
        res = muchShorter()
    return res

# Show timing values for leaky function for some values
for secret in range(0xc8, 0xcd):
    print("value %02x => duration %.3f" % (secret, timeit("leakSecret(%d)" % secret, globals=globals())))  

In [None]:
def dontLeakSecret(secret):
    # Run both sides of the if() condition
    res1 = takesLong()
    res2 = muchShorter()
    mask = int(secret == 0xca) - 1       # Mask is either all bits 0 or all bits 1, depending on if() condition
    res = (res1 & ~mask) | (res2 & mask) # Use mask to select one value 
    return res

# Show timing values for non-leaky function for some values
for secret in range(0xc8, 0xcd):
    print("value %02x => duration %.3f" % (secret, timeit("dontLeakSecret(%d)" % secret, globals=globals())))  

## Randomize access to confidential array values

In [None]:
import random

def memcmp_randorder(c1, c2, num):
    # Initially, no diff
    diff = 0                              
    # Get random starting point
    rnd = random.randint(0, num-1)
    
    # Loop over buffers
    for i in range(num):                  
        idx = (i + rnd) % num             # Get index, wrap around if needed          
        diff = diff | (c1[idx] ^ c2[idx]) # collect diff
        
    return diff

test_memcmp(memcmp_randorder, 1)

## Perform decoy operations or infective computing

In [None]:
DECOY_PROBABILITY = 0.5  # Probability of a decoy round

def memcmp_decoys(c1, c2, num):
    # Prep decoy values, initialize to 0
    decoy1 = bytes(len(c1)) 
    decoy2 = bytes(len(c2)) 
    
    # Init diff accumulator and random starting point
    diff = 0    
    rnd = random.randint(0, num-1)
    
    i = 0
    while i < num:    
        # Get index, wrap around if needed
        idx = (i + rnd) % num             
        
        # Flip coin to check we have a decoy round
        tmpdiff = 0
        
        do_decoy = random.random() < DECOY_PROBABILITY
        if do_decoy:
            decoy = decoy1[idx] ^ decoy2[idx]   # Do similar operation, don't touch tmpdiff
        else:
            tmpdiff = c1[idx] ^ c2[idx]         # Real operation, put in tmpdiff

        # Accumulate diff
        diff = diff | tmpdiff
        
        # Adjust index if not a decoy
        i = i + int(not do_decoy)
    return diff

test_memcmp(memcmp_decoys, 1)

## Use non-trivial constants

In [None]:
CONST1 = 0xC0A0B000
CONST2 = 0x03050400
SAME = CONST1 ^ CONST2
print("Consts chosen not to have overlapping bits: %x ^ %x == %x" % (CONST1, CONST2, SAME))
assert((CONST1 ^ CONST2) == (CONST1 | CONST2)) # Check no same bits set to 1

def memcmp_nontrivial(c1, c2, num):
    # Prep decoy values, initialize to 0
    decoy1 = bytes(len(c1)) 
    decoy2 = bytes(len(c2)) 
    
    # Init diff accumulator and random starting point
    diff = 0
    rnd = random.randint(0, num-1)
    
    i = 0
    while i < num:  
        # Get index, wrap around if needed
        idx = (i + rnd) % num             
        
        # Flip coin to check we have a decoy round
        do_decoy = random.random() < DECOY_PROBABILITY
        
        if do_decoy:
            decoy = (CONST1 | decoy1[idx]) ^ (CONST2 | decoy2[idx]) # Do similar operation 
            tmpdiff = CONST1 | CONST2                               # Set tmpdiff so we still have nontrivial consts
        else:
            tmpdiff = (CONST1 | c1[idx]) ^ (CONST2 | c2[idx])       # Real operation, put in tmpdiff
            decoy = CONST1 | CONST2                                 # Just to mimic other branch

        # Accumulate diff
        diff = diff | tmpdiff
        
        # Adjust index if not a decoy
        i = i + int(not do_decoy)
            
    return diff

test_memcmp(memcmp_nontrivial, 1, sameval=SAME) 

## Status variable reuse

In [None]:
# Our non-trivial constants
SECURE_OK = 0xc001bead
SECURE_FAIL = 0xa5a5a5a5

# Some dummy functions
def validate_address(a): return a
def validate_signature(s): return s

def check_fw(a, s, fault_skip):
    print("Running with %08x, %08x, % 5s" % (a, s, fault_skip), end=": ")
    
    rv = validate_address(a)
    if rv == SECURE_OK:
        # Simulate a fault here if fault_skip is set
        # In which case, it maintains rv == SECURE_OK state!
        if not fault_skip:
            rv = validate_signature(s)
        
        if rv == SECURE_OK:
            print("Firmware ok. Flashing!")
        
        # All the error cases below
        elif rv == SECURE_FAIL:
            print("Signature invalid")
        else:
            print("Fault detected!")
    elif rv == SECURE_FAIL:
        print("Address invalid")
    else:
        print("Fault detected!")


check_fw(SECURE_OK, SECURE_OK, False)       # All ok
check_fw(SECURE_FAIL, SECURE_OK, False)     # Invalid address
check_fw(SECURE_OK, SECURE_FAIL, False)     # Invalid signature 
check_fw(SECURE_OK, 0xbadf1, False)         # Fault in signature state
check_fw(SECURE_OK, SECURE_FAIL, True)      # Invalid signature, but skip verification and leave rv variable at SECURE_OK

## Verify control flow

In [None]:
def default_fail(input, secret, do_fault):
    print("default_fail(%s, %s, % 5s): " % (input, secret, do_fault), end="")
    result = memcmp(input, secret, len(input)) 

    if do_fault:
        # If we have a fault, just set result to something invalid
        result = 0x1243

    # Check result
    if result == 0:
        print("Access granted, my liege")
    elif result == 1:
        print("Ah ah ah, you didn't say the magic word")
    else: 
        # Default-fail case. This should logically never happen! Respond to fault!
        print("Fault detected!")

default_fail(b'xxx', b'pwd', False)
default_fail(b'pwd', b'pwd', False)
default_fail(b'xxx', b'pwd', True)

In [None]:
def double_check(input, secret, do_fault1, do_fault2):
    print("double_check(%s, %s, % 5s, % 5s): " % (input, secret, do_fault1, do_fault2), end="")

    # Check result
    result = memcmp(input, secret, len(input)) 
    
    if do_fault1:
        # Dangerous fault, setting result to 0
        result = 0

    if result == 0:

        # Do memcmp again to protect against FI
        result2 = memcmp(input, secret, len(input)) 

        if do_fault2:
            # Dangerous fault, setting result to 0
            result2 = 0
            
        # Double check with some different logic
        if not result2 ^ 0xff != 0xff:
            print("Access granted, my liege")
        else:
            print("Fault2 detected!")
            
    elif result == 1:
        print("Ah ah ah, you didn't say the magic word")
    else: 
        # Default-fail case. This should logically never happen! Respond to fault!
        print("Fault1 detected!")
        
double_check(b'xxx', b'pwd', False, False)
double_check(b'xxx', b'pwd', True, False)
double_check(b'xxx', b'pwd', True, True)

In [None]:
def double_check_wait(input, secret, do_fault1, do_fault_at):
    # do_fault_at is a counter that indicates at what point in time the simulated fault should happen,
    # relative to the start of the random wait loop. If we time it 'just right', it happens just after
    # the wait loop and hits the second comparison. 
    print("double_check_wait(%s, %s, % 5s, %d): " % (input, secret, do_fault1, do_fault_at), end="")

    # Check result
    result = memcmp(input, secret, len(input)) 
    
    if do_fault1:
        # Dangerous fault, setting result to 0
        result = 0

    if result == 0:

        # Random wait
        wait = random.randint(0,3)
        for i in range(wait):
            # Did the simulated fault happen here?
            if do_fault_at == 0:
                print("Fault happened during wait loop. ", end="")
            do_fault_at -= 1
        
        # This is also a good point to insert some not-so-sensitive other operations
        # Just to decouple the random wait loop from the sensitive operation
        
        if do_fault_at == 0:
            # Simulated fault timed right! 
            print("Fault timed right. ", end="")
            result2 = 0
        else:
            # Do memcmp again
            result2 = memcmp(input, secret, len(input)) 

        # Double check with some different logic
        if not result2 ^ 0xff != 0xff:
            print("Access granted, my liege")
        else:
            print("Fault2 detected!")
            
    elif result == 1:
        print("Ah ah ah, you didn't say the magic word")
    else: 
        # Default-fail case. This should logically never happen! Respond to fault!
        print("Fault1 detected!")
        
double_check_wait(b'xxx', b'pwd', False, 1)
# Try hitting one point in time many times... Re-run until you get an access granted
for i in range(10):
    double_check_wait(b'xxx', b'pwd', True, 1)

In [None]:
def check_loop_end(loop_len, fault_idx):
    print("check_loop_end(%s,%s): " % (loop_len, fault_idx), end="")
    i = 0
    while i < loop_len:
        # Simulate a fault where we prematurely break out of a loop
        if i == fault_idx:
            break
        i += 1
    
    # Check the loop has completed
    if i != loop_len:
        print("Fault detected")
    else:
        print("All good")
    
check_loop_end(10,None) # No fault
check_loop_end(10,11)   # Fault too late
check_loop_end(10,5)    # Break out of loop


In [None]:
# crc_check() example for control flow integrity

from binascii import crc32

# What the CRC should be, if we have no fault
CORRECT_CRC = 0xe390d9f8

# The global CRC value
global_crc = crc32(b"Initial CRC32 value")

def fault(): 
    if DEBUG:
        print("Fault detected! Continuing because in debug mode.")
    else:
        print("Fault, self-terminating!")
        assert False

def fi_check_update(x):  
    global global_crc
    
    if DEBUG:
        print("Updating with", x) 
    global_crc = crc32(x, global_crc)


def fi_check(check_crc): 
    global global_crc
    
    if DEBUG:
        print("CRC check: %08x" % check_crc)
        print("CRC is:    %08x" % global_crc)

    # First check
    if check_crc != global_crc:  
        fault(); 
    
    # Fancy double check by doing an extra crc32 
    if crc32(b'x', check_crc) != crc32(b'x', global_crc):  
        fault(); 

def func1(): 
    fi_check_update(b'f')
    print("Doing func1() stuff")
    fi_check_update(b'1')

def func2(): 
    fi_check_update(b'f')
    print("Doing func2() stuff")
    fi_check_update(b'2')
    
def crc_check():
    # Run the "program"
    fi_check_update(b'm')
    func1()
    if not FAULT:
        # Let's say the fault skips the call to func2()
        func2()
    fi_check_update(b'a')
    fi_check(CORRECT_CRC)
    print("Fin")
  
# Some global flags you can play with
DEBUG = 1  # Whether to do a bunch of debug prints
FAULT = 0  # Whether to inject a fault that causes a control flow issues

crc_check()

## Detect and respond to faults

In [None]:
from collections import Counter

def memcmp_fault_detect(c1, c2, num, fault=lambda x: x):
    # Fault = fault function. By default, no faults are injected
    
    # Prep decoy values, initialize to 0
    decoy1 = bytes(len(c1)) 
    decoy2 = bytes(len(c2)) 
    
    # Init diff accumulator and random starting point
    diff = 0    
    rnd = random.randint(0, num-1)
    
    i = 0
    while i < num:  
        # Get index, wrap around if needed
        idx = (i + rnd) % num             
        
        # Flip coin to check we have a decoy round
        do_decoy = random.random() < DECOY_PROBABILITY
        
        if do_decoy:
            decoy = (CONST1 | decoy1[idx]) ^ (CONST2 | decoy2[idx]) # Do similar operation 
            tmpdiff = CONST1 | CONST2                               # Set tmpdiff so we still have nontrivial consts
        else:
            tmpdiff = (CONST1 | c1[idx]) ^ (CONST2 | c2[idx])       # Real operation, put in tmpdiff
            decoy = CONST1 | CONST2                                 # Just to mimic other branch

        # Simulate fault injection
        tmpdiff = fault(tmpdiff)
        if tmpdiff & ~0xff != SAME: # Should always be True, unless a fault occurred
            return None  

        # Accumulate diff
        diff = diff | tmpdiff
        
        # Adjust index if not a decoy
        i = i + int(not do_decoy)

    # Simulate fault injection
    diff = fault(diff)
    if diff & ~0xff != SAME: # Should always be True, unless a fault occurred
        return None  
        
    return diff

test_memcmp(memcmp_fault_detect, 1, sameval=SAME)

## Verifying countermeasures

In [None]:
FAULT_PROB = 0.2
def fault(value):
    # Flip a random byte with FAULT_PROB probability
    if random.random() < FAULT_PROB:
        flipvalue = random.randint(1,255) # Avoid 0 => no fault
        flipindex = random.randint(0,3)*8 # 0,8,16,24
        value ^= flipvalue << flipindex   # XOR in the fault
    return value

def result_map(output):
    # Map results of memcmp 
    if output == None: return None # Fault detected 
    if output == SAME: return True # Same
    return False                   # Different

# Run some fault simulation tests
TOTAL = 2000000
results = [result_map(memcmp_fault_detect(b"abcde", b"abcdx", 5, fault)) for _ in range(TOTAL)]
ctr = Counter(results)
print("Fault injection results for %02.0f%% fault probability" % (FAULT_PROB*100))
print("%d/%d (%.1f%%) fault detected" % (ctr[None],  TOTAL, ctr[None]/TOTAL*100))
print("%d/%d (%.1f%%) no effect" % (ctr[False], TOTAL, ctr[False]/TOTAL*100))
print("%d/%d (%.4f%%) memcmp bypass" % (ctr[True],  TOTAL, ctr[True]/TOTAL*100))