In [1]:
import numpy as np
from itertools import product

In [6]:
# MDS matrix taken from https://dergipark.org.tr/en/download/article-file/2290583

A = np.array([
    [0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0],
    [0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1],
    [0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0],
    [1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1],
    [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1],
    [0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1],
    [0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1],
    [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0],
    [1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
    [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0],
    [0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0],
    [1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1],
    [0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
    [0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0],
    [1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1],
    [1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1]])

In [8]:
rank = np.linalg.matrix_rank(A)
print(rank)

16


In [10]:
b = np.array([1] * 8 + [0] * 8)
np.random.seed(42)
np.random.shuffle(b)
print(b)

[1 1 1 0 0 0 0 0 1 0 1 1 0 0 1 1]


In [4]:
def f1(x):
    return (x[0] * x[1]) ^ (x[2] * x[3]) ^ (x[4] * x[5]) ^ (x[6] * x[7])

def f2(x):
    return (x[0] * x[2]) ^ (x[1] * x[3]) ^ (x[4] * x[6]) ^ (x[5] * x[7])

def f3(x):
    return (x[0] * x[3]) ^ (x[1] * x[2]) ^ (x[8] * x[9]) ^ (x[10] * x[11])

def f4(x):
    return (x[4] * x[5]) ^ (x[6] * x[7]) ^ (x[12] * x[13]) ^ (x[14] * x[15])

def f5(x):
    return (x[8] * x[9]) ^ (x[10] * x[11]) ^ (x[0] * x[4]) ^ (x[1] * x[5])

def f6(x):
    return (x[2] * x[6]) ^ (x[3] * x[7]) ^ (x[12] * x[14]) ^ (x[13] * x[15])

def f7(x):
    return (x[0] * x[8]) ^ (x[1] * x[9]) ^ (x[2] * x[10]) ^ (x[3] * x[11])

def f8(x):
    return (x[4] * x[12]) ^ (x[5] * x[13]) ^ (x[6] * x[14]) ^ (x[7] * x[15])

def f9(x):
    return (x[0] * x[5]) ^ (x[1] * x[4]) ^ (x[2] * x[7]) ^ (x[3] * x[6])

def f10(x):
    return (x[8] * x[13]) ^ (x[9] * x[12]) ^ (x[10] * x[15]) ^ (x[11] * x[14])

def f11(x):
    return (x[0] * x[13]) ^ (x[1] * x[12]) ^ (x[2] * x[15]) ^ (x[3] * x[14])

def f12(x):
    return (x[4] * x[9]) ^ (x[5] * x[8]) ^ (x[6] * x[11]) ^ (x[7] * x[10])

def f13(x):
    return (x[8] * x[5]) ^ (x[9] * x[4]) ^ (x[10] * x[7]) ^ (x[11] * x[6])

def f14(x):
    return (x[0] * x[10]) ^ (x[1] * x[11]) ^ (x[2] * x[8]) ^ (x[3] * x[9])

def f15(x):
    return (x[4] * x[14]) ^ (x[5] * x[15]) ^ (x[6] * x[12]) ^ (x[7] * x[13])

def f16(x):
    return (x[12] * x[1]) ^ (x[13] * x[0]) ^ (x[14] * x[3]) ^ (x[15] * x[2])


bent_functions = [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16]  

In [5]:
def S_box(x):
    x_bits = [(x >> i) & 1 for i in range(16)]
    S_x = np.array([f(x_bits) for f in bent_functions])
    print(S_x)
    S_prime_x = (A @ S_x) % 2
    print(S_prime_x)
    output_bits = (S_prime_x + b) % 2
    return int("".join(str(bit) for bit in output_bits[::-1]), 2)

# Example for a random input
x = 0b1100000011000000
S_x = S_box(x)
print(f"Input: {bin(x)}\nOutput: {bin(S_x)}")


[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 1 1 0 0 0 1 1 0 1 0 0 1 1]
Input: 0b1100000011000000
Output: 0b11000011111
