In [3]:
import numpy as np
from functools import reduce, lru_cache, cached_property

class Element:
    
    A,B,C,D = [*map(lambda x: np.matrix(x, dtype=np.uint32), ([[1,0,0,0], [0,1,0,0], [0,0,0,1], [297,0,336,336]], [[1,269,0,0], [5,335,0,0], [0,8,0,1], [297,8,336,336]], [[8,0,0,0], [40,329,0,0], [232,0,295,0], [227,0,42,42]], [[329,0,269,0], [0,0,336,1], [105,0,8,0], [110,336,8,0]]
))]
    ID = np.matrix([[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]], dtype=np.uint32)
    sz = 252
    mod = 337
    
    def __init__(self, x):
        assert 0 <= x < self.sz
        a, x = x % 3, x // 3
        b, x = x % 3, x // 3
        c, x = x % 7, x // 7
        d = x
        self.v = reduce(lambda a,b: Element._mulmat(a,b), [
            Element._powmat(self.A, a),
            Element._powmat(self.B, b),
            Element._powmat(self.C, c),
            Element._powmat(self.D, d)
        ])
        
    @classmethod
    def _from_v(_, v):
        return Element(MAPPING[hash(bytes(sum(v)))])
        
    @staticmethod
    def _mulmat(A,B):
        return (A @ B) % Element.mod
    
    @staticmethod
    def _powmat(A,n):
        if n == 0: return Element.ID
        if n == 1: return A
        if n % 2:
            return Element._mulmat(A, Element._powmat(A, n-1))
        X = Element._powmat(A, n//2)
        return Element._mulmat(X, X)
    
    def to_byte(self):
        return MAPPING[hash(self)]
    
    def __add__(self, other):
        return Element._from_v(Element._mulmat(self.v, other.v))
    
    def __mul__(self, n):
        return Element._from_v(Element._powmat(self.v, n))
    
    def __rmul__(self, n):
        return self*n
    
    def __hash__(self):
        return hash(bytes(sum(self.v)))
    
    def __eq__(self, other):
        return hash(self) == hash(other)
    
    def __str__(self):
        return str(self.to_byte())
    
    def __repr__(self):
        return f"<E:{self}>"
    
MAPPING = {hash(Element(i)): i for i in range(Element.sz)}

In [4]:
def to_tuple(x):
    a, x = x % 3, x // 3
    b, x = x % 3, x // 3
    c, x = x % 7, x // 7
    d = x
    return (a,b,c,d)

def from_tuple(t):
    a,b,c,d = t
    return ((d*7 + c)*3 + b)*3 + a

if "invmap" not in globals():
    invmap = {Element(i): next(Element(j) for j in range(Element.sz) if (Element(j) + Element(i)).to_byte() == 0) for i in range(Element.sz)}

def commutators(tuples):
    alle = [Element(from_tuple(x)) for x in tuples]
    com = []
    for x in alle:
        for y in alle:
            com.append(x+y+invmap[x]+invmap[y])
    return set([to_tuple(c.to_byte()) for c in com])

In [5]:
def center(tuples):
    A = [Element(i) for i in range(252)]
    B = [Element(from_tuple(x)) for x in tuples]
    com = []
    for x in A:
        for y in B:
            if (x+y+invmap[x]+invmap[y]).to_byte() != 0:
                break
        else:
            com.append(x)
        
    return set([to_tuple(c.to_byte()) for c in com])

def is_normal(tuples):
    A = [Element(i) for i in range(252)]
    B = [Element(from_tuple(x)) for x in tuples]
    com = []
    for x in A:
        for y in B:
            if (x+y+invmap[x]) not in B:
                return False
    return True

In [11]:
is_normal([(i,j,k,(k*2)%4)
    for i in range(3)
    for j in range(3)
    for k in range(7)
    #for l in range(4)
])

True

In [4]:
# C2^3 |x (C7 |x C4)
# G > 

# The commutators are G' = (i,j,k,k*2) ~ C2^3 x C7
# G / G' = C4 or C2 x C2 (Quick verification shows it is C4)

# G_derived = commutators([(i,j,k,l) # prints (i,j,k,k*2)
#     for i in range(3)
#     for j in range(3)
#     for k in range(7)
#     for l in range(4)
# ])

G_derived = [(i,j,k,(k*2)%4)
    for i in range(3)
    for j in range(3)
    for k in range(7)
]
assert is_normal(G_derived)
commutators(G_derived) # prints (0,0,0,0)

# G = (C2^3 x C7) |x C4
# G/G' = C4

# Plan:
#   - Solve over G/G' ~ C4 via gaussian elimination
#   - Make all var to be within G' (we alr know which quotient each var is in via prev part)
#   - G' is abelian: gaussian elimination

{(0, 0, 0, 0)}

In [5]:
# These are elements outside of G' that happen to be distinct G' quotient representatives
# Extra: Quotient has isomorphic image in G: G = G' |x C4

pos_quo = [Element(from_tuple((0,0,1,1)))*i for i in range(4)]
invquomap = {
    i: set([q + Element(from_tuple(g)) for g in G_derived]) for i,q in enumerate(pos_quo)
}
# Maps G to C4
quomap = {
    g: i for i,quo in invquomap.items() for g in quo
}
assert len(set(sum(map(list, invquomap.values()), start=[]))) == Element.sz

In [6]:
RCON = [*map(Element, (
    0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40,
    0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A,
    0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A,
    0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39,
))]
N_ROUNDS = 10
N_BYTES = 16

def shift_rows(s):
    s[0][1], s[1][1], s[2][1], s[3][1] = s[1][1], s[2][1], s[3][1], s[0][1]
    s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2]
    s[0][3], s[1][3], s[2][3], s[3][3] = s[3][3], s[0][3], s[1][3], s[2][3]
    
def add_round_key(s, k):
    for i in range(4):
        for j in range(4):
            s[i][j] += k[i][j]
            
def mix_single_column(a):
    b1,b2,b3,b4 = (
        2*a[0] + 3*a[1] + 1*a[2] + 1*a[3],
        1*a[0] + 2*a[1] + 3*a[2] + 1*a[3],
        1*a[0] + 1*a[1] + 2*a[2] + 3*a[3],
        3*a[0] + 1*a[1] + 1*a[2] + 2*a[3]
    )
    a[0],a[1],a[2],a[3] = b1,b2,b3,b4
    
def mix_columns(s):
    for i in range(4):
        mix_single_column(s[i])
        
def xor_bytes(a, b):
    return [i+j for i, j in zip(a, b)]

def bytes2matrix(text):
    return [[*map(Element, text[i:i+4])] for i in range(0, len(text), 4)]

def matrix2bytes(matrix):
    return bytes(map(lambda m: m.to_byte(), sum(matrix, [])))
        
def expand_key(master_key):
    
    key_columns = bytes2matrix(master_key)
    iteration_size = len(master_key) // 4

    i = 1
    while len(key_columns) < (N_ROUNDS + 1) * 4:
        # Copy previous word.
        word = list(key_columns[-1])

        # Perform schedule_core once every "row".
        if len(key_columns) % iteration_size == 0:
            # Circular shift.
            word.append(word.pop(0))
            # XOR with first byte of R-CON, since the others bytes of R-CON are 0.
            word[0] += RCON[i]
            i += 1

        # XOR with equivalent word from previous iteration.
        word = xor_bytes(word, key_columns[-iteration_size])
        key_columns.append(word)

    # Group key words in 4x4 byte matrices.
    return [key_columns[4*i : 4*(i+1)] for i in range(len(key_columns) // 4)]

def encrypt_block(key, plaintext):

    assert len(plaintext) == N_BYTES

    plain_state = bytes2matrix(plaintext)
    round_keys = expand_key(key)
    
    add_round_key(plain_state, round_keys[0])

    for i in range(1, N_ROUNDS):
        shift_rows(plain_state)
        mix_columns(plain_state)
        add_round_key(plain_state, round_keys[i])

    shift_rows(plain_state)
    add_round_key(plain_state, round_keys[-1])

    return matrix2bytes(plain_state)

In [22]:
from secrets import randbelow

def gen_triple():
    key = bytes([randbelow(Element.sz) for _ in range(N_BYTES)])
    pt = b"1234567890123456"
    ct = encrypt_block(key, pt)
    return key,pt,ct

def to_quo(b):
    return [quomap[i] for i in map(Element, b)]

In [26]:
k1,p1,c1 = gen_triple()
print(f"k1 = {to_quo(k1)}")
print(f"p1 = {to_quo(p1)}")
print(f"c1 = {to_quo(c1)}")

k1 = [0, 1, 3, 2, 3, 2, 1, 3, 3, 2, 0, 1, 2, 2, 2, 2]
p1 = [2, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0]
c1 = [1, 0, 1, 3, 2, 0, 2, 3, 3, 0, 1, 1, 1, 3, 3, 1]


In [60]:
from secrets import randbelow

for i in range(100):
    xx = Element(randbelow(252))
    yy = Element(randbelow(252))
    x = quomap[xx]
    y = quomap[yy]
    assert (x+y)%4 == quomap[xx+yy]
    
    g = randbelow(7)
    assert (x*g)%4 == quomap[xx*g]
    
    assert quomap[xx+yy+invmap[xx]+invmap[yy]] == 0

In [103]:
from itertools import product

to_cyc = lambda x: (lambda x: _[(0,x[1],x[2],x[3])])(to_tuple(x.to_byte()))
from_cyc = lambda i: Element(from_tuple((0,i%3,i%7,((i%7)*2)%4)))
get_iter = lambda: range(21)

_ = {}
for i in get_iter():
    _[to_tuple(from_cyc(i).to_byte())] = i

for i in get_iter():
    for j in get_iter():
        
        x = from_cyc(i)
        y = from_cyc(j)
        
        assert (i+j)%21 == to_cyc(x + y)
        assert (i*j)%21 == to_cyc(x*j)
        assert to_cyc(x+y+invmap[x]+invmap[y]) == 0