In [18]:
import numpy as np 
import sympy as sym
from typing import Tuple

In [68]:
class Flat_Layout: 
  def __init__ (self, input_symbol:str, middle_symbol:str, marker:str, shape, stride): 
    assert len(shape) == len(stride) 
    self.length = len(shape)
    self.size = np.prod(shape).item()
    self.name = marker
    self.middle_sym = middle_symbol
    self.shape = [sym.Integer(s) for s in shape] 
    self.stride = [sym.Integer(d) for d in stride]
    self.co_size = (1  + np.sum(np.array(stride)*(np.array(shape)-1))).item()
    self.S = [sym.Symbol(f"s^{self.name}_{i}") for i in range(self.length)]
    self.D = [sym.Symbol(f"d^{self.name}_{i}") for i in range(self.length)]
    self.input_sym = sym.Symbol(input_symbol)
    self.SS = [sym.Integer(1) for _ in range(self.length)]
    for i in range(1, self.length): 
      self.SS[i] = self.S[i-1]*self.SS[i-1]
    
    self.colex_inv_sym = tuple([(self.input_sym // self.SS[i]) % self.S[i] for i in range(self.length)])
    self.co_ordinate_tuple = tuple([sym.Symbol(f"{self.middle_sym}^{self.name}_{i}") for i in range(self.length)])
    self.co_ordinate_map_sym = sum([self.co_ordinate_tuple[i]*self.D[i] for i in range(self.length)])
    self.layout_map_sym = self._substitute(self.co_ordinate_map_sym, self.colex_inv_sym, self.co_ordinate_tuple)
    
  def _substitute (self, expr: sym.Expr, new_symbols: Tuple[sym.Expr,...], old_symbols: Tuple[sym.Expr,...]): 
    new_expr = expr 
    for i in range(len(new_symbols)): 
      new_expr = new_expr.subs(old_symbols[i], new_symbols[i])
      
    return new_expr
  
  def realize(self): 
    
    self.colex_inv = []
  
    for i in range(self.length): 
      self.colex_inv.append(self._substitute(self.colex_inv_sym[i], self.shape, self.S))
      
    self.co_ordinate_map = self._substitute(self.co_ordinate_map_sym, self.stride, self.D)
    self.layout_map = self._substitute(self.layout_map_sym, self.shape, self.S)
    self.layout_map = self._substitute(self.layout_map, self.stride, self.D)
    
    

In [78]:
shape = (3,4,2,7)
stride = (1,13,9,5)
col_stride = (1,3,12,24)
x,y,z,w = (0,1,3,2)

L = Flat_Layout("x", "y", "", shape, stride)
L_perm = Flat_Layout("v", "w", "p", (shape[x], shape[y], shape[z], shape[w]), 
                     (stride[x], stride[y], stride[z], stride[w]))

col_perm = Flat_Layout("t", "h", "c",  (shape[x], shape[y], shape[z], shape[w]),
                       (col_stride[x], col_stride[y], col_stride[z], col_stride[w]))

In [80]:
L.realize()
L_perm.realize()
col_perm.realize()

In [81]:
phi_L = L.layout_map
phi_col_perm = col_perm.layout_map 
phi_L_perm = L_perm.layout_map

In [82]:
phi_L

5*(Mod(floor(x/24), 7)) + 9*(Mod(floor(x/12), 2)) + 13*(Mod(floor(x/3), 4)) + Mod(floor(x), 3)

In [83]:
phi_col_perm

12*(Mod(floor(t/84), 2)) + 24*(Mod(floor(t/12), 7)) + 3*(Mod(floor(t/3), 4)) + Mod(floor(t), 3)

In [84]:
phi_L_perm

9*(Mod(floor(v/84), 2)) + 5*(Mod(floor(v/12), 7)) + 13*(Mod(floor(v/3), 4)) + Mod(floor(v), 3)

In [95]:
old_symbol = tuple([L.input_sym]) 
new_symbol = tuple([phi_col_perm])
expression = phi_L

In [99]:
X = L._substitute(expression, new_symbol, old_symbol)
X

5*(Mod(floor((Mod(floor(t/84), 2))/2 + Mod(floor(t/12), 7) + (Mod(floor(t/3), 4))/8 + (Mod(floor(t), 3))/24), 7)) + 9*(Mod(floor(Mod(floor(t/84), 2) + 2*(Mod(floor(t/12), 7)) + (Mod(floor(t/3), 4))/4 + (Mod(floor(t), 3))/12), 2)) + 13*(Mod(floor(4*(Mod(floor(t/84), 2)) + 8*(Mod(floor(t/12), 7)) + Mod(floor(t/3), 4) + (Mod(floor(t), 3))/3), 4)) + Mod(floor(12*(Mod(floor(t/84), 2)) + 24*(Mod(floor(t/12), 7)) + 3*(Mod(floor(t/3), 4)) + Mod(floor(t), 3)), 3)

In [97]:
def evaluate (expression, domain, input_symbol): 
  np_domain = np.arange(domain) 
  expression_lambda = sym.lambdify(input_symbol, expression, "numpy")
  return expression_lambda(np_domain).astype(int)
    
  

In [100]:
X_eval = evaluate(X, L.size, col_perm.input_sym)

In [101]:
X_eval

array([ 0,  1,  2, 13, 14, 15, 26, 27, 28, 39, 40, 41,  5,  6,  7, 18, 19,
       20, 31, 32, 33, 44, 45, 46, 10, 11, 12, 23, 24, 25, 36, 37, 38, 49,
       50, 51, 15, 16, 17, 28, 29, 30, 41, 42, 43, 54, 55, 56, 20, 21, 22,
       33, 34, 35, 46, 47, 48, 59, 60, 61, 25, 26, 27, 38, 39, 40, 51, 52,
       53, 64, 65, 66, 30, 31, 32, 43, 44, 45, 56, 57, 58, 69, 70, 71,  9,
       10, 11, 22, 23, 24, 35, 36, 37, 48, 49, 50, 14, 15, 16, 27, 28, 29,
       40, 41, 42, 53, 54, 55, 19, 20, 21, 32, 33, 34, 45, 46, 47, 58, 59,
       60, 24, 25, 26, 37, 38, 39, 50, 51, 52, 63, 64, 65, 29, 30, 31, 42,
       43, 44, 55, 56, 57, 68, 69, 70, 34, 35, 36, 47, 48, 49, 60, 61, 62,
       73, 74, 75, 39, 40, 41, 52, 53, 54, 65, 66, 67, 78, 79, 80])

In [93]:
L_perm_eval = evaluate(phi_L_perm, L.size, L_perm.input_sym)

In [94]:
L_perm_eval

array([ 0,  1,  2, 13, 14, 15, 26, 27, 28, 39, 40, 41,  5,  6,  7, 18, 19,
       20, 31, 32, 33, 44, 45, 46, 10, 11, 12, 23, 24, 25, 36, 37, 38, 49,
       50, 51, 15, 16, 17, 28, 29, 30, 41, 42, 43, 54, 55, 56, 20, 21, 22,
       33, 34, 35, 46, 47, 48, 59, 60, 61, 25, 26, 27, 38, 39, 40, 51, 52,
       53, 64, 65, 66, 30, 31, 32, 43, 44, 45, 56, 57, 58, 69, 70, 71,  9,
       10, 11, 22, 23, 24, 35, 36, 37, 48, 49, 50, 14, 15, 16, 27, 28, 29,
       40, 41, 42, 53, 54, 55, 19, 20, 21, 32, 33, 34, 45, 46, 47, 58, 59,
       60, 24, 25, 26, 37, 38, 39, 50, 51, 52, 63, 64, 65, 29, 30, 31, 42,
       43, 44, 55, 56, 57, 68, 69, 70, 34, 35, 36, 47, 48, 49, 60, 61, 62,
       73, 74, 75, 39, 40, 41, 52, 53, 54, 65, 66, 67, 78, 79, 80])