In [106]:
import math


class Box:
    def __init__(self, mapping: dict):
        self.mapping = mapping
        width = math.log(len(mapping), 2)
        if int(width) != width:
            raise ValueError('Mapping not a multiple of 2')
        self.width = int(width)
            
        self.is_complete()
    
    def is_complete(self):
        expected = set(range(self.width ** 2))
        
        if expected != set(self.mapping.keys()):
            raise ValueError('Incorrect key set')
        
        if expected != set(self.mapping.values()):
            raise ValueError('Incorrect values')
            
    def evaluate(self, value):
        return self.mapping[value]
    
    def __repr__(self):
        return f'<{self.__class__.__name__}: Width {self.width}>'
    

class SBox(Box):
    pass


class PBox(Box):
    pass


class Round:
    def __init__(self, sboxes: list, pbox: PBox):
        self.sboxes = sboxes
        self.pbox = pbox
        self._input_width = sum(b.width for b in self.sboxes)
        self._output_width = self.pbox.width
    
        if self._input_width != self._output_width:
            raise ValueError('Sbox and Pbox width mismatch')
    
    def evaluate(self, value):
        if value > self._input_width ** 2:
            raise ValueError(f'Value can be be bigger than {self._input_width ** 2}')
        
        value = self._evaluate_sboxes(value)
        return self._evaluate_pbox(value)
            
    def _evaluate_sboxes(self, value):
        ret_value = 0
        used_width = 0
        
        for i, sbox in enumerate(reversed(self.sboxes)):
            # create mask the size of the sbox's width
            mask = (2 ** sbox.width) - 1

            val = sbox.evaluate(value & mask)

            # shift value up by running width adn th
            ret_value |= (val << used_width)
            
            # move value down so that next box can review LSB
            value >>= sbox.width
            
            used_width += sbox.width
        
        return ret_value

    def _evaluate_pbox(self, value):
        return self.pbox.evaluate(value)
    
    def __repr__(self):
        return f'<{self.__class__.__name__}: {self._input_width}/{self._output_width}>'

In [107]:
import random

def get_pbox_mapping(width, seed):
    keys = list(range(2 ** width))
    random.seed(seed)
    values = random.sample(keys, len(keys))
    random.seed()
    return dict(zip(keys, values))

In [108]:
def evaluate(value, rounds):
    s1 = {  # 2 bits
        0: 1,
        1: 2,
        2: 3,
        3: 0, 
    }
    
    p1 = get_pbox_mapping(width=4, seed=0)
    
    r = Round([SBox(s1), SBox(s1)], PBox(p1))    
    for _ in range(rounds):    
        value = r.evaluate(value)
    return value

In [109]:
width = 4
max_value = 2 ** width
rounds = 4
keys = list(range(max_value))
values = [evaluate(k, rounds) for k in keys]

plot = figure(
    y_range=(-1, max_value), 
    plot_width=400,
    plot_height=400,
)

plot.scatter(keys, values)

show(plot)

In [114]:
import numpy as np

from bokeh.layouts import column, row
from bokeh.models import CustomJS, Slider
from bokeh.plotting import ColumnDataSource, figure, output_notebook, show

output_notebook()

width = 4
max_value = 2 ** width
keys = list(range(max_value))

max_rounds = 10

data = {'x': keys}

for rounds in range(1, max_rounds + 1):
    data[str(rounds)] = [evaluate(k, rounds) for k in keys]

data['y'] = list(data['1'])

source = ColumnDataSource(
    data=data
)

plot = figure(
    y_range=(-1, max_value), 
    plot_width=600,
    plot_height=600,
)

plot.scatter('x', 'y', source=source, size=10)

slider = Slider(start=1, end=max_rounds, value=1, step=1, title="Rounds")

callback = CustomJS(args=dict(source=source, slider=slider),
                    code="""
    const data = source.data;
    const v = slider.value;
    const x = data['x'];
    const y = data['y'];
    const r = data[v];

    for (var i = 0; i < x.length; i++) {
        y[i] = r[i];
    }
    source.change.emit();
""")

slider.js_on_change('value', callback)

layout = row(
    plot,
    column(slider),
)

show(layout)