In [None]:
import random
import math

In [None]:
# Masked AES for d = 1 with no ifs

In [None]:
# 8bit random number
def rand8():
  return random.randint(0, 255)

In [None]:
# multiply by x in GF(256) :

def shift_com_r(a, r, shift):
  for i in range(shift):
    a <<= 1
    tmp = r
    a = a | tmp
  return a

def xtime(a):
  a <<= 1
  r = a & 0x100
  r >>= 8
  r = shift_com_r(r, r, 7)
  b = a ^ 0x11B
  b = b & 0xFF
  a = a & 0xFF
  ret = (r&(b ^ a)) ^ a
  return ret

In [None]:
# rotate a left n times:
def rot(a, n):
  r = a
  for i in range(n):
    r = xtime(a)
  return r

In [None]:
# a * b in GF(256) :
def gf_mul(a, b):
  res = 0
  while b:
    tmp = b & 1
    res = res  ^ (tmp*(a))
    a = xtime(a)
    b >>= 1
  return res


In [None]:
# a^e in GF(256) :
def gf_pow(a, e):
  r = 1
  while e:
    tmp = e & 1
    r = gf_mul(r, a) * tmp + (1-tmp)*r
    a = gf_mul(a, a)
    e >>= 1
  return r

In [None]:
# Simple inverse in GF(256):
def get_inv(a):
  return gf_pow(a, 254)

In [None]:
# Given p, calculate S[p] :

def rotl8(x, shift):
    return ((x << shift) | (x >> (8 - shift))) & 0xFF

def sbox(p):
    r = math.ceil(p/255)
    b = get_inv(p) & 0xFF
    tmp = (b ^
           rotl8(b,1) ^
           rotl8(b,2) ^
           rotl8(b,3) ^
           rotl8(b,4))
    return (1-r)*0x63 + r*(tmp ^ 0x63)

In [None]:
# Given a and b perform masked multiplication using the ISW algorithm with d+1 shares

def isw_mul(a, b, d):
  d = d+1
  r = [[0]*d]*d
  z = [[0]*d]*d
  c = [0]*d
  for i in range(d):
    for j in range(i+1,d):
      r[i][j] = rand8()
      z[i][j] = (r[i][j] ^ gf_mul(a[i], b[j])) ^ gf_mul(a[j], b[i])
      z[j][i] = r[i][j]
  for i in range(d):
    # somatorio:
    tmp = 0
    for j in range(d):
      if i != j:
        tmp ^= z[i][j]
    c[i] = gf_mul(a[i], b[i]) ^ tmp
  return c

In [None]:
# Inverse of a masked value in GF(256), working for d = 1 only :
def gf_inv_mskd(p, d):
  rand = rand8()
  b = [0x01 ^ rand, rand] # so funciona para d = 1
  e = 254
  while e:
    tmp = e & 1
    b = isw_mul(b, p, d) * tmp + (1-tmp)*b
    p = isw_mul(p, p, d)
    e >>= 1
  return b

In [None]:
# Given a masked value p, calculate S[p], only working for d = 1:

def mk_tmp_mskd(b, d):
  tmp = [0]*(d+1)
  for i in range(d+1):
    tmp[i] = b[i] ^ rotl8(b[i], 1) ^ rotl8(b[i], 2) ^ rotl8(b[i], 3) ^ rotl8(b[i], 4)
  r = 0
  return tmp

def sbox_mskd(p, d):
  r = math.ceil((p[0] ^ p[1])/255) # p e [0,255], but I must know if p == 0
  b = gf_inv_mskd(p, d)
  tmp = mk_tmp_mskd(b, d)
  tmp[0] = tmp[0]
  tmp[1] = tmp[1] ^ 0x63
  #
  r0 = rand8()
  zero_out = [0x63 ^ r0, r0]
  ret = [0, 0]
  ret[0] = (1-r)*zero_out[0] + r*(tmp[0])
  ret[1] = (1-r)*zero_out[1] + r*(tmp[1])
  return ret

In [None]:
# Mix Columns with a masked state, working for d = 1 only :
def mix_columns_mskd(state, d):
  l = len(state[0])
  tmp = [[[0, 0], [0, 0], [0, 0], [0, 0]],
         [[0, 0], [0, 0], [0, 0], [0, 0]],
         [[0, 0], [0, 0], [0, 0], [0, 0]],
         [[0, 0], [0, 0], [0, 0], [0, 0]]]
  a = [0x02 ^ 0x0, 0x0]
  b = [0x03 ^ 0x0, 0x0]
  for c in range(4):
    mul0 = isw_mul(a, state[0][c], 1)
    mul1 = isw_mul(b, state[1][c], 1)
    tmp[0][c][0] = (mul0[0] ^ mul1[0] ^ state[2][c][0] ^ state[3][c][0]) & 0xFF
    tmp[0][c][1] = (mul0[1] ^ mul1[1] ^ state[2][c][1] ^ state[3][c][1]) & 0xFF
    mul2 = isw_mul(a, state[1][c], 1)
    mul3 = isw_mul(b, state[2][c], 1)
    tmp[1][c][0] = ((state[0][c][0] ^ mul2[0]) ^ mul3[0] ^ state[3][c][0]) & 0xFF
    tmp[1][c][1] = ((state[0][c][1] ^ mul2[1]) ^ mul3[1] ^ state[3][c][1]) & 0xFF
    mul4 = isw_mul(a, state[2][c], 1)
    mul5 = isw_mul(b, state[3][c], 1)
    tmp[2][c][0] = ((state[0][c][0] ^ state[1][c][0]) ^ mul4[0] ^ mul5[0]) & 0xFF
    tmp[2][c][1] = ((state[0][c][1] ^ state[1][c][1]) ^ mul4[1] ^ mul5[1]) & 0xFF
    mul6 = isw_mul(b, state[0][c], 1)
    mul7 = isw_mul(a, state[3][c], 1)
    tmp[3][c][0] = (mul6[0] ^ state[1][c][0] ^ state[2][c][0] ^ mul7[0]) & 0xFF
    tmp[3][c][1] = (mul6[1] ^ state[1][c][1] ^ state[2][c][1] ^ mul7[1]) & 0xFF

  return tmp

In [None]:
# Shift Rows(It does not matter with state is masked or not) :
def shiftrows(state):
    assert len(state) == 4 and all(len(row) == 4 for row in state)
    # Row 0: no shift
    state[1] = state[1][1:] + state[1][:1]
    state[2] = state[2][2:] + state[2][:2]
    state[3] = state[3][3:] + state[3][:3]

    return state

In [None]:
# Given a masked state and round_key, perform an xor term by term :
def add_round_key_mskd(state, round_key, d):
    tmp = state.copy()
    for k in range(2):
      for i in range(4):
        for j in range(4):
          tmp[i][j][k] ^= round_key[i][j][k]
    return tmp

In [None]:
###################################### BEGIN KEY SCHEDULE ######################################################
#

RCON = [
    0x00, # unused
    0x01, 0x02, 0x04, 0x08,
    0x10, 0x20, 0x40, 0x80,
    0x1b, 0x36
]

def rot_word(word):
    return word[1:] + word[:1]

def sub_word(word):
    return [sbox(b) for b in word]

def key_exp_core(word, i):
    word = rot_word(word)
    word = sub_word(word)
    word[0] ^= gf_pow(0x02, i-1) # rcon(i) = gf_pow(2, i-1)
    return word

def key_exp(key):
    # key: list of 16 bytes
    assert len(key) == 16
    expanded = [0] * 176
    expanded[0:16] = key

    bytes_generated = 16
    rcon_iter = 1
    temp = [0, 0, 0, 0]

    while bytes_generated < 176:
        # last 4 bytes
        temp = expanded[bytes_generated-4:bytes_generated]

        if bytes_generated % 16 == 0:
            temp = key_exp_core(temp, rcon_iter)
            rcon_iter += 1

        for i in range(4):
            expanded[bytes_generated] = expanded[bytes_generated - 16] ^ temp[i]
            bytes_generated += 1

    return expanded


key = [0x0]*16
#
def mask_exp_keys(expd_keys):
  #
  m = rand8()
  tmp = [[0] for _ in range(176)]

  for i in range(len(expd_keys)):
    tmp[i] = [expd_keys[i] ^ m, m]
  return tmp

print(mask_exp_keys(key))


def test_key_exp():
    key = [
        0x2b, 0x7e, 0x15, 0x16,
        0x28, 0xae, 0xd2, 0xa6,
        0xab, 0xf7, 0x15, 0x88,
        0x09, 0xcf, 0x4f, 0x3c
    ]
    expanded = key_exp(key)
    assert len(expanded) == 176

    expected_r1 = [
        0xa0,0xfa,0xfe,0x17, 0x88,0x54,0x2c,0xb1,
        0x23,0xa3,0x39,0x39, 0x2a,0x6c,0x76,0x05
    ]
    expected_r2 = [
        0xf2,0xc2,0x95,0xf2, 0x7a,0x96,0xb9,0x43,
        0x59,0x35,0x80,0x7a, 0x73,0x59,0xf6,0x7f
    ]
    expected_r3 = [
        0x3d,0x80,0x47,0x7d, 0x47,0x16,0xfe,0x3e,
        0x1e,0x23,0x7e,0x44, 0x6d,0x7a,0x88,0x3b
    ]

    assert expanded[16:32] == expected_r1, "Round 1 mismatch"
    assert expanded[32:48] == expected_r2, "Round 2 mismatch"
    assert expanded[48:64] == expected_r3, "Round 3 mismatch"

    print("AES-128 key expansion matches FIPS-197 Appendix A.1 up to round 3!")
test_key_exp()


[[143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [143, 143], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]]
AES-128

In [None]:
# For each value in a masked state, get the corresponding sbox value :
def sub_bytes_masked(state, d):
  for i in range(4):
    for j in range(4):
      state[i][j] = sbox_mskd(state[i][j], d)
  return state

In [None]:
# Masked AES for d = 1
def masked_aes(text, keys, d):
  state = mk_matrix_2(text)
  state = add_round_key_mskd(state, mk_matrix_2(keys[0:16]), d)
  for i in range(1, 10):
    state = sub_bytes_masked(state, d)
    state = shiftrows(state)
    state = mix_columns_mskd(state, d)
    state = add_round_key_mskd(state, mk_matrix_2(keys[i*16:i*16+16]), d)
  state = sub_bytes_masked(state, d)
  state = shiftrows(state)
  state = add_round_key_mskd(state, mk_matrix_2(keys[160:176]), d)
  return state

In [None]:
# The operations of this AES implementation are done in matrixes, so if an input is given in an list
# I must parse it to a matrix in column major:
def mk_matrix_2(l):
    tmp = [[0]*4 for _ in range(4)]
    k = 0
    for c in range(4):       # columns first
        for r in range(4):   # then rows
            tmp[r][c] = l[k]
            k += 1
    return tmp


In [None]:
# TEST 1 From AESAVS.doc: 3ad78e726c1ec02b7ebfe92b23d9ec34
key = [0x0]*16
exp_key = key_exp(key)
exp_key = mask_exp_keys(exp_key)

text = [[0x80 ^ 0x1, 0x1], [0x0 ^ 0x1, 0x1], [0x0 ^ 0x1, 0x1], [0x0 ^ 0x1, 0x1],
       [0x0 ^ 0x2, 0x2], [0x0 ^ 0x2, 0x2], [0x0 ^ 0x2, 0x2], [0x0 ^ 0x2, 0x2],
       [0x0 ^ 0x3, 0x3], [0x0 ^ 0x3, 0x3], [0x0 ^ 0x3, 0x3], [0x0 ^ 0x3, 0x3],
       [0x0 ^ 0x4, 0x4], [0x0 ^ 0x4, 0x4], [0x0 ^ 0x4, 0x4], [0x0 ^ 0x4, 0x4]]

cipher = masked_aes(text, exp_key, 1)
out = []
for i in range(4):
    for j in range(4):
        out.append(cipher[j][i][0] ^ cipher[j][i][1])  # careful with order
print("Ciphertext:", "".join(f"{b:02x}" for b in out))

Ciphertext: 3ad78e726c1ec02b7ebfe92b23d9ec34


In [None]:
# TEST 2 From AESAVS.doc: 79bf5dce14bb7dd73a8e3611de7ce026
key = [0x00]*16
exp_key = key_exp(key)

text = [[0xff ^ 0x1, 0x1], [0xff ^ 0x1, 0x1], [0xff ^ 0x1, 0x1], [0xff ^ 0x1, 0x1],
        [0xff ^ 0x2, 0x2], [0xff ^ 0x2, 0x2], [0xff ^ 0x2, 0x2], [0xff ^ 0x2, 0x2],
        [0xff ^ 0x3, 0x3], [0xff ^ 0x3, 0x3], [0xff ^ 0x3, 0x3], [0xff ^ 0x3, 0x3],
        [0xfc ^ 0x4, 0x4], [0x0 ^ 0x4, 0x4], [0x0 ^ 0x4, 0x4], [0x0 ^ 0x4, 0x4]]
exp_key = mask_exp_keys(exp_key)
cipher = masked_aes(text, exp_key, 1)
out = []
for i in range(4):
    for j in range(4):
        out.append(cipher[j][i][0] ^ cipher[j][i][1])  # careful with order
print("Ciphertext:", "".join(f"{b:02x}" for b in out))

Ciphertext: 79bf5dce14bb7dd73a8e3611de7ce026


In [None]:
# TEST 3 From AESAVS.doc: 6d251e6944b051e04eaa6fb4dbf78465
key = [0x10, 0xa5, 0x88, 0x69, 0xd7, 0x4b, 0xe5, 0xa3,
           0x74, 0xcf, 0x86, 0x7c, 0xfb, 0x47, 0x38, 0x59]
exp_key = key_exp(key)

text = [[0x0 ^ 0x1, 0x1], [0x0 ^ 0x1, 0x1], [0x0 ^ 0x1, 0x1], [0x0 ^ 0x1, 0x1],
        [0x0 ^ 0x2, 0x2], [0x0 ^ 0x2, 0x2], [0x0 ^ 0x2, 0x2], [0x0 ^ 0x2, 0x2],
        [0x0 ^ 0x3, 0x3], [0x0 ^ 0x3, 0x3], [0x0 ^ 0x3, 0x3], [0x0 ^ 0x3, 0x3],
        [0x0 ^ 0x4, 0x4], [0x0 ^ 0x4, 0x4], [0x0 ^ 0x4, 0x4], [0x0 ^ 0x4, 0x4]]
exp_key = mask_exp_keys(exp_key)
cipher = masked_aes(text, exp_key, 1)
out = []
for i in range(4):
    for j in range(4):
        out.append(cipher[j][i][0] ^ cipher[j][i][1])  # careful with order
print("Ciphertext:", "".join(f"{b:02x}" for b in out))

Ciphertext: 6d251e6944b051e04eaa6fb4dbf78465
