In [None]:
import numpy as np 

class layout: 
  def __init__ (self, m: int, shape: tuple[int, ...], stride: tuple[int, ...]): 
    if (len(shape) != m) or (len(stride) != m): 
      raise ValueError("m is the dimensionality, both the shape and stride must be an m-tuple")
    for i in range(m): 
      assert isinstance(shape[i], int)  
      assert isinstance(stride[i], int) 
      assert shape[i] > 1
      assert stride[i] > 0
      
    self.m = m 
    self.N = np.prod(shape).item()
    self.shape = shape
    self.stride = stride
    
    shape_stride = [1 for i in range(m)] 
    for i in reversed(range(m-1)): 
      shape_stride[i] = shape[i+1]*shape_stride[i+1]
      
    self.shape_stride = tuple(shape_stride)
    
    self.one_d_domain = np.arange(self.N).reshape(self.N, 1)
    self.md_domain = np.indices(self.shape)
    
    self.p_shape_stride = np.array([self.N] + list(self.shape_stride))
    self.P_one_d_to_md = (self.one_d_domain // self.p_shape_stride[1:].reshape(1,self.m)) % (self.p_shape_stride[:m].reshape(1,self.m))
    
    self.P_stride = np.array(self.stride).reshape(tuple([self.m] + [1 for _ in range (self.m)]))
    self.P_md_to_one_d = np.sum(self.md_domain * self.P_stride, axis = 0)

    self.layout = np.sum((self.P_one_d_to_md.T).reshape(tuple([m] + list(self.shape)))*self.P_stride, axis = 0)
    self.test_layout = [0 for _ in range(self.N)] 
    for i in range(self.N): 
      md_index = self.one_d_to_md(i)
      print(md_index)
      
    self.perm = [i for i in range(self.m)] 
    self.sorted_stride = stride
    for i in range(self.m): 
      for j in range(self.m):
        if self.sorted_stride[i] <= self.sorted_stride[j]: 
          p_i = self.perm[i]
          p_j = self.perm[j]
          s_i = self.sorted_stride[i]
          s_j = self.sorted_stride[j]
          self.perm[i], self.perm[j] = p_j, p_i
          self.sorted_stride[i], self.sorted_stride[j] = s_j, s_i
          
         
        
    
    
  def md_to_one_d (self, md_index): 
    assert len(md_index) == self.m 
    for i in range(self.m): 
      assert md_index[i] >= 0
      assert md_index[i] < self.shape[i]
    
    one_d_index = 0
    for i in range(self.m): 
      one_d_index += self.stride[i]*md_index[i]
    
    return one_d_index
  
  def one_d_to_md (self, one_d_index): 
    assert one_d_index >= 0 
    assert one_d_index < self.N 
    
    md_index = [0 for _ in range(self.m)]
    
    temp_shape_stride = tuple([self.N]) + self.shape_stride
    for i in range(self.m): 
      md_index[i] = ((one_d_index)//temp_shape_stride[i+1]) % temp_shape_stride[i]
      
    return tuple(md_index)

  def __repr__(self):
      return (f"N_elements = {self.N} \n"
              f"layout(m={self.m}, \n"
              f"       shape={self.shape}, \n"
              f"       stride={self.stride}, \n"
              f"       shape_stride = {self.shape_stride})")
      
  
  

In [3]:
shape = (3,4,2)
stride = (8,2,1)
m = 3

matrix = layout(m, shape, stride)
print(matrix.P_md_to_one_d)
print(matrix.P_md_to_one_d.shape)
print(matrix.P_one_d_to_md)
print(matrix.P_one_d_to_md.shape)

(0, 0, 0)
(0, 0, 1)
(0, 1, 0)
(0, 1, 1)
(0, 2, 0)
(0, 2, 1)
(0, 3, 0)
(0, 3, 1)
(1, 4, 0)
(1, 4, 1)
(1, 5, 0)
(1, 5, 1)
(1, 6, 0)
(1, 6, 1)
(1, 7, 0)
(1, 7, 1)
(2, 0, 0)
(2, 0, 1)
(2, 1, 0)
(2, 1, 1)
(2, 2, 0)
(2, 2, 1)
(2, 3, 0)
(2, 3, 1)
[[[ 0  1]
  [ 2  3]
  [ 4  5]
  [ 6  7]]

 [[ 8  9]
  [10 11]
  [12 13]
  [14 15]]

 [[16 17]
  [18 19]
  [20 21]
  [22 23]]]
(3, 4, 2)
[[0 0 0]
 [0 0 1]
 [0 1 0]
 [0 1 1]
 [0 2 0]
 [0 2 1]
 [0 3 0]
 [0 3 1]
 [1 4 0]
 [1 4 1]
 [1 5 0]
 [1 5 1]
 [1 6 0]
 [1 6 1]
 [1 7 0]
 [1 7 1]
 [2 0 0]
 [2 0 1]
 [2 1 0]
 [2 1 1]
 [2 2 0]
 [2 2 1]
 [2 3 0]
 [2 3 1]]
(24, 3)


In [4]:
print(matrix)

N_elements = 24 
layout(m=3, 
       shape=(3, 4, 2), 
       stride=(8, 2, 1), 
       shape_stride = (8, 2, 1))


In [5]:
q = matrix.one_d_to_md(21)

q

(2, 2, 1)

In [6]:
l = matrix.md_to_one_d(q)

In [7]:
l

21

In [8]:
shape = (3,4,2)
domain = np.indices(shape)

In [9]:
domain.shape

(3, 3, 4, 2)

In [10]:
matrix.layout

print(matrix.layout.shape)


(3, 4, 2)


In [None]:
import numpy as np 

class layout: 
  def __init__ (self, m: int, shape: tuple[int, ...], stride: tuple[int, ...]): 
    if (len(shape) != m) or (len(stride) != m): 
      raise ValueError("m is the dimensionality, both the shape and stride must be an m-tuple")
    assert shape[0] > 0 
    assert stride[0] > 0
    for i in range(1,m): 
      assert isinstance(shape[i], int)  
      assert isinstance(stride[i], int) 
      assert shape[i] > 1
      assert stride[i] > 0
      
    self.m = m 
    self.N = np.prod(shape).item()
    self.shape = np.array(shape)
    self.stride = np.array(stride)
    
    shape_stride = [1 for i in range(m)] 
    for i in reversed(range(m-1)): 
      shape_stride[i] = shape[i+1]*shape_stride[i+1]
      
    self.shape_stride = np.array(shape_stride)
    self.one_d_domain = np.arange(self.N)
    self.fan_map = (self.one_d_domain.reshape(self.N,1) // self.shape_stride.reshape(1, self.m)) % (self.shape.reshape(1, self.m))
    #the shape of the fan map is (N, m) the reason being, each one d element is a position in N (the outermost) dimension
    #the rest of the inner dimensions, is the literal m_tuple co-ordinate of that that one. that is: 
    #fan_map(i) (which is a tuple depicting co-oridnates)  =  tuple(self.fan_map[i]) (as a numpy slice)
    #the fan map itself, simply presents a row major layout ordering. 
    #the stride now, is going to induce the map from the m_tuple, back to the 1d. so composed, we get the layout
    self.layout = np.sum(self.fan_map * self.stride.reshape(1, self.m), axis = 1)
    self.bank_conflicts = -np.ones((self.N//32, 32)).astype(int)
    self.layout_bank_id = (self.layout.reshape(self.N//32, 32) % 32)
    for i in range(self.N//32): 
      for j in range(32): 
        self.bank_conflicts[i, self.layout_bank_id[i,j]] += 1
        
        
    self.perm = [i for i in range(self.m)] 
    self.sorted_stride = stride
    for i in range(self.m): 
      for j in range(self.m):
        if self.sorted_stride[i] <= self.sorted_stride[j]: 
          p_i = self.perm[i]
          p_j = self.perm[j]
          s_i = self.sorted_stride[i]
          s_j = self.sorted_stride[j]
          self.perm[i], self.perm[j] = p_j, p_i
          self.sorted_stride[i], self.sorted_stride[j] = s_j, s_i
    
  def __repr__(self):
    return (f"N_elements = {self.N} \n"
            f"layout(m={self.m}, \n"
            f"       shape={self.shape}, \n"
            f"       stride={self.stride}, \n"
            f"       shape_stride = {self.shape_stride})")
    
  

In [12]:
tensor = layout(4, (2,4,8,2), (50,2,4,4))

In [13]:
tensor.one_d_domain

tensor.layout.shape

(128,)

In [14]:
tensor.layout_bank_id

array([[ 0,  4,  4,  8,  8, 12, 12, 16, 16, 20, 20, 24, 24, 28, 28,  0,
         2,  6,  6, 10, 10, 14, 14, 18, 18, 22, 22, 26, 26, 30, 30,  2],
       [ 4,  8,  8, 12, 12, 16, 16, 20, 20, 24, 24, 28, 28,  0,  0,  4,
         6, 10, 10, 14, 14, 18, 18, 22, 22, 26, 26, 30, 30,  2,  2,  6],
       [18, 22, 22, 26, 26, 30, 30,  2,  2,  6,  6, 10, 10, 14, 14, 18,
        20, 24, 24, 28, 28,  0,  0,  4,  4,  8,  8, 12, 12, 16, 16, 20],
       [22, 26, 26, 30, 30,  2,  2,  6,  6, 10, 10, 14, 14, 18, 18, 22,
        24, 28, 28,  0,  0,  4,  4,  8,  8, 12, 12, 16, 16, 20, 20, 24]])

In [15]:
tensor.bank_conflicts

array([[ 1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,
         1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1],
       [ 1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,
         1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1],
       [ 1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,
         1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1],
       [ 1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,
         1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1]])

In [43]:
### For a given shape, when is a stride going to produce an injective map? 
### idea 1. the number of elements is same, that is stride product equals shape product
### that doesn't work. consider shape = (4,3,4,4) stride = (2,2,3,16)
### what if stride(i) <= row_major_stride(i)? 


S = 2*2*3*4*4
print(S)
shape = (4,3,4,4) 
row_major_stride = (48,16,4,1)
stride = (49,17,5,1)

L = layout(4, shape, stride)

192


In [45]:
print(L.layout)
print(L.layout.shape)
X = L.layout 
X_unique = np.unique(X)
print(X_unique.size == X.size) 
print(X_unique.shape)

[  0   1   2   3   5   6   7   8  10  11  12  13  15  16  17  18  17  18
  19  20  22  23  24  25  27  28  29  30  32  33  34  35  34  35  36  37
  39  40  41  42  44  45  46  47  49  50  51  52  49  50  51  52  54  55
  56  57  59  60  61  62  64  65  66  67  66  67  68  69  71  72  73  74
  76  77  78  79  81  82  83  84  83  84  85  86  88  89  90  91  93  94
  95  96  98  99 100 101  98  99 100 101 103 104 105 106 108 109 110 111
 113 114 115 116 115 116 117 118 120 121 122 123 125 126 127 128 130 131
 132 133 132 133 134 135 137 138 139 140 142 143 144 145 147 148 149 150
 147 148 149 150 152 153 154 155 157 158 159 160 162 163 164 165 164 165
 166 167 169 170 171 172 174 175 176 177 179 180 181 182 181 182 183 184
 186 187 188 189 191 192 193 194 196 197 198 199]
(192,)
False
(164,)
