In [161]:
import time
import os
import sh
import subprocess

from multiprocessing import cpu_count, Pool, Manager
from contextlib import closing
from collections import Counter

In [162]:
texts, key_limit, keys, alpha_length = [2000, 2**16, [i for i in range(2**16)], 60] 

possible_alphas = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 32, 48,
                    64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 512,
                    768, 1024, 1280, 1536, 1792, 2048, 2304, 2560, 2816, 3072, 3328, 3584,
                    3840, 4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 36864, 
                    40960, 45056, 49152, 53248, 57344, 61440]

In [163]:
class Heys: 
    def __init__(
            self,
            key,
            s_block = [0x3,0x8,0xb,0x5,0x6,0x4,0xe,0xa,0x2,0xc,0x1,0x7,0x9,0xf,0xd,0x0],
            s_block_reverse = [0xf, 0xa, 0x8, 0x0,0x5, 0x3, 0x4, 0xb, 0x1, 0xc, 0x7, 0x2,0x9, 0xe, 0x6, 0xd],
            rounds = 6
        ):
        self.key, self.s_block, self.s_block_reverse,  self.rounds = [self.get_round_keys(key, rounds), s_block,  s_block_reverse, rounds]


    def encrypt(self, text):
        i = 0
        while i < self.rounds:
            text = self.round(text, self.key[i])
            i += 1
        t=text ^ self.key[self.rounds]
        return t



    def decrypt(self, c_text):
        i = self.rounds
        while i >= 1:
            c_text = self.round_reverse(c_text, self.key[i])
            i -= 1
        t=c_text ^ self.key[0]
        return t
    
    def round(self, text_block, round_key): 
        return self.permutation(self.substitution(text_block ^ round_key, self.s_block))


    def round_reverse(self, c_text_block, round_key):
        return self.substitution(self.permutation(c_text_block ^ round_key), self.s_block_reverse)


    @staticmethod	
    def substitution(block, s_block):
        result = 0
        i = 0
        while i < 4:
            bits = (block >> (i * 4)) & 0xf
            substituted_bits = s_block[bits] << (i * 4)
            result = result | substituted_bits
            i += 1
        return result



    @staticmethod	
    def permutation(block):
        result = 0
        i = 0
        while i < 4:
            j = 0
            while j < 4:
                bit = (block >> (i * 4 + j)) & 1
                shifted_bit = bit << (j * 4 + i)
                result = result | shifted_bit
                j += 1
            i += 1
        return result



    @staticmethod 
    def get_round_keys(key, rounds):
        round_keys = []
        for i in range(rounds + 1):
            round_key = (key >> (i * 16)) & 0xffff
            round_keys.append(round_key)
        return round_keys


In [164]:
cypher = Heys(0x3030303030303030303030303030)

In [165]:
def get_gamma_q_pairs(betta):
    gammas = [cypher.round(text, 0x0001) ^ cypher.round(text ^ betta, 0x0001) for text in range(key_limit)]
    return [(gamma, count / key_limit) for gamma, count in Counter(gammas).items()]

In [166]:
def diff_search(alpha, p_break=[1, 0.1, 0.008, 0.004, 0.001, 0.0004], rounds=6): 
    G = [[]] * rounds
    G[0] = [(alpha, 1.0)]
    t = 1
    while t < rounds:
        i = 0
        while i < len(G[t - 1]):
            beta, p = G[t - 1][i]
            gamma_q = get_gamma_q_pairs(beta)
            j = 0
            while j < len(gamma_q):
                gamma, q = gamma_q[j]
                G_t = G[t].copy()
                if gamma in [g[0] for g in G_t]:
                    p_gamma = next((p_g[1] for p_g in G_t if p_g[0] == gamma), None)
                    G_t = [(g[0], g[1] + (p * q)) if g[0] == gamma else g for g in G_t]
                else:
                    G_t.append((gamma, p * q))
                G[t] = G_t.copy()
                j += 1
            i += 1
        G_t = G[t].copy()
        G[t] = [(g[0], g[1]) for g in G[t] if g[1] > p_break[t]]
        t += 1
    return G[rounds - 1]

In [167]:
def diff_search_alpha_iter(possible_alpha):
    return [((possible_alpha, betta[0]), betta[1]) for betta in diff_search(possible_alpha)]

In [168]:
def create_c_p_texts(arg): 
    alpha,text=arg
    # current_directory = os.getcwd()
    # folder_names = ['texts']

    # for folder_name in folder_names:
    #     folder_path = os.path.join(current_directory, folder_name)

    # if not os.path.exists(folder_path):
    #     os.makedirs(folder_path)

    # dir = os.path.abspath(folder_path)

    with open(f'./texts/pt_{text}.bin', 'wb') as f:
        f.write(int(text).to_bytes(2, 'little'))
    with open(f'./texts/ptx_{text}.bin', 'wb') as pt:
        pt.write(int(text ^ alpha).to_bytes(2, 'little'))

    script = sh.Command('./script.sh')

    key_file = ""

    input_file =      f'./texts/pt_{text}.bin'
    output_file =     f'./texts/ct_{text}.bin'
    xor_input_file =  f'./texts/ptx_{text}.bin'
    xor_output_file=  f'./texts/cta_{text}.bin'
    
    script(input_file, output_file, key_file)
    script(xor_input_file, xor_output_file,key_file)


In [169]:
def read_c_text_pairs(texts): #+
    current_directory = os.getcwd()
    return [[int.from_bytes(open(f'texts/ct_{text}.bin', 'rb').read(), 'little'),
             int.from_bytes(open(f'texts/cta_{text}.bin', 'rb').read(), 'little')]
            for text in range(texts)]

In [170]:
iteration_keys = Manager().dict()

In [171]:
def last_round_attack(arg):
    c_text_pair, betta, key = arg
    i = 0
    while i < len(c_text_pair):
        c1, c2 = c_text_pair[i]
        betta_calc = cypher.round_reverse(c1, key) ^ cypher.round_reverse(c2, key)
        if betta == betta_calc:
            if key in iteration_keys:
                iteration_keys[key] += 1
            else:
                iteration_keys[key] = 1
        i += 1

In [172]:
import multiprocessing

if __name__ == '__main__':

    print("Search differentials:")

    differentials = []
    
    num_processes = multiprocessing.cpu_count()

    with multiprocessing.Pool(processes=num_processes) as pool:
        differentials = pool.map(diff_search_alpha_iter,possible_alphas)
        pool.close()
        pool.join()

    differentials = [item for sublist in differentials for item in sublist if isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], tuple) and isinstance(item[1], float)]
    differentials = sorted(differentials, key = lambda x: x[1], reverse = True)

    print(differentials)

    print('#######################')

    iteration = 0

    while True:

        alpha, betta = [differentials[iteration][0][0], differentials[iteration][0][1]]

        print(f'Iteration {iteration} START')

        #---------------------------------------------------------
        num_processes = multiprocessing.cpu_count()

        with multiprocessing.Pool(processes=num_processes) as pool:
            arg=[(alpha,l) for l in range(texts)]
            pool.map(create_c_p_texts,arg)
            pool.close()
            pool.join()

        c_text_pairs = read_c_text_pairs(texts)

        #---------------------------------------------------------

        num_processes = multiprocessing.cpu_count()

        with multiprocessing.Pool(processes=num_processes) as pool:
            arg=[(c_text_pairs, betta, key) for key in keys]
            pool.map(last_round_attack,arg)
            pool.close()
            pool.join()

        #---------------------------------------------------------

        max_weight_keys = [hex(k) for k,v in iteration_keys.items() if v == max(iteration_keys.values())]
        max_weight_keys.sort()

        print('Alpha: ', alpha, 'Beta: ', betta)
        print('Keys with max weight: ', max_weight_keys)
        print('----------------------')
        s=len(max_weight_keys)
        if  s== 1:
            break
        iteration += 1

            


Search differentials:
[((6, 2056), 0.00048828125), ((6, 34816), 0.00048828125), ((96, 4352), 0.00048828125), ((96, 257), 0.00048828125), ((96, 1028), 0.00048828125), ((96, 17408), 0.00048828125), ((96, 514), 0.00048828125), ((96, 8704), 0.00048828125), ((24576, 1028), 0.00048828125), ((24576, 17408), 0.00048828125), ((24576, 2056), 0.00048828125), ((24576, 34816), 0.00048828125), ((6, 34952), 0.00042724609375), ((6, 32904), 0.00042724609375), ((6, 136), 0.00042724609375), ((96, 4369), 0.00042724609375), ((96, 17), 0.00042724609375), ((96, 4113), 0.00042724609375), ((96, 17476), 0.00042724609375), ((96, 16452), 0.00042724609375), ((96, 68), 0.00042724609375), ((96, 8738), 0.00042724609375), ((96, 8226), 0.00042724609375), ((96, 34), 0.00042724609375), ((24576, 17476), 0.00042724609375), ((24576, 16452), 0.00042724609375), ((24576, 68), 0.00042724609375), ((24576, 34952), 0.00042724609375), ((24576, 32904), 0.00042724609375), ((24576, 136), 0.00042724609375)]
#######################
Iter