## Import the analysed cipher

In [None]:
from tabulate import tabulate
from present import *

## Initialise cipher and helper functions

In [None]:
presKey = bytes.fromhex("0123456789abcdef0123456789abcdef")
present = Present(presKey, 2)

### Reversing the final permutation as it does not do anything for analysis
def presentEncrypt(p):
    return pLayer_dec(int.from_bytes(present.encrypt(p), 'big'))

### Apply permutation back before attempting decryption
def presentDecrypt(c):
    return present.decrypt(number2string_N(pLayer(c), 8))

### Present decryption working on a single S-box
def prstDecrypt(block, roundkeys, shift):
        rounds = 2
        state = string2number(block)
        for i in range(rounds - 1):
            state = addRoundKey(state, pLayer(roundkeys[-i - 1] << (60 - 4*shift)))
            state = pLayer_dec(state)
            state = sBoxLayer_dec(state)
        decipher = addRoundKey(state, (roundkeys[0] << (60 - 4*shift)))
        return number2string_N(decipher, 8)


# Differential & Impossible Differential cryptanalysis

In [None]:
### Helper encryption
def cipher(p, box, k=(0,2)):
    return box[p^k[0]]^k[1]

### Helper decryption
def decipher(c, box, k=(0,2)):
    return box.index(c^k[1])^k[0]

### Check key candidate against all text pairs
def checkKey(keys, box, shift):
    for i in range(len(box)):
        c = presentEncrypt((i << (60 - 4 * shift)).to_bytes(8,'big'))
        if(i != int(bin(int.from_bytes(prstDecrypt(number2string_N(pLayer(c),8), keys, shift),"big"))[2:].rjust(64,"0")[shift*4:shift*4+4],2)):
            return False
    return True      

### Guess the key candidate
def calcKey(box, pairs, text, shift):
    for val in pairs:
        key0 = val^text
        key1 = box[val]^int(bin(presentEncrypt((text << (60-shift*4)).to_bytes(8,'big')))[2:].rjust(64,'0')[:4],2)
        if(checkKey((key0, key1),box,shift)):
            return (key0, key1)
    return None

### Find all possible pairs for given input/output XOR and compute XOR 1 for each first value in a pair
def findPairs(box, inXor, outXor, shift):
    pairs = list()
    textPairs = list()
    for i in range(len(box)):
        j = i^inXor
        if(box[i]^box[j] == outXor):
            pairs.append(i)
        if(int(bin(presentEncrypt((i  << (60-shift*4)).to_bytes(8,'big')))[2:].rjust(64,'0')[shift*4:shift*4+4],2)^int(bin(presentEncrypt((j  << (60-shift*4)).to_bytes(8,'big')))[2:].rjust(64,'0')[shift*4:shift*4+4],2) == outXor):
            textPairs.append(i)
    return tuple(pairs), tuple(textPairs)

### Find pairs for an immpossible path
def findImpPairs(box, inXor, outXor, shift):
    pairs = list()
    textPairs = list()
    for i in range(len(box)):
        j = i^inXor
        if(box[i]^box[j] != outXor):
            pairs.append(i)
        if(int(bin(presentEncrypt((i  << (60-shift*4)).to_bytes(8,'big')))[2:].rjust(64,'0')[shift*4:shift*4+4],2)^int(bin(presentEncrypt((j  << (60-shift*4)).to_bytes(8,'big')))[2:].rjust(64,'0')[shift*4:shift*4+4],2) != outXor):
            textPairs.append(i)
    return tuple(pairs), tuple(textPairs)

### Find all non-zero values in XOR profile and sort them by probability and input/output XOR values
def findCommon(profile):
    charact = list()
    for i, row in enumerate(profile):
        for j, val in enumerate(row):
            if(val > 0):
                charact.append((val/profile[0][0], f'({val}/{profile[0][0]})', hex(i)[2:], hex(j)[2:]))
    return sorted(charact, key=(lambda x: (1-x[0],int(x[2],16),int(x[3],16))))
        
### Determine XOR profile
def XORprofile(box):
    profile = [[0]*len(box) for _ in range(len(box))]
    for i in range(len(box)):
        for j in range(len(box)):
            profile[i^j][box[i]^box[j]] += 1
    return profile

### Read S-box from file   
def parseBox(text):
    text = text.split()
    return tuple(map(lambda x: int(x, 16), text))

## DDT generation

In [None]:
with open(f'present.sbx', 'r') as file, open(f'present_prof2', "w") as profOut, open(f'present_char2', "w") as charOut, open('present_keys2', "a") as keysOut:
    box = parseBox(file.read())
    profile = XORprofile(box)
    print(tabulate(profile, headers=[hex(i)[2:] for i in range(profile[0][0])], showindex=[hex(i)[2:] for i in range(profile[0][0])], tablefmt='rst'), file=profOut)
    charact = findCommon(profile)
    print(tabulate(charact, showindex=[f'{i}:' for i in range(1,len(charact)+1)], tablefmt='rst'), file=charOut)
    print(f'KS = KEYS(01, {hex(int(charact[2][2],16)^1)[2:].rjust(2,"0")}, {charact[2][3]})', file=keysOut)
    #print(" ".join(findPairs(box, int(charact[2][2], 16), int(charact[2][3], 16))), file=keysOut)

## Key recovery

In [None]:
rdyKeys = list()
for i in range(16):
    for j in range(1, 26):
        pairs, textPairs = findPairs(box, int(charact[j][2], 16), int(charact[j][3], 16), i)
        print(f'Pairs {i}, {j}: {pairs}, {textPairs}')
        for text in textPairs:
            key = calcKey(box, pairs, text, i)
            if(key != None):
                print(f'Charact {j} used: { charact[j]}')
                break
        if(key != None):
            break
    print(f'Key {i}: {key}')
    if(key != None):
        rdyKeys.append(key)

In [None]:
### Join all found partial subkeys to form the subkeys
key0 = ''
key1 = ''
for keyPair in rdyKeys:
    key0 += bin(keyPair[0])[2:].rjust(4,'0')
    key1 += bin(keyPair[1])[2:].rjust(4,'0')
key0 = int(key0, 2)
key1 = pLayer(int(key1, 2))
(key0, key1)

# Attempted 2-round differential cryptanalysis

In [None]:
### SP-layer of Present
def spLayer(block):
    state = sBoxLayer(block)
    state = pLayer(state)
    return state

### Find pairs for SP-layer
def findPairsSP(inXor, outXor):
    pairs = list()
    for i in range(2**32):
        j = i^inXor
        if(spLayer(i)^spLayer(j) == outXor):
            pairs.append((i,j))
    return tuple(pairs)

### Find pairs for S-box layer
def findPairsS(inXor, outXor):
    pairs = list()
    for i in range(2**5):
        j = i^inXor
        if(sBoxLayer(i)^sBoxLayer(j) == outXor):
            pairs.append((i,j))
    return tuple(pairs)

In [None]:
### Attempt to generate pairs based on a known characteristic
pairs = findPairs(int('4004',16), int('400000004',16))
pairs0 = findPairsSP(int('4004',16), int('900000009',16))
pairs1 = findPairsS(int('9',16), int('4',16))
pairs1 = tuple(map(lambda x: (int(hex(x[0])[2:][-1] + hex(x[0])[2:],16),int(hex(x[1])[2:][-1] + hex(x[1])[2:],16)), pairs1))

# Linear cryptanalysis

In [None]:
### Helper function for checking bit parity
def checkBitParity(val):
    total = 0
    while(val):
        total ^= val & 1
        val >>= 1
    return total

In [None]:
def testKey(keys, p, c, shift):
    for pair in zip(p,c):
        if(pair[0] != int(bin(int.from_bytes(prstDecrypt(number2string_N(pLayer(pair[1] << 60-shift*4),8), keys, shift),"big"))[2:].rjust(64,"0")[shift*4:shift*4+4],2)):
            return False
    return True

### Check guessed keys against all text pairs
def checkKey(key0, box, p, c, shift):
    for i, text in enumerate(p):
        mid = box[text^key0]
        key1 = mid^c[i]
        if(testKey((key0,key1), p, c, shift)):
            return (key0,key1)
    return None

### Guess keys            
def calcKey(keys, box, p, c, shift):
    for key in keys:
        found = checkKey(key, box, p, c, shift)
        if(found != None):
            return found
    return None

### Find the keys with highest grades
def findMaxKeys(keys):
    grades = tuple(map(lambda x: x**2, keys))
    maxK = max(grades)
    outKeys = list()
    for i, key in enumerate(grades):
        if(maxK == key):
            outKeys.append(i)
    return outKeys

### Grade all keys based on an approximation
def gradeKeys(p, c, box, inApprox, outApprox):
    keys = [0]*16
    for key in range(len(keys)):
        for i, char in enumerate(p):
            mid = box[char^key]
            if(checkBitParity(mid&inApprox) == checkBitParity(c[i]&outApprox)):
                keys[key] += 1
            else:
                keys[key] -= 1
    return keys

## Linear approximation profile

In [None]:
### Sort approximations
def findCommon(profile):
    charact = list()
    for i, row in enumerate(profile):
        for j, val in enumerate(row):
            if(val > 0):
                charact.append((val/profile[0][0], f'({val}/{profile[0][0]})', hex(i)[2:], hex(j)[2:]))
    return sorted(charact, key=(lambda x: (1-x[0],int(x[2],16),int(x[3],16))))

### Generate LAT
def linearProfile(box):
    profile = [[0]*len(box) for _ in range(len(box))]
    for i in range(1,len(box)):
        for j in range(1,len(box)):
            for k in range(len(box)):
                if(checkBitParity(k&j) == checkBitParity(box[k]&i)):
                    profile[j][i] += 1
    profile[0][0] = len(box)
    return profile

def parseBox(text):
    text = text.split()
    return tuple(map(lambda x: int(x, 16), text))

In [None]:
with open('present.sbx', 'r') as file, open(f'present/lin_prof', "w") as profOut, open(f'present/lin_char', "w") as charOut, open('present/lin_keys', "a") as keysOut:
    box = parseBox(file.read())
    print(tabulate([box], showindex=True, headers=range(len(box)), tablefmt='grid'))
    profile = linearProfile(box)
    #for i, line in enumerate(profile):
     #   profile[i] = tuple(map(lambda x: x-16, line))
    print(tabulate(profile, headers=[hex(i)[2:] for i in range(len(profile[0]))], showindex=[hex(i)[2:] for i in range(len(profile))], tablefmt='rst'))
    charact = findCommon(profile)
    #print(tabulate(charact, showindex=[f'{i}:' for i in range(1,len(charact)+1)], tablefmt='rst'), file=charOut)

## Key recovery

In [None]:
rdyKeys = list()
maxProb = float(charact[1][0])
for i in range(16):
    ### Generate all text pairs
    p = tuple(x for x in range(16))
    c = tuple(map(lambda x: int(bin(presentEncrypt((x << 60-i*4).to_bytes(8,'big')))[2:].rjust(64,'0')[(i*4):(i*4)+4],2), p))
    print(f'Pairs {i}: {p}, {c}')
    ### Attempt key recovery
    for j in range(1, len(charact)):
        keyGrades = gradeKeys(p, c, box, int(charact[j][2], 16), int(charact[j][3], 16))
        maxKeys = findMaxKeys(keyGrades)
        print(f'Max {i}, {j}: {maxKeys}')
        key = calcKey(maxKeys, box, p, c, i)
        if(key != None):
            print(f'Charact {j} used: {charact[j]}')
            print(f'Key {i}: {key}')
            break
    if(key != None):
        rdyKeys.append(key)

In [None]:
### Reassemble key
key0 = ''
key1 = ''
for keyPair in rdyKeys:
    key0 += bin(keyPair[0])[2:].rjust(4,'0')
    key1 += bin(keyPair[1])[2:].rjust(4,'0')
key0 = int(key0, 2)
key1 = pLayer(int(key1, 2))