#CX4010 Course Project
##Objective
Create a simple SPN implementation with 2 rounds of 4x4 bit S-boxes and 8 bit p-boxes.

##Generating Key. 
A random Key is generated using os.urandom

In [None]:
from os import urandom

In [None]:
def keyGeneration():
  key = urandom(1)
  key_int = int.from_bytes(key,"big")
  print(bin(key_int))
  return key_int

##S-Box
S-box provide one-to-one substitution of blocks of data. it provides no-linearity in the encryption, but also allows decryption by being invertible.
For this SPN network, 4x4 s-boxes with a fixed mapping is used. We use the mapping (4 A 9 2 D 8 0 E 6 B 1 C 7 F 5 3) in one of the s-boxes from a standard russian block cipher GOST R 34.12-2015. See https://en.wikipedia.org/wiki/GOST_(block_cipher)

In [None]:
class Sbox:
  input = [0b0000,0b0001,0b0010,0b0011,0b0100,0b0101,0b0110,0b0111,0b1000,0b1001,0b1010,0b1011,0b1100,0b1101,0b1110,0b1111]
  output = [0x4,0xA,0x9,0x2,0xD,0x8, 0x0, 0xE, 0x6, 0xB, 0x1,0xC, 0x7, 0xF, 0x5,0x3]

                        
  def __init__(self):
    pass

  
  def forward(self,inputBits):
    index = self.input.index(inputBits)
    return self.output[index]
  
  def backward(self,outputBits):
    index = self.output.index(outputBits)
    return self.input[index]


In [None]:
## testing s-box
box1 = Sbox()
ct = box1.forward(0b0010)
print(bin(ct))
pt = box1.backward(ct)
print(bin(pt))
#as the decrypted plaintext is the same as input for encryption, s-box works as intended

0b1001
0b10


##P-box
in SPN network, P-box permutates the bits across different blocks from output of different s-boxes, so that confusion and diffusion is achived as changes in plaintext and key are spread across multiple bits over different rounds. 8 bit p-boxes with fixed permutation is used.

In [None]:
class Pbox:
  bitMask = [0b00000001,0b00000010,0b00000100,0b00001000,0b00010000,0b00100000,0b01000000,0b10000000]
  permutation = [7,3,5,4,2,6,0,1]

                        
  def __init__(self):
    pass

  
  def forward(self,inputBits):
    output = 0
    for i, mask in enumerate(self.bitMask):
      bit = inputBits & mask
      output += int(bool(bit))*self.bitMask[self.permutation[i]]
    return output
    
  
  def backward(self,outputBits):
    output = 0
    for i, mask in enumerate(self.bitMask):
      bit = outputBits & mask
      output += int(bool(bit))*self.bitMask[self.permutation.index(i)]
    return output

In [None]:
##testing p-box
box2 = Pbox()
ct = box2.forward(0b01010010)
print(bin(ct))
pt = box2.backward(ct)
print(bin(pt))

0b1101
0b1010010


## Key extention
Key is extention is needed so that each round of the SPN has a different sub-key. There are a variety of compliated key extention algorithms for standard block cipher such as AES and DES as they require a larger number of subkeys. Here a simple key extension algo is used as only two subkeys is needed.

In [None]:
def keyExtend(key):
  temp = key>>4
  key1 = temp^key
  temp = (key&0b00001111)<<4
  key2 = temp^key
  return [key1,key2]


In [76]:
keyList = keyExtend(0b01111010)
print(type(keyList))
print(bin(keyList[0]))
print(bin(keyList[1]))

<class 'list'>
0b1111101
0b11011010


##SPN network

In [None]:
class SPN:
  s_box = Sbox()
  p_box = Pbox()

  def __init__(self,key):
    self.keyList = keyExtend(key)

    
  def encrypt(self,msg):
    temp = msg
    # print(bin(temp))
    for key in self.keyList:
      # apply key
      # print("key:"+bin(key))

      temp = temp^key
      # print(bin(temp))

      # apply s-box
      leftHalf = temp>>4
      buffer = self.s_box.forward(leftHalf)
      # print(bin(buffer))
      rightHalf = temp&0b00001111
      buffer = (buffer<<4)+self.s_box.forward(rightHalf)
      # print(bin(self.s_box.forward(rightHalf)))
      # print(bin(buffer))
      temp = buffer

      # apply p-box
      temp = self.p_box.forward(temp)
      # print("pbox")
      # print(bin(temp))

    return temp

   
  def decrypt(self,msg):
    # for key in keyList.reverse():
    temp = msg
    print(bin(temp))
    for key in reversed(self.keyList):

      # apply p-box
      temp = self.p_box.backward(temp)
      print("pbox")
      print(bin(temp))

      # apply s-box
      leftHalf = temp>>4
      buffer = self.s_box.backward(leftHalf)
      # print(bin(buffer))
      rightHalf = temp&0b00001111
      buffer = (buffer<<4)+self.s_box.backward(rightHalf)
      # print(bin(self.s_box.backward(rightHalf)))
      # print(bin(buffer))
      temp = buffer

      # apply key
      
      # print("key:"+bin(key))
      temp = temp^key
      # print(bin(temp))


    return temp



In [None]:
test = SPN(0b00010010)
input = 0b00100011
ct = test.encrypt(input)
print(bin(ct))
pt = test.decrypt(ct)
print("decrypted plaintext")
print(bin(pt))
print("reversible:",input==pt)

0b10010010
0b10010010
pbox
0b10001001
pbox
0b100100
decrypted plaintext
0b100011
reversible: True


# Linear Cryptanalysis

## Functions used for linear cryptanalysis

In [None]:
from math import fabs, ceil
import multiprocessing
import concurrent.futures

In [None]:
def initialize(num_p_c_pairs, sbox_bits , num_sboxes, num_rounds, min_bias, max_blocks_to_bf, do_sbox_param, do_inv_sbox_param, do_pbox_param):
    global NUM_P_C_PAIRS, SBOX_BITS , NUM_SBOXES, NUM_ROUNDS, MIN_BIAS, MAX_BLOCKS_TO_BF, do_sbox, do_inv_sbox, do_pbox
    NUM_P_C_PAIRS = num_p_c_pairs
    SBOX_BITS = sbox_bits
    NUM_SBOXES = num_sboxes
    NUM_ROUNDS = num_rounds
    MIN_BIAS = min_bias
    MAX_BLOCKS_TO_BF = max_blocks_to_bf
    do_sbox = do_sbox_param
    do_inv_sbox = do_inv_sbox_param
    do_pbox = do_pbox_param

def analize_cipher():
    
    # analize the sbox and create bias table
    table = create_bias_table()
    table_sorted = sorted(table, key=lambda elem: fabs(elem[2]), reverse=True)

    # take the best max_size results
    max_size = 300
    table_len = len(table_sorted)
    if table_len > max_size:
        print('\n[*] reducing bias table size from {:d} to {:d}\n'.format(table_len, max_size))
        table_sorted = table_sorted[:max_size]

    # calculate all possible linear aproximations with a bias greater than 0
    linear_aproximations = get_linear_aproximations(table_sorted)
    # sort the list from the best approximations to the worst
    linear_aproximations_sorted = sort_linear_aproximations(linear_aproximations)
    # return the sorted list of approximations
    return linear_aproximations_sorted


# create the bias table for the sbox
def create_bias_table():

    tablesize = 1 << SBOX_BITS
   
    table = []
    for x in range(1, tablesize):
        for y in range(1, tablesize):
            matches = 0
            for num in range(tablesize):
                # calculate the parity of the number before going in the sbox
                in_mask  = apply_mask(num, x)
                # calculate the parity of the number after going out the sbox
                out_mask = apply_mask(do_sbox(num), y)
                # if the parity is the same in both cases, add a match
                if in_mask == out_mask:
                    matches += 1

            # calculate the bias
            bias = (matches / tablesize) - 1/2
            # if the bias is greater than 0, save the 'x' and 'y' combination
            if bias > MIN_BIAS:
                table.append( [x, y, bias] )

    # return the table of biases that are greater than 0
    return table

def sort_linear_aproximations(linear_aproximations):

    sorted_linear_aproximations = []
    for linear_aproximation in linear_aproximations:
        biases = linear_aproximation['biases']
        
        # calculate the resulting bias following the Piling-Up Lemma
        resulting_bias = 1
        for _, _, bias in biases:
            resulting_bias *= bias
        resulting_bias *= 1 << (len(biases) - 1)

        # construct the element of the list resulting list
        x, _, _ = biases[0]
        _, num_sbox = linear_aproximation['start']
        entry = [resulting_bias, [num_sbox, num_to_bits(x)], linear_aproximation['state']]
        # keep the entry only if has a bias grater than MIN_BIAS
        if resulting_bias > MIN_BIAS:
            sorted_linear_aproximations.append( entry )

    # sort and return the result
    sorted_linear_aproximations = sorted(sorted_linear_aproximations, key=lambda elem: fabs(elem[0]), reverse=True)
    return sorted_linear_aproximations

# calculate all the possible linear aproximations given a bias table
def get_linear_aproximations(bias_table, current_states=None, depth=1):

    # run for NUM_ROUNDS - 1 times
    if depth == NUM_ROUNDS:
        # delete elements that involve more than MAX_BLOCKS_TO_BF final sboxes
        current_states = [elem for elem in current_states if len(elem['state']) <= MAX_BLOCKS_TO_BF]
        if len(current_states) == 0:
            exit('No linear aproximations found! May be MIN_BIAS is too high or MAX_BLOCKS_TO_BF too low.')
        # return the linear aproximations that reach to no more than MAX_BLOCKS_TO_BF sboxes
        return current_states

    # at the beginnig, only one sbox can be chosen
    if depth == 1:
        # for each bias and each sbox, calculate which sboxes are reached (in the lower layer)
        # this will be the next step's new initial state
        current_states = []
        for x, y, bias in bias_table:

            for num_sbox in range(1, NUM_SBOXES + 1):

                sboxes_reached = get_destination(num_sbox, y)

                entry = {}
                entry['start']  = [depth, num_sbox]
                entry['biases'] = [[x, y, bias]]
                entry['state']  = sboxes_reached

                current_states.append( entry )
        # call the function recursevely with the new current state and new depth
        return get_linear_aproximations(bias_table, current_states, depth + 1)

    else:
        # for each set of possible states it will do the following:
        #   for each sbox that we last reached,
        #   it will calculate all possible moves according to the bias table.
        #   then it will calculate all possible the combinations of choices
        # this set of combinations, will be our next 'current_states'
        # lastly, it will call itself recursevely
        next_states = []
        for current_state in current_states:

            curr_pos = current_state['state']

            # calculate all possible moves from 'curr_sbox'
            possible_step_per_sbox = {}
            num_possible_step_per_sbox = {}
            start_sboxes = {}
            total_combinations = 1
            num_start_sboxes = 0
            for curr_sbox in curr_pos:

                inputs  = curr_pos[curr_sbox]
                Y_input = bits_to_num(inputs)

                possible_steps = []

                # only use the biases which input matches the current sbox
                possible_biases = [ elem for elem in bias_table if elem[0] == Y_input ]
                for x, y, bias in possible_biases:

                    sboxes_reached = get_destination(curr_sbox, y)

                    step = {'to': sboxes_reached, 'path': [x, y, bias]}

                    possible_steps.append(step)

                num_possible_step_per_sbox[curr_sbox] = len(possible_steps)
                if len(possible_steps) > 0:
                   total_combinations *= len(possible_steps)
                   possible_step_per_sbox[curr_sbox] = possible_steps
                   start_sboxes[num_start_sboxes] = curr_sbox
                   num_start_sboxes += 1

            if total_combinations == 0:
                continue

            # combine all the possible choises of each sbox in all possible ways
            # for example, if there are 2 sboxes and each has 4 possible moves
            # then calculate all 16 (4x4) possible combinations.

            possible_steps_combinations = []

            for comb_num in range(total_combinations):
                new_comb = []
                new_comb.append( possible_step_per_sbox[start_sboxes[0]][comb_num % num_possible_step_per_sbox[start_sboxes[0]]] )

                for sbox in start_sboxes:
                    if sbox == 0:
                        continue
                    real_sbox = start_sboxes[sbox]

                    mod = 1
                    for prev_sbox in range(sbox):
                        mod *= num_possible_step_per_sbox[start_sboxes[prev_sbox]]

                    index = (comb_num / mod) % num_possible_step_per_sbox[real_sbox]
                    index = int(index)

                    new_comb.append( possible_step_per_sbox[real_sbox][index] )
                possible_steps_combinations.append(new_comb)

            # now, for each combination, check to which sboxes we reached and what are their inputs
            # this will be the next state
            for possible_step in possible_steps_combinations:

                # save the first sbox and the previous biases
                entry = {}
                entry['start'] = current_state['start']
                entry['biases'] = current_state['biases'].copy()
                entry['state'] = {}

                # add the new biases
                for elem in possible_step:
                    entry['biases'].append( elem['path'] )

                    # add the final sboxes and their inputs
                    for destination in elem['to']:
                        if destination not in entry['state']:
                            entry['state'][destination] = []

                        new_bits = elem['to'][destination]
                        entry['state'][destination] += new_bits


                # calculate the resulting bias following the Piling-Up Lemma
                biases = entry['biases']
                resulting_bias = 1
                for _, _, bias in biases:
                    resulting_bias *= bias
                resulting_bias *= 1 << (len(biases) - 1)
                if resulting_bias >= MIN_BIAS:
                    # update the next_states
                    next_states.append( entry )


        return get_linear_aproximations(bias_table, next_states, depth + 1)


def apply_mask(value, mask):
    #retrieve the parity of mask/value
    interValue = value & mask
    total = 0
    while interValue > 0:
        temp = interValue & 1
        interValue = interValue >> 1
        if temp == 1:
            total = total ^ 1
    return total


# from an sbox and the "output" of a bias y,
# calculate which sboxs will be reached and in which bits
def get_destination(num_sbox, y):
    # pass 'y' through the permutation
    offset = (NUM_SBOXES - (num_sbox-1) - 1) * SBOX_BITS
    Y = y << offset
    # do_pbox is supposed to transpose the state, make sure is well defined!
    permuted = do_pbox(Y)

    sboxes_reached = {}
    # sboxes go from 1 to NUM_SBOXES from left to right
    # bits go from 1 to SBOX_BITS from left to right
    for sbox in range(1, NUM_SBOXES + 1):
        for bit in range(SBOX_BITS):
            bits_offset = ((NUM_SBOXES - (sbox-1) - 1) * SBOX_BITS) + bit
            # if 'sbox' has a 1 in the position 'bit' then take note of that
            if permuted & (1 << bits_offset) != 0:
                if sbox not in sboxes_reached:
                    sboxes_reached[sbox] = []
                sboxes_reached[sbox].append(SBOX_BITS - bit)
    # return which sboxes where reached and in which bit
    return sboxes_reached

# convert a list of bits to an integer
def bits_to_num(inputbits):
    Y_input = 0
    for input_pos in inputbits:
        Y_input |= 1 << (SBOX_BITS - input_pos)
    return Y_input

# convert an integer to a list of it's bits
def num_to_bits(num):
    bits = []
    for index in range(SBOX_BITS):
        if (1 << index) & num > 0:
            bits.append( SBOX_BITS - index )
    return bits

# this function eliminates the linear approximations that have
# a bias below the MIN_BIAS threshold and then sorts the results



def bit(num, n):
    # get the nth bit of num
    return (num >> (SBOX_BITS - n)) & 1

def get_xor(plaintext, ciphertext, key, linear_aproximation):
    _, p_data, c_data = linear_aproximation

    # get the plaintext block
    p_block_num, p_bits = p_data
    pt = plaintext >> ((NUM_SBOXES - p_block_num) * SBOX_BITS)
    pt = pt & ((1 << SBOX_BITS) - 1)

    # calculate the plaintext's part of the xor
    xor_pt = 0
    for b in p_bits:
        xor_pt = xor_pt ^ bit(pt, b)

    # for each final sbox, get the according ciphertext block
    xor_u = 0
    i = len(c_data) - 1
    for c_block_num in c_data:
        c_bits = c_data[c_block_num]

        # get the ciphertext block
        ct = ciphertext >> ((NUM_SBOXES - c_block_num) * SBOX_BITS)
        ct = ct & ((1 << SBOX_BITS) - 1)

        # get the key block that corresponds with the sbox
        k = key >> (i * SBOX_BITS)
        k = k & ((1 << SBOX_BITS) - 1)

        # xor the key and the ciphertext to get v (the sbox output)
        v = ct ^ k

        # get the sbox input
        # do_inv_sbox is supposed to calculate the inverse of the substitution, make sure is well defined!
        u = do_inv_sbox(v)

        # calculate the input of the sbox's part of the xor
        for b in c_bits:
            xor_u = xor_u ^ bit(u, b)

        i -= 1
    # return the result of the full xor
    return xor_pt ^ xor_u

def get_biases_for_key_space(keystart, keyend, p_c_pairs, linear_aproximation):

    try:
        hits = [0] * (keyend - keystart)
    except OverflowError:
        exit('the amount of key bits to brute force is too large.')

    # get the result of the aproximation for each possible key
    for key in range(keystart, keyend):
        for plaintext, ciphertext in p_c_pairs:
            xor = get_xor(plaintext, ciphertext, key, linear_aproximation)
            if xor == 0:
                hits[keystart - key] += 1

    result = {'start': keystart, 'end': keyend, 'hits': hits}
    return result

def get_biases(p_c_pairs, linear_aproximation):

    # calculate how many key bits must be brute forced
    key_bits = len(linear_aproximation[2]) * SBOX_BITS
    try:
        # get the key's maximum size
        key_max  = 1 << key_bits
    except MemoryError:
        exit('the amount of key bits to brute force is too large.')

    num_cores = multiprocessing.cpu_count()

    sub_key_space = key_max // num_cores

    bias_lists = []

    # run in num_cores threads
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_cores) as executor:

        future_list = []
        # divide the key space into num_cores parts
        for core in range(num_cores):
            start = sub_key_space * core
            end   = start + sub_key_space

            future = executor.submit(get_biases_for_key_space, start, end, p_c_pairs, linear_aproximation)
            future_list.append(future)

        # get the result for each thread
        for future in concurrent.futures.as_completed(future_list):
            bias_lists.append(future.result())

    try:
        hits = [0] * key_max
    except OverflowError:
        exit('the amount of key bits to brute force is too large.')

    # join all the results
    for result in bias_lists:
        start = result['start']
        end   = result['end']
        array = result['hits']
        for hit in range(start, end):
            hits[hit] = array[start - hit]

    # calculate the bias for each key
    bias = [ fabs(num_hits - float(NUM_P_C_PAIRS/2)) / float(NUM_P_C_PAIRS) for num_hits in hits ]

    return bias


## Demonstration of linear cryptanalysis

In [None]:
NUM_P_C_PAIRS = 256
SBOX_BITS  = 4
NUM_SBOXES = 2
NUM_ROUNDS = 2
MIN_BIAS = 0.008
MAX_BLOCKS_TO_BF = 3

s_box = Sbox()
p_box = Pbox()

initialize(NUM_P_C_PAIRS,
      SBOX_BITS,
      NUM_SBOXES,
      NUM_ROUNDS,
      MIN_BIAS,
      MAX_BLOCKS_TO_BF,
      s_box.forward,# do_sbox,
      s_box.backward,# do_inv_sbox,
      p_box.forward# do_pbox
      )

print('Cipher being analysed...\n')
linear_aproximations = analize_cipher()

print('The followings are the best linear aproximations:')

for i in range(10):
    try:
         print(linear_aproximations[i])
    except IndexError:
         break



print('\nThis linear aproximation with the best bias will be used:')
linear_aproximation = linear_aproximations[0]


end_sboxs = ', '.join(list(map(str, linear_aproximation[2])))
print('ε: {:f}\nstart: sbox n°{:d}\nend sboxes:{}'.format(linear_aproximation[0], linear_aproximation[1][0], end_sboxs))
Nl = ceil( pow(pow(linear_aproximation[0], -1), 2))

print('\nThe following key blocks will be recovered:{}'.format(' '.join(list(map(str, list(linear_aproximation[2].keys()))))))


key = keyGeneration()
print("\nOriginal key: ",bin(key))
subkey = keyExtend(key)
subkey = [bin(n) for n in subkey]
print("sub key:",subkey)


spn = SPN(key)
p_c_pairs = []
for pt in range(NUM_P_C_PAIRS):
    p_c_pairs.append( [pt, spn.encrypt(pt)] )

print('\nBreaking the cypher...\n')
# obtain the biases given the p/c pairs and the linear aproximation
biases = get_biases(p_c_pairs, linear_aproximation)

print("biases",biases)
# get the key with the most hits
maxResult, maxIdx = 0, 0
for rIdx, result in enumerate(biases):
    if result > maxResult:
        maxResult = result
        maxIdx    = rIdx
print("maxResult",maxResult)
print("maxIdx",maxIdx)

print('\nResult:')
bits_found = '{:b}'.format(maxIdx).zfill(len(linear_aproximation[2])*SBOX_BITS)
bits_found = [bits_found[i:i+SBOX_BITS] for i in range(0, len(bits_found), SBOX_BITS)]

blocks_num = list(linear_aproximation[2].keys())

zipped = list(zip(blocks_num, bits_found))

print('Key bits might be:')
for num_block, bits in zipped:
    print('block {:d}: {}'.format(num_block, bits))



Cipher being analysed...

The followings are the best linear aproximations:
[0.25, [1, [4]], {1: [2]}]
[0.25, [2, [4]], {2: [1]}]
[0.25, [1, [4]], {2: [3]}]
[0.25, [2, [4]], {1: [4]}]
[0.25, [1, [4, 2]], {1: [2], 2: [4]}]
[0.25, [2, [4, 2]], {1: [3], 2: [1]}]
[0.25, [1, [3, 2]], {1: [2], 2: [4, 2]}]
[0.25, [2, [3, 2]], {1: [3, 1], 2: [1]}]
[0.25, [1, [3, 2]], {1: [2], 2: [3]}]
[0.25, [2, [3, 2]], {1: [4], 2: [1]}]

This linear aproximation with the best bias will be used:
ε: 0.250000
start: sbox n°1
end sboxes:1

The following key blocks will be recovered:1
0b10111011

Original key:  0b10111011
sub key: ['0b10110000', '0b1011']

Breaking the cypher...

biases [0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125]
maxResult 0.0625
maxIdx 0

Result:
Key bits might be:
block 1: 0000
