Project: Differential Cryptanalysis
======
### By
### Anindya Brata Choudhury

The following code is a submitted as a solution for Applied cryptography Jan2021 course project. This file serves as both the code and README file requirements for the project. It implements differential cryptanalysis as described in [this document](https://www.engr.mun.ca/~howard/PAPERS/ldc_tutorial.pdf)

## Installation
To run the code, install jupyter notebook by following the steps given [here](https://jupyter.org/install)

## Imports and Miscellaneous functions

This first section contains imported libraries and miscellaneous utility functions that are not important to the main project. They only help simplify the rest of the code.

[Click here to skip ahead to the main code](#Setup)

In [1]:
from fractions import Fraction
from collections import defaultdict
from itertools import product
import random

def print_table(data):
    s = "{:>3} " * len(data[0])
    print("Input\Output", s.format(*[hx for hx in range(16)]), "\n")
    for i, row in enumerate(data):
        print("{:<12}".format(i), s.format(*row))
        
def hex_to_binlist(h):
    # Converts a single hexadecimal number to binary tuple
    return tuple([1 if h&(1<<i) > 0 else 0 for i in range(4)][::-1])

#print(hex_to_binlist(0x5))

def binlist_to_hex(bl):
    # Converts binary tuple to a single hexadecimal number 
    hx = 0
    for b in bl:
        hx = hx << 1
        hx = hx | b
    return hx

#print(binlist_to_hex((0,1,1,0)))

def hexlist_to_binlist(hl):
    # Converts a tuple of hexadecimal numbers to binary
    bl = list()
    for h in hl:
        bl.extend(hex_to_binlist(h))
    return tuple(bl)

#print(hexlist_to_binlist((0,0xB, 0, 0)))

def binlist_to_hexlist(bl):
    # Converts a tuple of binary numbers to hexadecimal values
    return tuple((binlist_to_hex(bl[i:i+4]) for i in range(0, len(bl), 4)))

#print(binlist_to_hexlist((0,1,1,0,0,0,0,1)))

def xor_list(l1, l2):
    # Returns the XOR of two lists of numbers
    # Returns 0 if any of the XOR operands are not numbers ('unknown')
    return tuple((hx1^hx2 if type(hx1) is int and type(hx2) is int else 0 for hx1, hx2 in zip(l1, l2)))

#print(xor_list((1,2,1,2), (3,3,3,3)))

## Setup
In this section, the constant values are defined, basic operations are implemented and basic encryption and decryption algorithms for the toy cipher are implemented. This section doesn't contain anything specific to cryptanalysis.

To run this code, change the key in the code cell below to a new key, then run the whole notebook. The cryptanalysis attack will run and find the new key.

### Input and constants
Input/Output mapping for the SBoxes are contained in the SMap tuple.
Input/Output bit's position mapping is contained in the PMap tuple.
A tuple containing 5 16-bit keys are randomly assigned as input K.
This is the key for 4 round toy cipher that the cryptanalysis attack will try to guess.
Replace the value in K with secret input value of the cipher.
The cryptanalysis part of the code will not use K in any step other than for chosen plaintext attack, where it is used in encryption operation to generate plaintext-ciphertext pairs.

In [2]:
K = (
    (0xC, 0x9, 0x0, 0x8),
    (0x3, 0xF, 0xA, 0xE),
    (0x5, 0x6, 0x0, 0x1),
    (0x6, 0x5, 0x2, 0x7),
    (0x3, 0x2, 0xA, 0x4),
)

#         0    1    2    3    4    5    6    7    8    9   10   11   12   13   14   15
SMap = (0xE, 0x4, 0xD, 0x1, 0x2, 0xF, 0xB, 0x8, 0x3, 0xA, 0x6, 0xC, 0x5, 0x9, 0x0, 0x7)

#        1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16
PMap = ( 1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15,  4,  8, 12, 16)

### Basic Operations
In this section, the basic substitution, permutation and add-key operations are implemented. These functions are used in the next step and simplifies the implementation of encryption and decryption operations.

In [3]:
def substitution_op(hexval, smap=SMap):
    # Returns SBox output given SBox input and SBox I/O mapping
    return smap[hexval] if type(hexval) is int else 0

#print(substitution_op(0xB))
#print(substitution_op("Q"))

def substitution_row(hexlist, smap=SMap):
    # Returns list of SBox outputs given list of SBox inputs and SBox I/O mapping
    return tuple([substitution_op(hx, smap=smap) for hx in hexlist])

#print(substitution_row((1,2,"Q",0)))

def reverse_substitution_row(hexlist, smap=SMap):
    # Returns list of SBox inputs given list of SBox outputs and SBox I/O mapping
    # Used in decryption
    smap_copy, smap = smap, list(smap)
    for i, v in enumerate(smap_copy):
        smap[v] = i
    return substitution_row(hexlist, smap=tuple(smap))

#print(reverse_substitution_row((4,13,"Q",14)))

def permutation_row(hexlist, pmap=PMap):
    # Returns output for permutation step given input and Permutation bit position mapping
    ln = len(hexlist) * 4
    result = [0] * ln
    bl = hexlist_to_binlist(hexlist)
    for i in range(ln):
        result[pmap[i] - 1] = bl[i]
    return binlist_to_hexlist(tuple(result))

#print(permutation_row((0, 2, 0, 0)))

def reverse_permutation_row(hexlist, pmap=PMap):
    # Returns input for permutation step given output and Permutation bit position mapping
    # Used in decryption
    pmap_0_base, pmap = [p - 1 for p in pmap], list(pmap)
    for i, v in enumerate(pmap_0_base):
        pmap[v] = i + 1
    return permutation_row(hexlist, pmap=tuple(pmap))

#print(permutation_row((0, 0, 4, 0)))

def add_key(hexlist, key):
    # Used to XOR round input/output with round key
    return xor_list(hexlist, key)

#print(add_key((1, "Q", 1, 2), (3,3,3,"Q")))

### SPN Encryption Operation

In [4]:
def spn_encrypt(P, rounds=4, Key=K):
    # Returns Ciphertext given Plaintext, number of rounds, and key
    # Input key K is used when value of Key not specified
    C = P
    for r in range(rounds):
        C = add_key(C, Key[r])
        C = substitution_row(C)
        if r < rounds - 1: # Permutation operation is skipped for last round
            C = permutation_row(C)
    C = add_key(C, Key[rounds]) # Adding last round key
    return C

def spn_decrypt(C, rounds=4, Key=K):
    # Returns Plaintext given Ciphertext, number of rounds, and key
    # Input key K is used when value of Key not specified
    P = add_key(C, Key[rounds]) # Adding last round key
    k = [reverse_permutation_row(ki) for ki in Key[:-1]] # All round keys other than last are permuted before use in decryption
    for r in range(rounds - 1, -1, -1):
        if r < rounds - 1: # Permutation operation is skipped for last round
            P = reverse_permutation_row(P)
        P = reverse_substitution_row(P)
        P = add_key(P, k[r])
    return P

# p = [1,2,3,4]
# c = spn_encrypt(p)
# p = spn_decrypt(c)
# print(p,c)

## Cryptanalysis

### SBox difference pair probabilities
By iterating through all possible values for two inputs (x1 and x2) of an SBox, we find frequencies (and thus probabilities) for all SBox difference input/output pairs.

In [5]:
differential_prob = defaultdict(int)

for x1 in range(16):
    for x2 in range(16):
        y1, y2 = substitution_op(x1), substitution_op(x2)
        dx, dy = x1 ^ x2, y1 ^ y2
        differential_prob[dx, dy] += 1 # Frequency of each dx,dy pair is counted for a single SBox

diff_prob_table = [[differential_prob[dx, dy] for dy in range(16)] for dx in range(16)]
print("The following table shows the frequency of different difference pairs for a single SBox:\n")
print_table(diff_prob_table)

The following table shows the frequency of different difference pairs for a single SBox:

Input\Output   0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  

0             16   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 
1              0   0   0   2   0   0   0   2   0   2   4   0   4   2   0   0 
2              0   0   0   2   0   6   2   2   0   2   0   0   0   0   2   0 
3              0   0   2   0   2   0   0   0   0   4   2   0   2   0   0   4 
4              0   0   0   2   0   0   6   0   0   2   0   4   2   0   0   0 
5              0   4   0   0   0   2   2   0   0   0   4   0   2   0   0   2 
6              0   0   0   4   0   4   0   0   0   0   0   0   2   2   2   2 
7              0   0   2   2   2   0   2   0   0   2   2   0   0   0   0   4 
8              0   0   0   0   0   0   2   2   0   0   0   4   0   4   2   2 
9              0   2   0   0   2   0   0   4   2   0   2   2   2   0   0   0 
10             0   2   2   0   0   0   0   0   6  

### Optimal SBox configurations
In this part, optimal configurations for SBoxes are found using depth first search. The output delP and delU4 pairs that occur with a high probability are generated.

Only high probability configurations such that only 2 or less SBoxes in round 4 get non-zero inputs are considered.

The SBox differential pairs are explored in order of highest probability first. A minimum probability for paths is also maintained. No configuration paths that have lower probability of occuring than the minimum is explored any further. This value is updated every time an optimal configuration is found. The plaintexts are tested in a random order, since this increases probability of encountering a high probability differential pair early. Encountering high probability pair early will decrease runtime since we no longer need to explore paths will less probability.

This still takes some time to run (a few minutes), but only needs to be run once if SBox and permutation mappings stay constant. To save time, parameters for search, such as pruning threshold and number of high probability differential pairs to use per input differential can be adjusted for finding approximate optimal answers instantaneously. The number of plaintexts being tested may also be lowered to speed things up.

For demonstration, previously run output is saved. To run the function, uncomment line 105.

In [6]:
def dfs():
# Returns a list of optimal SBox configurations
# Each config is a tuple (probability, DelP, DelU4, list_of((SBox position, SBox I/O, SBox probability)))
    optimal_config = [(Fraction(0), (0,0,0,0), (0,0,0,0), ((1,1), (0,0), Fraction(1))) for _ in range(4)]
    sorted_dif_prob = defaultdict(list)
    for dx, dy in differential_prob:
        if differential_prob[dx,dy] > 0:
            sorted_dif_prob[dx].append((Fraction(differential_prob[dx,dy], 16), dy))
    for dx in sorted_dif_prob:
        sorted_dif_prob[dx].sort(reverse=True) # Differential outputs sorted according to probability
        #sorted_dif_prob[dx] = sorted_dif_prob[dx][:3] # Limit number of paths per input differential
    minprob = Fraction(0, 256) # Change to only consider configurations that guarantee a certain probability
    sbox_positions = [(i, j) for i in range(1, 4) for j in range(1, 5)]
    Ps = list(product([hx for hx in range(16)], repeat=4)) # All possible Plaintexts are tested
    random.shuffle(Ps) # Checking P in random order, to increase probability of getting high prob pair early
    #Ps = Ps[:1000] # Limit number of plaintexts being tested
    for P in Ps:
        if P == (0,0,0,0): # delP = (0,0,0,0) is not tested, since for the same P, we'll get the same U and C
            continue
        U, V = [P], []
        # Setting up stack and conditions for depth-first search
        stack = [((1,1), (P[0],-1), Fraction(1))]
        si = [-1]
        prob = [Fraction(1)] * 2
        # Running depth-first search
        while len(stack) > 0:
            last = stack.pop()
            lasti = si.pop()
            prob.pop()
            if last[0][1] == 4: # If config of last SBox of row is changed, U,V values are updated
                U.pop()
                V.pop()
            lastdx = last[1][0]
            if lasti < len(sorted_dif_prob[lastdx]) - 1:
                lasti += 1
                lastdif = sorted_dif_prob[lastdx][lasti]
                pos = last[0]
                dif = (lastdx, lastdif[1])
                pr = lastdif[0]
                last = (pos, dif, pr)
                stack.append(last)
                si.append(lasti)
                prob.append(prob[-1] * pr)
                if pos[1] == 4: # If config of last SBox of row is changed, U,V values are updated
                    smap = [0]*16
                    v = []
                    for i in range(4):
                        dif = stack[-4 + i][1]
                        smap[dif[0]] = dif[1]
                        v.append(substitution_op(U[-1][i], smap))
                    V.append(v)
                    U.append(permutation_row(V[-1]))
                if prob[-1] < minprob: # Skip config paths of lower probability than those we already found
                    continue
                for pos in sbox_positions[len(stack):]: # Iterating through unset configs, assigning new ones
                    u = U[-1][pos[1]-1]
                    udif = sorted_dif_prob[u][0]
                    pr = udif[0]
                    next_conf = (pos, (u, udif[1]), pr)
                    stack.append(next_conf)
                    si.append(0)
                    prob.append(prob[-1] * pr)
                    if pos[1] == 4: # If config of last SBox of row is changed, U,V values are updated
                        smap = [0]*16
                        v = []
                        for i in range(4):
                            dif = stack[-4 + i][1]
                            smap[dif[0]] = dif[1]
                            v.append(substitution_op(U[-1][i], smap))
                        V.append(v)
                        U.append(permutation_row(V[-1]))

                    if prob[-1] < minprob: # Skip config path of lower probability than those we already found
                        break
                else:
                    if U[-1].count(0) >= 2: # if dU only affects <= 2 sboxes
                        for i in range(4): # Choose highest probability config for each target partial key
                            if U[-1][i] > 0 and prob[-1] > optimal_config[i][0]:
                                optimal_config[i] = (prob[-1], P, U[-1], [c for c in stack if sum(c[1]) > 0])
                                minprob = min(p[0] for p in optimal_config)
                
    return optimal_config

"""
# According to the document
optimal_config = [(Fraction(27, 1024), (0,0xB,0,0), (0,6,0,6),[
                   ((1,2), (0xB,2), Fraction(8, 16)),
                  ((2,3), (4,6), Fraction(6, 16)),
                  ((3,2), (2,5), Fraction(6, 16)),
                  ((3,3), (2,5), Fraction(6, 16))
])]
"""

#"""
# To save time during demonstration, this config from a previous output is assigned
# Otherwise, this function may take about 1-5 minutes to complete execution
optimal_config = [
    (Fraction(3, 256), (11, 0, 11, 0), [8, 0, 0, 8], [((1, 1), (11, 2), Fraction(1, 2)), ((1, 3), (11, 2), Fraction(1, 2)), ((2, 3), (10, 8), Fraction(3, 8)), ((3, 1), (2, 9), Fraction(1, 8))]),
    (Fraction(9, 256), (11, 0, 11, 0), [0, 8, 0, 8], [((1, 1), (11, 2), Fraction(1, 2)), ((1, 3), (11, 2), Fraction(1, 2)), ((2, 3), (10, 8), Fraction(3, 8)), ((3, 1), (2, 5), Fraction(3, 8))]),
    (Fraction(81, 4096), (0, 0, 15, 0), [0, 5, 5, 0], [((1, 3), (15, 4), Fraction(3, 8)), ((2, 2), (2, 5), Fraction(3, 8)), ((3, 2), (4, 6), Fraction(3, 8)), ((3, 4), (4, 6), Fraction(3, 8))]),
    (Fraction(9, 256), (11, 0, 11, 0), [0, 8, 0, 8], [((1, 1), (11, 2), Fraction(1, 2)), ((1, 3), (11, 2), Fraction(1, 2)), ((2, 3), (10, 8), Fraction(3, 8)), ((3, 1), (2, 5), Fraction(3, 8))])
]
#"""

#%time optimal_config = dfs()
print("For the following configurations, there exists a DelP for which a DelU4 occurs with high probability:\n")
for c in optimal_config:
    print(c, "\n")



CPU times: user 3min 22s, sys: 35.2 ms, total: 3min 23s
Wall time: 3min 23s
For the following configurations, there exists a DelP for which a DelU4 occurs with high probability:

(Fraction(3, 256), (11, 0, 11, 0), (8, 0, 0, 8), [((1, 1), (11, 2), Fraction(1, 2)), ((1, 3), (11, 2), Fraction(1, 2)), ((2, 3), (10, 8), Fraction(3, 8)), ((3, 1), (2, 9), Fraction(1, 8))]) 

(Fraction(9, 256), (11, 0, 11, 0), (0, 8, 0, 8), [((1, 1), (11, 2), Fraction(1, 2)), ((1, 3), (11, 2), Fraction(1, 2)), ((2, 3), (10, 8), Fraction(3, 8)), ((3, 1), (2, 5), Fraction(3, 8))]) 

(Fraction(81, 4096), (0, 15, 0, 0), (0, 6, 6, 0), [((1, 2), (15, 4), Fraction(3, 8)), ((2, 2), (4, 6), Fraction(3, 8)), ((3, 2), (4, 6), Fraction(3, 8)), ((3, 3), (4, 6), Fraction(3, 8))]) 

(Fraction(9, 256), (11, 0, 11, 0), (0, 8, 0, 8), [((1, 1), (11, 2), Fraction(1, 2)), ((1, 3), (11, 2), Fraction(1, 2)), ((2, 3), (10, 8), Fraction(3, 8)), ((3, 1), (2, 5), Fraction(3, 8))]) 



### Probability for different keys
In this part of the code, for each configuration, we run the function generate_partial_keys() to find partial keys of high probability.

This function iterates through all possible partial keys in the target partial key blocks for the given configurations. It then stores the frequency of occurence of right pairs for each key. The partial key with highest probability is returned.

In [7]:
partial_generations = []
def generate_partial_keys(config):
    confProb, confdP, confdU4, confs = config
    # Choose random x1s, generate x2s from them using plaintext differential from config
    X1s = list(product((hx for hx in range(16)), repeat=4))
    random.shuffle(X1s)
    X1s = X1s[:800] # Limit number of chosen plaintexts to use for the attack
    X2s = [xor_list(x1, confdP) for x1 in X1s]
    C = {tuple(x): spn_encrypt(x) for x in X1s+X2s}
    # Getting rid of pairs where 0 doesn't appear at appropriate places in ciphertext
    X1s_, X1s = X1s, []
    for x1, x2 in zip(X1s_, X2s):
        c1, c2 = C[x1], C[x2]
        dc = xor_list(c1, c2)
        if all((dc[i] == 0 if confdU4[i] == 0 else dc[i] != 0 for i in range(4))):
            X1s.append(x1)
    X2s = [xor_list(x1, confdP) for x1 in X1s]
    # Generate candidate partial keys
    partials_to_test = [_ for _ in product(*[[_ for _ in range(16)] if hx > 0 else ["?"] for hx in confdU4])]
    partials_probability = defaultdict(Fraction)
    # For each partial key, go through every filtered pair and count probability for right pairs
    for partial_key in partials_to_test:
        dec_key = ((0,0,0,0), partial_key) # Key used for decryption function
        for x1, x2 in zip(X1s, X2s):
            c1, c2 = C[x1], C[x2] # Real ciphertexts are used to generate value of U by decrypting last round
            u41, u42 = spn_decrypt(c1, rounds=1, Key=dec_key), spn_decrypt(c2, rounds=1, Key=dec_key)
            du4 = xor_list(u41, u42)
            if all((u == cu for u, cu in zip(du4, confdU4))): # Count right pairs found for each tested key
                partials_probability[partial_key] += Fraction(1,len(X1s_))
    partial_generations.append((config, X1s_, X1s, partials_probability)) # Data is stored for analysis
    return max(partials_probability, key=partials_probability.get)

%time partial_keys = [generate_partial_keys(oc) for oc in optimal_config]
print("The following partial keys have been found:")
for key in partial_keys:
    print(key)

CPU times: user 1.4 s, sys: 6.64 ms, total: 1.4 s
Wall time: 1.41 s
The following partial keys have been found:
(3, '?', '?', 4)
('?', 2, '?', 4)
('?', 2, 10, '?')
('?', 2, '?', 4)


### Sample partial keys
Some sample partial keys along with their probability of occuring are given below

In [8]:
prob_dict = partial_generations[0][-1]
prob_pair = sorted([(prob_dict[k], k) for k in prob_dict.keys()], reverse=True)
print("Sample partial keys and their probabilities:")
for p, k in prob_pair[:13:3]:
    print(k, "\t", p)

Sample partial keys and their probabilities:
(3, '?', '?', 4) 	 11/800
(3, '?', '?', 9) 	 1/160
(3, '?', '?', 15) 	 1/200
(14, '?', '?', 9) 	 3/800
(5, '?', '?', 4) 	 3/800


## Results
The resultant partial keys occuring with high probability are combined to find the whole key of last round.

In [9]:
whole_key = ["?", "?", "?", "?"]
for pk in partial_keys:
    for i in range(4):
        if type(pk[i]) is int:
            whole_key[i] = pk[i]
print("Key found:   ", whole_key)
print("Original key:", K[4])

Key found:    [3, 2, 10, 4]
Original key: (3, 2, 10, 4)


## Conclusion
The cryptanalysis code is capable of attacking the spn_encrypt function implemented according to the toy cipher. It uses all the techniques given in the [document](https://www.engr.mun.ca/~howard/PAPERS/ldc_tutorial.pdf) and uses them to carry out the differential cryptanalysis attack.

The manually determined parameters could have been found automatically by employing binary search and choosing parameter based on code execution times encountered during the search. This was not done since the code would become unnecessarily convoluted and wouldn't be suitable for demonstration purposes.