<a href="https://colab.research.google.com/github/alialdakheel/BlockCipher_exercise/blob/main/BC_cm3_exercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [36]:
"""
  Cipher model 3: SPN cipher with half linear layer
  Use 4 Sboxes each 4x4 and a linear layer on half
  the inputs to build a 16-bit block cipher
"""
class CM3:
  """
    @ Param:
      sbox: sepcify the sbox to be used (a list)
      lfx: a function specifing the linear layer (on half input: 8-bits)
      r: number of rounds
  """
  def __init__(self, sbox, lfx, inv_lfx, r):
    self.sbox = sbox
    self.inv_sbox = [sbox.index(i) for i in range(len(sbox))]
    self.lfx = lfx
    self.inv_lfx = inv_lfx
    self.r = r

  def _bytestolist(self, b):
    return [b[0] >> 4, b[0] & 15, b[1] >> 4, b[1] & 15]
  
  def _listtobytes(self, l):
    return bytes([(l[0] << 4) + l[1], (l[2] << 4) + l[3]])

  def _enc_round(self, s, k):
    sub_out = self._listtobytes([self.sbox[w] for w in self._bytestolist(s)])
    lfx_out = bytes([self.lfx(sub_out[0]), sub_out[1]])
    return bytes([w ^ kw for w, kw in zip(lfx_out, k)])

  def _dec_round(self, s, k):
    lfx_out = bytes([self.inv_lfx(s[0]), s[1]])
    sub_out = self._listtobytes([self.inv_sbox[w] for w in self._bytestolist(lfx_out)])
    return bytes([w ^ kw for w, kw in zip(sub_out, k)])

  def enc(self, p, k):
    assert(len(k) == (self.r + 1) * 2)
    sk = k[0:2]
    s = bytes([w ^ kw for w, kw in zip(p, sk)])
    for i in range(self.r):
      rk = k[i*2+2:i*2+4]
      s = self._enc_round(s, rk)
    return s

  def dec(self, c, k):
    assert(len(k) == (self.r + 1) * 2)
    sk = k[-2:]
    s = bytes([w ^ kw for w, kw in zip(c, sk)])
    for i in range(self.r):
      rk = k[-i*2-4:-i*2-2]
      s = self._dec_round(s, rk)
    return s

In [37]:
sbox = [1,2,3,0,4,5,7,6,8,9,11,10,12,13,15,14]
cm3 = CM3(sbox, lambda x: x, lambda x: x, 2)

In [38]:
p = 'he'
b = p.encode()

In [39]:
k = 'helloo'.encode()

In [41]:
c = cm3.enc(b, k)

In [42]:
cm3.dec(c, k)

b'he'