In [24]:
import sympy as sym 
import numpy as np
from typing import Tuple

In [25]:
#let us use cutlass notation and _ at the end of symbols 
x, y = sym.symbols('x, y')


In [26]:
f_sym = x**2 + sym.sin(y)

In [27]:
f_realized = sym.lambdify([x,y], f_sym, 'numpy')

In [28]:
X = np.random.randn(10)
Y = np.random.randn(1) 
g = f_realized(X,Y)
g

array([0.46789551, 0.49895757, 4.66119282, 3.87078259, 1.63333665,
       1.76295222, 0.47804541, 0.54590916, 1.32977942, 0.49773147])

In [29]:
X 

array([ 0.08349935, -0.19502359,  2.04945589,  1.84658041, -1.08278035,
       -1.14106479, -0.1308512 , -0.29152323,  0.93212448,  0.19185436])

In [30]:
Y

array([0.4790354])

In [86]:
class Layout: 
  def __init__ (self, shape: Tuple[sym.Symbol,...], stride: Tuple[sym.Symbol,...], 
                kind: str, input_symbol: sym.Symbol):
    
    assert len(shape) == len(stride) 
    self.m = len(shape)
    self.shape_stride = [sym.Integer(1)]*self.m
    self.shape = shape 
    self.stride = stride
    

    for i in reversed(range(self.m-1)): 
      self.shape_stride[i] = self.shape[i]*self.shape_stride[i+1]
      
    self.shape_stride = tuple(self.shape_stride)  
    self.I = input_symbol

    self.fan_out = [0]*self.m 
    for i in range(self.m): 
      self.fan_out[i] = (self.I // self.shape_stride[i]) % self.shape[i]
    
    self.fan_out = tuple(self.fan_out)
    
    self.fan_in = sym.Integer(0) 
    self.fan_out_sym = tuple([sym.Symbol(f"y_{str(i)}") for i in range(self.m)])
    for i in range(self.m):
      self.fan_in += self.stride[i]*self.fan_out_sym[i]
      
    self.layout = self._substitute(self.fan_in, self.fan_out, self.fan_out_sym)
    
    
  def _substitute (self, expr: sym.Expr, new_symbols: Tuple[sym.Expr,...], old_symbols: Tuple[sym.Expr,...]): 
    new_expr = expr 
    for i in range(len(new_symbols)): 
      new_expr = new_expr.subs(old_symbols[i], new_symbols[i])
      
    return new_expr



In [87]:
shape = (sym.symbols('M'), sym.symbols('N')) 
stride = (sym.symbols('P'), sym.symbols('Q'))
kind = "mem" 
input_symbol = sym.symbols('x')

L = Layout(shape, stride, kind, input_symbol)


In [88]:
L.shape_stride

(M, 1)

In [89]:
L.fan_out

(Mod(floor(x/M), M), Mod(floor(x), N))

In [90]:
L.fan_in

P*y_0 + Q*y_1

In [91]:
L.layout

P*(Mod(floor(x/M), M)) + Q*(Mod(floor(x), N))

In [176]:
class Layout: 
  def __init__ (self, shape: Tuple[int,...], stride: Tuple[int,...], 
                kind: str):
    
    assert len(shape) == len(stride) 
    self.m = len(shape)
    self.shape = tuple([sym.Integer(shape[i]) for i in range(self.m)])
    self.stride = tuple([sym.Integer(stride[i]) for i in range(self.m)])
    self.S = tuple([sym.Symbol(f"S_{str(i)}") for i in range(self.m)])
    self.S_S = [sym.Integer(1)]*self.m
    self.D = tuple([sym.Symbol(f"D_{str(i)}") for i in range(self.m)])
    self.I = sym.Symbol("x")
    self.N_elems = 1 
    for i in range(self.m): 
      self.N_elems *= shape[i]
      
    
    for i in reversed(range(self.m-1)): 
      self.S_S[i] = self.S[i+1]*self.S_S[i]
      
    self.S_S = tuple(self.S_S)  
    

    self.fan_out = [0]*self.m 
    for i in range(self.m): 
      self.fan_out[i] = (self.I // self.S_S[i]) % self.S[i]
    
    self.fan_out = tuple(self.fan_out)
    
    self.fan_in = sym.Integer(0) 
    self.fan_out_sym = tuple([sym.Symbol(f"y_{str(i)}") for i in range(self.m)])
    for i in range(self.m):
      self.fan_in += self.D[i]*self.fan_out_sym[i]
      
    self.layout_sym = self._substitute(self.fan_in, self.fan_out, self.fan_out_sym)
    self.layout = self._substitute(self.layout_sym, self.shape, self.S)
    self.layout = self._substitute(self.layout, self.stride, self.D)
    self.realized_layout = "un_realized"
    
    
  
  def _substitute (self, expr: sym.Expr, new_symbols: Tuple[sym.Expr,...], old_symbols: Tuple[sym.Expr,...]): 
    new_expr = expr 
    for i in range(len(new_symbols)): 
      new_expr = new_expr.subs(old_symbols[i], new_symbols[i])
      
    return new_expr
  
  def realize(self): 
    one_d_domain = np.arange(self.N_elems) 
    layout_lambda = sym.lambdify([self.I], self.layout, "numpy")
    self.realized_layout = layout_lambda(one_d_domain).astype(int)
    
   



In [177]:
shape = (16,4,2)
stride = (17,9,1)
L = Layout(shape, stride, "mem")

In [178]:
L.S

(S_0, S_1, S_2)

In [190]:
L.fan_out
L.fan_in
L.layout_sym

D_0*(Mod(floor(x/S_1), S_0)) + D_1*(Mod(floor(x/S_2), S_1)) + D_2*(Mod(floor(x), S_2))

In [180]:
L.realized_layout

'un_realized'

In [181]:
L.realize()

In [182]:
L.realized_layout

array([  0,   1,   9,  10,  35,  36,  44,  45,  34,  35,  43,  44,  69,
        70,  78,  79,  68,  69,  77,  78, 103, 104, 112, 113, 102, 103,
       111, 112, 137, 138, 146, 147, 136, 137, 145, 146, 171, 172, 180,
       181, 170, 171, 179, 180, 205, 206, 214, 215, 204, 205, 213, 214,
       239, 240, 248, 249, 238, 239, 247, 248, 273, 274, 282, 283,   0,
         1,   9,  10,  35,  36,  44,  45,  34,  35,  43,  44,  69,  70,
        78,  79,  68,  69,  77,  78, 103, 104, 112, 113, 102, 103, 111,
       112, 137, 138, 146, 147, 136, 137, 145, 146, 171, 172, 180, 181,
       170, 171, 179, 180, 205, 206, 214, 215, 204, 205, 213, 214, 239,
       240, 248, 249, 238, 239, 247, 248, 273, 274, 282, 283])

In [200]:

class Swizzle:
    
  BitAnd = sym.Function('BitAnd')
  BitOr  = sym.Function('BitOr')
  BitXor = sym.Function('BitXor')
  RShift = sym.Function('RShift')
  LShift = sym.Function('LShift')
  BitNot = sym.Function('BitNot')
  Max = sym.Max
  Min = sym.Min

  def __init__(self, m_base: int, b_bits: int, s_shift: int, N_elems: int):
      
    assert m_base >= 0 
    assert b_bits >= 0 
    assert abs(s_shift) >= b_bits 
    self.N_elems = N_elems
    
    self.m_base = sym.Integer(m_base)
    self.b_bits = sym.Integer(b_bits)
    self.s_shift = sym.Integer(s_shift)
    
    self.b = sym.Symbol('b', integer=True) 
    self.m = sym.Symbol('m', integer=True)
    self.s = sym.Symbol('s', integer=True)
    self.x = sym.Symbol('x', integer=True)

    LShift, BitAnd, BitOr = Swizzle.LShift, Swizzle.BitAnd, Swizzle.BitOr
    BitXor, RShift = Swizzle.BitXor, Swizzle.RShift
    Max, Min = Swizzle.Max, Swizzle.Min

    one = sym.Integer(1) 
    zero = sym.Integer(0)

    base_mask = (LShift(one, self.b)) - 1
    src_shift = self.m + Max(zero, self.s)
    dst_shift = self.m - Min(zero, self.s)
    
    src_mask = LShift(base_mask, src_shift) 
    dst_mask = LShift(base_mask, dst_shift)
    
    bits_to_move = BitAnd(self.x, src_mask)
    
    self.swizzle_mask = BitOr(src_mask, dst_mask)
    
    expr_pos_s = BitXor(self.x, RShift(bits_to_move, self.s))
    expr_neg_s = BitXor(self.x, LShift(bits_to_move, -self.s))

    self.swizzle_map_sym = sym.Piecewise(
        (expr_pos_s, self.s >= 0),
        (expr_neg_s, True)
    )
    self.swizzle_map = self._substitute(self.swizzle_map_sym, (self.m_base, self.b_bits, self.s_shift), (self.m, self.b, self.s))
    self.realized_swizzle_map = "un_realized"
    

  def __repr__(self):
    return (
        f"Swizzle(\n"
        f"  symbols = ({self.b}, {self.m}, {self.s}, {self.x}),\n"
        f"  mask    = {self.swizzle_mask},\n"
        f"  map     = {self.swizzle_map_sym}\n"
        f")"
    )
      
  def _get_realization_map(self):
    import numpy as np
    return {
        'BitAnd': np.bitwise_and,
        'BitOr':  np.bitwise_or,
        'BitXor': np.bitwise_xor,
        'RShift': np.right_shift,
        'LShift': np.left_shift,
        'BitNot': np.invert,
        'Max':    np.maximum,
        'Min':    np.minimum
    }
  
  def _substitute (self, expr: sym.Expr, new_symbols: Tuple[sym.Expr,...], old_symbols: Tuple[sym.Expr,...]): 
    new_expr = expr 
    for i in range(len(new_symbols)): 
      new_expr = new_expr.subs(old_symbols[i], new_symbols[i])
      
    return new_expr
  
  def realize(self):
    ops = self._get_realization_map()
    func_lamb = sym.lambdify(
            self.x, 
            self.swizzle_map, 
            modules=['numpy', ops]
        )
    arr = np.arange(self.N_elems)
    self.realized_swizzle_map = func_lamb(arr)
  
    



In [221]:
A = Swizzle(0,1,-2, 64)

In [222]:
A.swizzle_map
A.realize()

In [223]:
x = A.realized_swizzle_map
y = np.arange(64)
print(y-x)

[ 0 -4  0 -4  0  4  0  4  0 -4  0 -4  0  4  0  4  0 -4  0 -4  0  4  0  4
  0 -4  0 -4  0  4  0  4  0 -4  0 -4  0  4  0  4  0 -4  0 -4  0  4  0  4
  0 -4  0 -4  0  4  0  4  0 -4  0 -4  0  4  0  4]
