In [260]:
import os
import json
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from Crypto.Util.strxor import strxor as xor

class TightSchedule:
    S = [
            0x63,0x7C,0x77,0x7B,0xF2,0x6B,0x6F,0xC5,0x30,0x01,0x67,0x2B,0xFE,0xD7,0xAB,0x76,
            0xCA,0x82,0xC9,0x7D,0xFA,0x59,0x47,0xF0,0xAD,0xD4,0xA2,0xAF,0x9C,0xA4,0x72,0xC0,
            0xB7,0xFD,0x93,0x26,0x36,0x3F,0xF7,0xCC,0x34,0xA5,0xE5,0xF1,0x71,0xD8,0x31,0x15,
            0x04,0xC7,0x23,0xC3,0x18,0x96,0x05,0x9A,0x07,0x12,0x80,0xE2,0xEB,0x27,0xB2,0x75,
            0x09,0x83,0x2C,0x1A,0x1B,0x6E,0x5A,0xA0,0x52,0x3B,0xD6,0xB3,0x29,0xE3,0x2F,0x84,
            0x53,0xD1,0x00,0xED,0x20,0xFC,0xB1,0x5B,0x6A,0xCB,0xBE,0x39,0x4A,0x4C,0x58,0xCF,
            0xD0,0xEF,0xAA,0xFB,0x43,0x4D,0x33,0x85,0x45,0xF9,0x02,0x7F,0x50,0x3C,0x9F,0xA8,
            0x51,0xA3,0x40,0x8F,0x92,0x9D,0x38,0xF5,0xBC,0xB6,0xDA,0x21,0x10,0xFF,0xF3,0xD2,
            0xCD,0x0C,0x13,0xEC,0x5F,0x97,0x44,0x17,0xC4,0xA7,0x7E,0x3D,0x64,0x5D,0x19,0x73,
            0x60,0x81,0x4F,0xDC,0x22,0x2A,0x90,0x88,0x46,0xEE,0xB8,0x14,0xDE,0x5E,0x0B,0xDB,
            0xE0,0x32,0x3A,0x0A,0x49,0x06,0x24,0x5C,0xC2,0xD3,0xAC,0x62,0x91,0x95,0xE4,0x79,
            0xE7,0xC8,0x37,0x6D,0x8D,0xD5,0x4E,0xA9,0x6C,0x56,0xF4,0xEA,0x65,0x7A,0xAE,0x08,
            0xBA,0x78,0x25,0x2E,0x1C,0xA6,0xB4,0xC6,0xE8,0xDD,0x74,0x1F,0x4B,0xBD,0x8B,0x8A,
            0x70,0x3E,0xB5,0x66,0x48,0x03,0xF6,0x0E,0x61,0x35,0x57,0xB9,0x86,0xC1,0x1D,0x9E,
            0xE1,0xF8,0x98,0x11,0x69,0xD9,0x8E,0x94,0x9B,0x1E,0x87,0xE9,0xCE,0x55,0x28,0xDF,
            0x8C,0xA1,0x89,0x0D,0xBF,0xE6,0x42,0x68,0x41,0x99,0x2D,0x0F,0xB0,0x54,0xBB,0x16
    ]
    RCON = [0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36]

    def __init__(self, key):
        self.key = key
        self.rk = self.expandKey(self.key)

    def _round(self, x, cst = 0): # fancy permutation # can be inverted
        a, b, c, d = x[-4:]
        t = bytes([self.S[b] ^^ cst, self.S[c], self.S[d], self.S[a]])
        y  = xor(x[ 0: 4], t)
        y += xor(x[ 4: 8], y[-4:]) # append bytes at the end
        y += xor(x[ 8:12], y[-4:])
        y += xor(x[12:16], y[-4:])
        return y

    def _round_inv(self, y, cst = 0):
        t_0 = y[ 0: 4]
        t_1 = y[ 4: 8]
        t_2 = y[ 8: 12]
        t_3 = y[ 12: 16]
        x = [0]*16
        x[ 0: 4] = t_0   
        x[ 4: 8] = xor(t_0, t_1)
        x[ 8:12] = xor(t_1, t_2)
        x[12:16] = xor(t_2, t_3)
        x[0] = x[0] ^^ self.S[x[13]] ^^ cst
        x[1] = x[1] ^^ self.S[x[14]]
        x[2] = x[2] ^^ self.S[x[15]]
        x[3] = x[3] ^^ self.S[x[12]]
        return bytes(x)
    
    def expandKey(self, k):
        rk = [k]
        for _ in range(10):
            rk.append(self._round(rk[-1], self.RCON[len(rk)])) # value 0 of RCON is unused
        return rk

    def encrypt(self, p):
        c = p
        for sk in self.rk[:-1]:
            c = xor(c, sk)
            for _ in range(5):
                c = self._round(c)
        return xor(c, self.rk[-1])

    def decrypt(self, p):
        p = xor(p, self.rk[-1])
        for sk in self.rk[:-1][::-1]:
            for _ in range(5):
                p = self._round_inv(p)
            p = xor(p, sk)
        return p

In [529]:
a, b, c, d = 1, 1, 1, 1
k = [a+b+c+d,a+b+c+d,a+b+c+d,a+b+c+d,a+c,a+c,a+c,a+c,a+d,a+d,a+d,a+d,a,a,a,a]
k = tab_to_bytes(k)
print(pi_i(k, 0))
print(pi_i(k, 1))
print(pi_i(k, 2))
print(pi_i(k, 3))

[3, 3, 5, 1, 0, 3, 0, 1, 3, 0, 0, 1, 0, 0, 0, 1]
[3, 5, 1, 3, 3, 0, 1, 0, 0, 0, 1, 3, 0, 0, 1, 0]
[5, 1, 3, 3, 0, 1, 0, 3, 0, 1, 3, 0, 0, 1, 0, 0]
[1, 3, 3, 5, 1, 0, 3, 0, 1, 3, 0, 0, 1, 0, 0, 0]


In [467]:
P = TightSchedule(k)
bytes_to_tab(P.encrypt(b'\x00'*16))

[102, 21, 197, 103, 126, 255, 65, 218, 236, 89, 4, 114, 215, 71, 238, 169]

In [344]:
import os
import json
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from Crypto.Util.strxor import strxor as xor

class ClikeTightSchedule:
    S = [
            0x63,0x7C,0x77,0x7B,0xF2,0x6B,0x6F,0xC5,0x30,0x01,0x67,0x2B,0xFE,0xD7,0xAB,0x76,
            0xCA,0x82,0xC9,0x7D,0xFA,0x59,0x47,0xF0,0xAD,0xD4,0xA2,0xAF,0x9C,0xA4,0x72,0xC0,
            0xB7,0xFD,0x93,0x26,0x36,0x3F,0xF7,0xCC,0x34,0xA5,0xE5,0xF1,0x71,0xD8,0x31,0x15,
            0x04,0xC7,0x23,0xC3,0x18,0x96,0x05,0x9A,0x07,0x12,0x80,0xE2,0xEB,0x27,0xB2,0x75,
            0x09,0x83,0x2C,0x1A,0x1B,0x6E,0x5A,0xA0,0x52,0x3B,0xD6,0xB3,0x29,0xE3,0x2F,0x84,
            0x53,0xD1,0x00,0xED,0x20,0xFC,0xB1,0x5B,0x6A,0xCB,0xBE,0x39,0x4A,0x4C,0x58,0xCF,
            0xD0,0xEF,0xAA,0xFB,0x43,0x4D,0x33,0x85,0x45,0xF9,0x02,0x7F,0x50,0x3C,0x9F,0xA8,
            0x51,0xA3,0x40,0x8F,0x92,0x9D,0x38,0xF5,0xBC,0xB6,0xDA,0x21,0x10,0xFF,0xF3,0xD2,
            0xCD,0x0C,0x13,0xEC,0x5F,0x97,0x44,0x17,0xC4,0xA7,0x7E,0x3D,0x64,0x5D,0x19,0x73,
            0x60,0x81,0x4F,0xDC,0x22,0x2A,0x90,0x88,0x46,0xEE,0xB8,0x14,0xDE,0x5E,0x0B,0xDB,
            0xE0,0x32,0x3A,0x0A,0x49,0x06,0x24,0x5C,0xC2,0xD3,0xAC,0x62,0x91,0x95,0xE4,0x79,
            0xE7,0xC8,0x37,0x6D,0x8D,0xD5,0x4E,0xA9,0x6C,0x56,0xF4,0xEA,0x65,0x7A,0xAE,0x08,
            0xBA,0x78,0x25,0x2E,0x1C,0xA6,0xB4,0xC6,0xE8,0xDD,0x74,0x1F,0x4B,0xBD,0x8B,0x8A,
            0x70,0x3E,0xB5,0x66,0x48,0x03,0xF6,0x0E,0x61,0x35,0x57,0xB9,0x86,0xC1,0x1D,0x9E,
            0xE1,0xF8,0x98,0x11,0x69,0xD9,0x8E,0x94,0x9B,0x1E,0x87,0xE9,0xCE,0x55,0x28,0xDF,
            0x8C,0xA1,0x89,0x0D,0xBF,0xE6,0x42,0x68,0x41,0x99,0x2D,0x0F,0xB0,0x54,0xBB,0x16
    ]
    RCON = [0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36]

    def __init__(self, key):
        self.key = key
        self.rk = rk = [[0]*16 for _ in range (11)] # call to malloc   
        self.expandKey(self.key, self.rk)

    def _xor(self, x, y):
        # modify x
        for i in range (16):
            x[i] ^^= y[i]

    def _round_in_place(self, input, cst = 0):
        for i in range (4):
            if i != 0:
                for j in range (4):
                    input[4*i + j] ^^= input[4*(i-1) + j]
            else:
                input[0] ^^= self.S[input[13]] ^^ cst
                input[1] ^^= self.S[input[14]]
                input[2] ^^= self.S[input[15]]
                input[3] ^^= self.S[input[12]]
    
    def _round_out_of_place(self, input, output, cst = 0):
        for i in range (4):
            if i != 0:
                for j in range (4):
                    output[4*i + j] ^^= input[4*i + j]
                    output[4*i + j] ^^= output[4*(i-1) + j]
            else:
                output[0] = input[0] ^^ self.S[input[13]] ^^ cst
                output[1] = input[1] ^^ self.S[input[14]]
                output[2] = input[2] ^^ self.S[input[15]]
                output[3] = input[3] ^^ self.S[input[12]]

    def expandKey(self, k, rk):
        rk[0] = k
        for i in range(1, 11):
            self._round_out_of_place(rk[i-1], rk[i], self.RCON[i])

    def encrypt(self, input, output):
        # we must copy all the value of the input in the ouput
        output_ = input.copy() 
        for i in range (10):
            self._xor(output_, self.rk[i])
            for _ in range(5):
                self._round_in_place(output_)
        self._xor(output_, self.rk[10])
        # that should not occur while writing C code
        for i in range (16):
            output[i] = output_[i]

In [426]:
k = os.urandom(16)
msg = os.urandom(16)

In [458]:
from tqdm import tqdm

P = TightSchedule(k)
clear  = bytes_to_tab(msg)
cipher = bytes_to_tab(P.encrypt(msg))

for i in range (4):
    k_i = sum (pi_i(bytes_to_tab(k), i), pi_i(bytes_to_tab(os.urandom(16)), i+1))
    #msg_i = pi_i(bytes_to_tab(msg), i)
    msg_i = bytes_to_tab(msg)
    
    output = [0] * 16
    P = ClikeTightSchedule(k_i)
    P.encrypt(msg_i, output)
    assert pi_i(output, i+2) == pi_i(cipher, i+2)

In [405]:
flag = b"text"

P.decrypt(P.encrypt(pad(flag, 16))) == pad(flag, 16)

AttributeError: 'ClikeTightSchedule' object has no attribute 'decrypt'

In [28]:
P.decrypt(P.encrypt(p)) == p

False

The matrix $A$ transforms the function a byte from the canonical basis to the new basis.

It takes a vector $x\in (\mathbb{F}_{2^8})^{16}$ and outputs its decomposition in the new basis.

In [508]:
A = [ 
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],
        [0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,0],
        [0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0],
        [0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],
        [0,1,0,0,0,1,0,0,0,1,0,0,0,1,0,0],
        [0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],
        [1,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0],
        [0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1],
        [0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],
        [0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1],
        [0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0],
        [0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0]
    ]

A_inv = [
        [0,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0],
        [0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,1],
        [0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0],
        [1,0,0,0,0,0,0,1,0,0,1,0,0,1,0,0],
        [0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0],
        [0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0],
        [0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0],
        [1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],
        [0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0],
        [0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1],
        [0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0],
        [1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],
        [0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],
        [0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],
        [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
]

#A = Matrix(FiniteField(2), A)
#A_inv = Matrix(FiniteField(2), A_inv)

#assert A * A_inv == Matrix.identity(16)

A = Matrix(ZZ, A)
A_inv = Matrix(ZZ, A_inv)

def bytes_to_tab (b):
    return [c for c in b]

def tab_to_bytes(t):
    return bytes(t)

def canonical_to_invariant (input: list):
    output = [0] * 16
    for i in range (16):
        for j in range (16):
            output[i] ^^= A[i, j] * input[j]
    return output

def invariant_to_canonical (input: list):
    output = [0] * 16
    for i in range (16):
        for j in range (16):
            output[i] ^^= A_inv[i, j] * input[j]
    return output

def pi_i(input, i):
    # return the ith projection of input
    i = i%4
    output = canonical_to_invariant(input)
    for j in range (16):
        if not (4*i <= j < 4*(i+1)):
            output[j] = 0
    output = invariant_to_canonical(output)
    return output

def lift_pi_i(a, b, c, d, i):
    i = i%4
    if i==0:
        return [d,c,b,a,0,c,0,a,d,0,0,a,0,0,0,a]
    elif i==1:
        return [c,b,a,d,c,0,a,0,0,0,a,d,0,0,a,0]
    elif i==2:
        return [b,a,d,c,0,a,0,c,0,a,d,0,0,a,0,0]
    else: # i==3
        return [a,d,c,b,a,0,c,0,a,d,0,0,a,0,0,0]

def sum (x, y):
    return [a ^^ b for a, b in zip(x, y)]

def sum_parts (pi_0, pi_1, pi_2, pi_3):
    s = [a ^^ b ^^ c ^^ d for a, b, c, d in zip(pi_0, pi_1, pi_2, pi_3)]
    return tab_to_bytes(s)

In [509]:
key = bytes_to_tab(k)
pi_0 = pi_i(key, 0)
pi_1 = pi_i(key, 1)
pi_2 = pi_i(key, 2)
pi_3 = pi_i(key, 3)
k == sum_parts(pi_0, pi_1, pi_2, pi_3)

True

In [425]:
a, b, c, d = 1, 2, 3, 4
for i in range (4):
    x = lift_pi_i(a, b, c, d, i)
    assert pi_i(x, i) == x

In [257]:
v = vector((0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0))
v = (2^8 - 1) * v
v = v.list()
pi_i(v, 3) == v

False

In [495]:
from json import loads
file = "output.txt"
tab = json.loads(open(file, "r").read())
plaintext = bytes.fromhex(tab["p"])
ciphertext = bytes.fromhex(tab["c"])
iv = bytes.fromhex(tab["iv"])
flag_enc = bytes.fromhex(tab["flag_enc"])

In [507]:
bytes_to_tab(ciphertext)

[212, 237, 25, 224, 105, 65, 1, 182, 177, 81, 225, 28, 45, 185, 115, 191]

In [514]:
pi_2 = [7, 64, 187, 9, 0, 64, 0, 9, 0, 64, 187, 0, 0, 64, 0, 0]
pi_1 = [96, 217, 153, 215, 96, 0, 153, 0, 0, 0, 153, 215, 0, 0, 153, 0]
pi_0 = [71, 142, 75, 90, 0, 142, 0, 90, 71, 0, 0, 90, 0, 0, 0, 90]
pi_3 = [76, 127, 191, 108, 76, 0, 191, 0, 76, 127, 0, 0, 76, 0, 0, 0]

In [515]:
k = sum_parts(pi_0, pi_1, pi_2, pi_3)

In [516]:
from Crypto.Cipher import AES
E = AES.new(k, AES.MODE_CBC, iv = iv)
E.decrypt(flag_enc)

b'\xbcPo\x0c\xab\x07\xfe\x9c\xa4\x80\xe1\\\xa2Xj\x80\xcb].\xd2~|\xc0\x05|n\xffI\xf7)\x98XU\xc2\x82,YB\x85IM\xab\xdbW\xba\x89\x05v\xb9\xdc\xe6\x80`\rU#\xc4\x82\xf8\xe7e\xc0\xa2A3\x8b\xc6\xc1J\xce\x01?\xe7\x15{b\x12\x07\xf2\x06'

In [517]:
P = TightSchedule(k)
P.encrypt(plaintext) == ciphertext

False