In [3]:
import numpy as np 
import sympy as sp
from typing import Tuple
from typing import List

In [4]:
class Flat_layout: 
  def __init__ (self, shape: Tuple[int], stride: Tuple[int]): 
    self.rank = len(shape) 
    assert len(shape) == len(stride)
    self.size = np.prod(shape).item() 
    self.co_size = (np.sum((np.array(shape)-1)*np.array(stride)) + 1).item()
    self.shape = shape 
    self.stride = stride 
    self.colex_stride = tuple([np.prod(np.array(shape)[0:i]).item() for i in range(0,self.rank)]) #np already defines empty product as 1. 
    self.colex_map = lambda t: sum(tuple(self.colex_stride[i]*t[i] for i in range(self.rank)))
    self.colex_inv_map = lambda x: tuple((x//self.colex_stride[i])%self.shape[i] for i in range(self.rank))
    self.co_ordinate_map = lambda t: sum(tuple(self.stride[i]*t[i] for i in range(self.rank)))
    self.layout_map = lambda x: self.co_ordinate_map(self.colex_inv_map(x))
    
    self.domain = np.arange(self.size) 
    self.co_ordinate_tensor = np.moveaxis(np.indices(self.shape), 0, -1)
    self.range = np.arange(self.co_size)
    
  def realize(self): 
    if (not self.shape): 
      self.co_ordinate_map_tensor = np.empty((0,))
      self.layout_map_array = np.empty((0,))
      return
    self.co_ordinate_map_tensor = np.dot(self.co_ordinate_tensor, np.array(self.stride))
    ss_array = np.array(self.colex_stride).reshape(self.rank,1)
    s_array = np.array(self.shape).reshape(self.rank,1) 
    d_array = np.array(self.stride).reshape(self.rank,1)
    self.layout_map_array = np.sum(((self.domain.reshape(1,self.size)//ss_array) % s_array)*d_array, axis = 0)
    
  def __repr__(self): 
    return f"{self.shape}:{self.stride}\n"
  
    
      

In [5]:
F = Flat_layout((3,7,2),(2,5,9))
F.realize()

In [6]:
F.co_ordinate_map_tensor

array([[[ 0,  9],
        [ 5, 14],
        [10, 19],
        [15, 24],
        [20, 29],
        [25, 34],
        [30, 39]],

       [[ 2, 11],
        [ 7, 16],
        [12, 21],
        [17, 26],
        [22, 31],
        [27, 36],
        [32, 41]],

       [[ 4, 13],
        [ 9, 18],
        [14, 23],
        [19, 28],
        [24, 33],
        [29, 38],
        [34, 43]]])

In [7]:
#lets try to find the layout that tiles a 8x8 row major matrix as (2x2) tile shape, 

L_8_8_row_major = Flat_layout((8,8),(8,1))

In [8]:
L_8_8_row_major.realize()

In [9]:
co_ordinate_tensor = L_8_8_row_major.co_ordinate_map_tensor

In [10]:
co_ordinate_tensor

array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 8,  9, 10, 11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [32, 33, 34, 35, 36, 37, 38, 39],
       [40, 41, 42, 43, 44, 45, 46, 47],
       [48, 49, 50, 51, 52, 53, 54, 55],
       [56, 57, 58, 59, 60, 61, 62, 63]])

In [11]:
L_2_2_tiler_row_major = Flat_layout((2,2),(2,1))


In [12]:
#we say that a layout $A$ "tile divides" a layout $B$ if there exists some $C$ such that coal(A concat C) = B
#where the process of coalescing is (si, si+1):(di, sidi) --> (si*si+1):(di)

  
def coalesce(layout:Flat_layout): 
  shape = layout.shape 
  stride = layout.stride 
  m = layout.rank
  curr_shape,curr_stride = shape[0], stride[0]
  new_shape, new_stride = [], []
  for i in range(1,m): 
    if stride[i] == curr_shape*curr_stride:
      curr_shape*= shape[i]
    else: 
      new_shape.append(curr_shape)
      new_stride.append(curr_stride)
      curr_shape = shape[i]
      curr_stride = stride[i]
    
  new_shape.append(curr_shape)
  new_stride.append(curr_stride)
  
  return Flat_layout(tuple(new_shape), tuple(new_stride))
       
       
def concat(A:Flat_layout, B:Flat_layout): 
  return Flat_layout(A.shape + B.shape, A.stride + B.stride)    

  
    
   

In [68]:
S = (3,2,2,4,5)
D = (1,4,8,15,60)

In [69]:
H = Flat_layout(S,D)

In [70]:
H.realize()

In [71]:
X = H.layout_map_array

In [72]:
X

array([  0,   1,   2,   4,   5,   6,   8,   9,  10,  12,  13,  14,  15,
        16,  17,  19,  20,  21,  23,  24,  25,  27,  28,  29,  30,  31,
        32,  34,  35,  36,  38,  39,  40,  42,  43,  44,  45,  46,  47,
        49,  50,  51,  53,  54,  55,  57,  58,  59,  60,  61,  62,  64,
        65,  66,  68,  69,  70,  72,  73,  74,  75,  76,  77,  79,  80,
        81,  83,  84,  85,  87,  88,  89,  90,  91,  92,  94,  95,  96,
        98,  99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113,
       114, 115, 117, 118, 119, 120, 121, 122, 124, 125, 126, 128, 129,
       130, 132, 133, 134, 135, 136, 137, 139, 140, 141, 143, 144, 145,
       147, 148, 149, 150, 151, 152, 154, 155, 156, 158, 159, 160, 162,
       163, 164, 165, 166, 167, 169, 170, 171, 173, 174, 175, 177, 178,
       179, 180, 181, 182, 184, 185, 186, 188, 189, 190, 192, 193, 194,
       195, 196, 197, 199, 200, 201, 203, 204, 205, 207, 208, 209, 210,
       211, 212, 214, 215, 216, 218, 219, 220, 222, 223, 224, 22

In [73]:
Y = np.unique(X)

In [74]:
Y

array([  0,   1,   2,   4,   5,   6,   8,   9,  10,  12,  13,  14,  15,
        16,  17,  19,  20,  21,  23,  24,  25,  27,  28,  29,  30,  31,
        32,  34,  35,  36,  38,  39,  40,  42,  43,  44,  45,  46,  47,
        49,  50,  51,  53,  54,  55,  57,  58,  59,  60,  61,  62,  64,
        65,  66,  68,  69,  70,  72,  73,  74,  75,  76,  77,  79,  80,
        81,  83,  84,  85,  87,  88,  89,  90,  91,  92,  94,  95,  96,
        98,  99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113,
       114, 115, 117, 118, 119, 120, 121, 122, 124, 125, 126, 128, 129,
       130, 132, 133, 134, 135, 136, 137, 139, 140, 141, 143, 144, 145,
       147, 148, 149, 150, 151, 152, 154, 155, 156, 158, 159, 160, 162,
       163, 164, 165, 166, 167, 169, 170, 171, 173, 174, 175, 177, 178,
       179, 180, 181, 182, 184, 185, 186, 188, 189, 190, 192, 193, 194,
       195, 196, 197, 199, 200, 201, 203, 204, 205, 207, 208, 209, 210,
       211, 212, 214, 215, 216, 218, 219, 220, 222, 223, 224, 22

In [75]:
X.shape

(240,)

In [76]:
Y.shape

(240,)

In [77]:
H = Flat_layout((3,4),(1,3))

In [78]:
H.realize()

In [79]:
H.layout_map_array

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])