# 差分分析实验一的分组密码算法。要求：
1. 分组长度 = 128 bit
2. 主密钥长度 >= 128 bit
3. 轮数 =1
4. 附加：分组算法轮函数引入S盒

## 首先对实验一对算法加以改造
+ 修改轮数为1
+ S盒

In [116]:
BLOCK_SIZE = 128 //8
LOOP = 1

sbox = [3, 14, 1, 10, 4, 9, 5, 6, 8, 11, 15, 2, 13, 12, 0, 7]
sbox_rev = [14, 2, 11, 0, 4, 6, 7, 15, 8, 5, 3, 9, 13, 12, 1, 10]
def F_function(input_bytes: bytes, key: bytes):
    out = []
    for i in range(len(input_bytes)):
        # print(hex(input_bytes[i]))
        # replaced = sbox[input_bytes[i]>>4]+sbox[input_bytes[i]&0xf]
        # replaced = input_bytes[i]
        # print(hex(input_bytes[i]),hex(replaced))
        replaced = sbox[input_bytes[i]>>4]
        out.append(replaced ^ key[i % len(key)])
        replaced = sbox[input_bytes[i]&0xf]
        out.append(replaced ^ key[i % len(key)])
    # print(out)
    # print(bytes(out).hex())
    return bytes(out)



In [117]:
def feistel_encrypt(input_bytes: bytes, key, F):
    output = bytes()
    left_half = input_bytes[BLOCK_SIZE // 2:]
    right_half = input_bytes[:BLOCK_SIZE // 2]

    left_half_output = right_half
    function_output = F(left_half_output, key)  # TODO i need a key generate function
    right_half_output = [pair1 ^ pair2 for pair1, pair2 in zip(left_half, function_output)]

    output += bytes(left_half_output)
    output += bytes(right_half_output)
    return output


def feistel_decrypt(input_bytes: bytes, key, F):
    output = bytes()
    left_half = input_bytes[:BLOCK_SIZE // 2]
    right_half = input_bytes[BLOCK_SIZE // 2:]

    left_half_output = right_half
    function_output = F(left_half_output, key)  # TODO i need a key generate function
    right_half_output = [pair1 ^ pair2 for pair1, pair2 in zip(left_half, function_output)]

    output += bytes(right_half_output)
    output += bytes(left_half_output)
    return output

# 交换左右部分的字节
def exchange(input: bytes):
    a = input[BLOCK_SIZE // 2:]
    a += input[:BLOCK_SIZE // 2]
    return a
def get_bits(num):
    bits = []
    while num > 0:
        bits.append(num & 1)  # 使用位与操作符提取最低位的值
        num >>= 1  # 将数字右移一位
    bits.reverse()  # 将结果反转，使得最高位在最前面
    return bits

class LFSR:
    def __init__(self, tap_positions, seed:int):
        self.tap_positions = tap_positions
        self.seed = seed
        self.register = seed
        self._move_length_ = seed.bit_length()-1
        self.debug = False

    def shift_get_bytes(self):
        shift = self.shift()
        return shift.to_bytes((self._move_length_+1+7)//8,"big")
    def reset(self):
        self.register =self.seed
    def shift(self):
        if self.debug:
            print("before: {}".format(get_bits(self.register)))
        feedback = 0
        for position in self.tap_positions:
            feedback ^= (self.register >> position) & 1
        self.register = (self.register >> 1) | (feedback << self._move_length_)
        if self.debug:
            print("after : {}".format(get_bits(self.register)))
            print("--------------")
        return self.register

def block_encrypt(block: bytes,lfsr:LFSR):
    assert len(block) == BLOCK_SIZE
    cipher = block
    for n in range(LOOP):
        cipher = feistel_encrypt(cipher, lfsr.shift_get_bytes(), F_function)
    cipher = exchange(cipher)
    return bytes(cipher)


def block_decrypt(block: bytes,lfsr:LFSR):
    msg_ = block
    keys = [lfsr.shift_get_bytes() for _ in range(LOOP)]
    for n in range(LOOP, 0, -1):
        msg_ = feistel_decrypt(msg_, keys[(n - 1) % len(keys)], F_function)
    msg_ = exchange(msg_)
    return bytes(msg_)
def slice_arr(arr, size):
    s = []
    for i in range(0, int(len(arr)) + 1, size):
        c = arr[i:i + size]
        if not len(c) == 0:
            s.append(c)
    return s


def encrypt(message: bytes,key:bytes,tap_positions):
    blocks = slice_arr(message, BLOCK_SIZE)
    encrypted_blocks = bytes()
    tail_fill = 0
    lfsr = LFSR(tap_positions,int.from_bytes(key,"big"))
    for enc_ in blocks:
        if len(enc_) < BLOCK_SIZE:
            tail_fill = BLOCK_SIZE - len(enc_)
            encrypted_blocks += (block_encrypt(enc_.zfill(BLOCK_SIZE),lfsr))
        else:
            encrypted_blocks += (block_encrypt(bytes(enc_),lfsr))
    return encrypted_blocks, tail_fill


def decrypt(message: bytes,key:bytes,tap_positions, tail_fill: int):
    blocks = slice_arr(message, BLOCK_SIZE)
    decrypted_blocks = bytes()
    lfsr = LFSR(tap_positions,int.from_bytes(key,"big"))
    for n in range(len(blocks) - 1):
        decrypted_blocks += block_decrypt(bytes(blocks[n]),lfsr)
    decrypted_blocks += (block_decrypt(blocks[-1],lfsr))[tail_fill:]
    return bytes(decrypted_blocks)


In [118]:
msg_origin = b"abcdefg"
key = b"not happy"
tap_position = [1,3,4,9]
enc, fill = encrypt(msg_origin,key,tap_position)
msg = decrypt(enc,key,tap_position, fill)
print(enc.hex())
print(msg)
assert msg==msg_origin

0d555f57d4dc7c743030303030303030
b'abcdefg'


## 差分分析
### 生成DDT

In [119]:
DDT = [[0 for _ in range(16)] for _ in range(16)]
for c in range(16):
    for d in range(16):
        DDT[c ^ d][sbox[c] ^ sbox[d]] += 1
DDT

[[16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 2, 0, 4, 0, 0, 0, 2, 0, 0, 0, 2, 0, 6, 0, 0],
 [0, 2, 2, 0, 2, 0, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2],
 [0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 4, 0, 4, 0, 0, 2],
 [0, 0, 0, 0, 2, 4, 0, 6, 0, 0, 0, 0, 2, 0, 0, 2],
 [0, 0, 2, 0, 2, 0, 2, 2, 2, 0, 4, 0, 0, 0, 0, 2],
 [0, 0, 2, 2, 0, 2, 2, 0, 4, 0, 0, 0, 2, 0, 2, 0],
 [0, 0, 0, 2, 0, 2, 0, 0, 2, 0, 0, 4, 0, 0, 2, 4],
 [0, 2, 0, 0, 0, 6, 0, 0, 2, 2, 0, 2, 0, 0, 2, 0],
 [0, 0, 2, 2, 2, 2, 4, 0, 4, 0, 0, 0, 0, 0, 0, 0],
 [0, 2, 0, 0, 2, 0, 0, 0, 2, 2, 2, 0, 4, 0, 2, 0],
 [0, 4, 2, 2, 0, 0, 0, 0, 0, 4, 2, 2, 0, 0, 0, 0],
 [0, 2, 4, 0, 2, 0, 0, 0, 0, 0, 2, 0, 2, 2, 2, 0],
 [0, 2, 0, 2, 0, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 4],
 [0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 0, 4, 2, 4, 0, 0],
 [0, 0, 0, 0, 2, 0, 4, 2, 0, 0, 0, 0, 0, 2, 6, 0]]

### 分析输出差分

In [153]:
msg_origin = b"\x00\x00\x01\x10"
msg_origin_xor = bytes([b ^ 4 for b in msg_origin])
key = b"\x12"
tap_position = [1,3,4,9]
enc, fill = encrypt(msg_origin,key,tap_position)
enc1, fill1 = encrypt(msg_origin_xor,key,tap_position)
# msg = decrypt(enc,key,tap_position, fill)
print(msg_origin)
print(msg_origin_xor)
print(enc.hex())
print(enc1.hex())
for i,v in enumerate(enc):
    # print(hex(enc[i]))
    print((enc[i]>>4)^(enc1[i]>>4))
    print((enc[i]&0xf)^(enc1[i]&0xf))
    if i//2 > len(enc)-fill:
        break


b'\x00\x00\x01\x10'
b'\x04\x04\x05\x14'
333a333a030a021a3030303030303030
333a333a070e061e3030303030303030
0
0
0
0
0
0
0
0
0
4
0
4
0
4
0
4
0
0
0
0
0
0


In [114]:

chardat0 = [0] * 16
chardatmax = 0

def gen_char_data(indiff, outdiff):
    print(f"\nGenerating possible intermediate values based on differential({indiff} --> {outdiff}):")
    
    global chardatmax
    chardatmax = 0
    for f in range(16):
        my_comp = f ^ indiff
        
        if (sbox[f] ^ sbox[my_comp]) == outdiff:
            print(f"  Possibles:   {f} ^ {indiff} = {my_comp} --> {sbox[f]} + {sbox[my_comp]}")
            chardat0[chardatmax] = f
            chardatmax += 1
gen_char_data(6,6)
chardat0


Generating possible intermediate values based on differential(6 --> 6):
  Possibles:   0 ^ 6 = 6 --> 3 + 5
  Possibles:   6 ^ 6 = 0 --> 5 + 3


[0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]