In [1]:
from typing import Tuple, List, Iterator
import itertools
import numpy as np
import math
import numpy as np
from typing import Tuple, Union, Any
from enum import Enum, auto

In [2]:


class Bijection: 
  def __init__(self, length, bits, base, shift, scale, step): 
    # 1. Affine Safety: Scale must be coprime to length
    assert math.gcd(scale, length) == 1, "Scale must be coprime to length for bijectivity"
    assert shift >= 0 
    
    self.length = length
    self.scale = scale
    self.step = step
    self.shift = shift
    self.mask = ((1 << bits) - 1) << base

    # 2. Swizzle Safety: If swizzling (bits > 0), length MUST be power of 2.
    # Otherwise, XOR operations can jump outside the domain [0, length).
    if bits > 0:
        is_power_of_two = (length > 0) and ((length & (length - 1)) == 0)
        if not is_power_of_two:
            raise ValueError(f"CRITICAL: Cannot apply bit-swizzle on non-power-of-2 length {length}. Collision imminent.")

  def __call__(self, x):
    # Vectorized for numpy
    # 1. Swizzle (Bitwise Mix)
    swizzled = x ^ ((x >> self.shift) & self.mask)
    
    # 2. Affine (Linear Mix)
    return (self.scale * (swizzled + self.step)) % self.length

In [None]:
import numpy as np
from typing import Tuple, Union



class GemmTopology: 
  """
  The Mathematical Definition (S, T, σ, γ).
  Defines the Geometry (Shape/Tiling) and Topology (Order/Rank) of one level.
  """
  def __init__(self, 
               shape: Tuple[int, int, int], 
               tiler: np.ndarray, 
               sigma: Tuple[int, int, int], 
               perm: 'Bijection'): 
    
    # --- 1. Validation (The "S" and "T") ---
    self.tiler = np.asarray(tiler, dtype=int)
    assert self.tiler.ndim == 2 and self.tiler.shape[0] == 3 
    
    # Verify partitions sum to total shape
    assert np.all(np.sum(self.tiler, axis=1) == np.array(shape))
    
    # Verify Sigma is a valid permutation of (0,1,2)
    assert np.array_equal(np.sort(np.array(sigma)), np.array([0, 1, 2]))

    self.shape = shape 
    self.perm = perm # The Gamma Bijection (Swizzle)
    
    # --- 2. Geometry Setup (The Tiler Tensor) ---
    self.q = self.tiler.shape[1] # The 'q' split factor
    self.N_tiles = self.q**3
    self.tile_tensor_shape = (self.q, self.q, self.q)
    # Pre-calculate Cuts (Prefix Sums) for O(1) Slice lookups
    # hstack with 0 column to get range starts [0, s1, s1+s2, ...]
    zeros_col = np.zeros((3, 1), dtype=int)
    self.cuts = np.cumsum(np.hstack([zeros_col, self.tiler]), axis=1)
    
    # --- 3. Topology Setup (The "σ" and "γ") ---
    # Base Colex strides: (1, q, q^2)
    base_strides = (1, self.q, self.q**2)
    
    # Permute strides according to sigma (The Layout Permutation)
    self.strides = (base_strides[sigma[0]], 
                    base_strides[sigma[1]], 
                    base_strides[sigma[2]])
  
  # --- Helper: Polymorphic Coordinate Resolver ---
  def _resolve_coords(self, x: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]:
    """Implements colex^{-1} if input is an int."""
    if isinstance(x, (int, np.integer)):
        assert 0 <= x < self.N_tiles, f"Tile ID {x} out of bounds"
        return (x % self.q, (x // self.q) % self.q, x // (self.q**2))
    else:
        assert len(x) == 3
        return x

  # --- Map 1: The Tile-Shape Map (shp_T) ---
  def get_tile_shape(self, x: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]:
    m, k, n = self._resolve_coords(x)
    return (self.tiler[0, m], self.tiler[1, k], self.tiler[2, n])

  # --- Map 2: The Tile-Slice Map (slc_T) ---
  def get_tile_slice(self, x: Union[int, Tuple[int, int, int]]) -> Tuple[slice, slice, slice]:
    m, k, n = self._resolve_coords(x)
    ms = slice(self.cuts[0, m], self.cuts[0, m+1])
    ks = slice(self.cuts[1, k], self.cuts[1, k+1])
    ns = slice(self.cuts[2, n], self.cuts[2, n+1])
    return (ms, ks, ns)

  # --- Map 3: The Rank Map (R) ---
  def get_rank(self, x: Union[int, Tuple[int, int, int]]) -> int:
    """R: (x,y,z) -> gamma( phi_sigma(x,y,z) )"""
    c = self._resolve_coords(x)
    # 1. Apply Strides (Layout)
    linear_rank = (c[0]*self.strides[0] + c[1]*self.strides[1] + c[2]*self.strides[2])
    # 2. Apply Swizzle (Bijection)
    return self.perm(linear_rank)

  def __repr__(self):
      return f"<Topology {self.shape} | q={self.q} | σ={self.strides}>"

In [14]:


class ExecutionType(Enum):
    PARALLEL = auto() # e.g., Grid tiling (Spatial)
    SERIAL = auto()   # e.g., K-loop or Producer-Consumer (Temporal)

class GemmNode: 
  """
  A recursive node in the orchestration hierarchy.
  Links the current level's mathematical topology to the next level's nodes.
  """
  def __init__(self, 
               topology: 'GemmTopology', 
               execution_type: ExecutionType = ExecutionType.PARALLEL): 
    
    self.topology = topology
    self.execution_type = execution_type
    
    # Initialize map: Tile_ID -> None (Placeholder for future strategies)
    # The domain is strictly defined by the Topology's tiling (0 to q^3 - 1)
    self.children = {i: None for i in range(self.topology.N_tiles)}

  def set_child(self, tile_id: int, subtopology: 'GemmTopology', sub_exec_type: ExecutionType = ExecutionType.PARALLEL):
    """
    Attaches a strategy (Sub-Topology) to a specific tile of the current level.
    """
    # 1. Validate Range (Is this tile ID valid for me?)
    if tile_id not in self.children:
        raise IndexError(f"Tile ID {tile_id} is out of bounds for parent topology with {self.topology.N_tiles} tiles.")

    # 2. Validate Shape Consistency (The Physics Check)
    # The shape of the specific tile in Parent must match the Total Shape of the Child
    parent_tile_shape = self.topology.get_tile_shape(tile_id)
    child_total_shape = subtopology.shape

    if np.any(np.array(parent_tile_shape) != np.array(child_total_shape)):
        raise ValueError(f"Shape Mismatch at Tile {tile_id}!\n"
                         f"Parent Tile Shape: {parent_tile_shape}\n"
                         f"Child Total Shape: {child_total_shape}\n"
                         f"A strategy must perfectly fill the tile it occupies.")

    # 3. Create and Link the Node
    # We wrap the topology in a Node to allow further recursion down the line
    self.children[tile_id] = GemmNode(subtopology, sub_exec_type)

  def __repr__(self):
      # Count how many children are actually defined (not None)
      defined_children = sum(1 for c in self.children.values() if c is not None)
      return f"<GemmNode ({self.execution_type.name}) | {self.topology} | Defined Children: {defined_children}/{self.topology.N_tiles}>"



In [36]:
shape = (4,4,4) 
Tiler = np.zeros((3,4)).astype(int)

Tiler[0,:] = [1,1,1,1]
Tiler[1,:] = [4,0,0,0]
Tiler[2,:] = [1,1,1,1]

sigma = (0,1,2)
length = 4*4*4
bits = 0 
base = 0 
shift = 0
scale = 1 
step = 0
b = Bijection(length, bits, base, shift, scale, step)
inner_Tiler = np.zeros((3,4)).astype(int)

inner_Tiler[0,:] = [1,0,0,0]
inner_Tiler[1,:] = [1,1,1,1]
inner_Tiler[2,:] = [1,0,0,0]


In [37]:
outer_gemm = GemmTopology(shape, Tiler, sigma, b)

outer_node = GemmNode(outer_gemm, ExecutionType.PARALLEL)
for m in range(4): 
  for n in range(4): 
    outer_node.set_child(outer_gemm.get_rank((m,0,n)), GemmTopology((1,4,1),inner_Tiler, sigma, b), ExecutionType.SERIAL)

In [38]:
def traverse (root:GemmNode): 
  print(root)
  print("____")
  for c in root.children: 
    print(root.children[c])
  

In [39]:
traverse(outer_node)

<GemmNode (PARALLEL) | <Topology (4, 4, 4) | q=4 | σ=(1, 4, 16)> | Defined Children: 16/64>
____
<GemmNode (SERIAL) | <Topology (1, 4, 1) | q=4 | σ=(1, 4, 16)> | Defined Children: 0/64>
<GemmNode (SERIAL) | <Topology (1, 4, 1) | q=4 | σ=(1, 4, 16)> | Defined Children: 0/64>
<GemmNode (SERIAL) | <Topology (1, 4, 1) | q=4 | σ=(1, 4, 16)> | Defined Children: 0/64>
<GemmNode (SERIAL) | <Topology (1, 4, 1) | q=4 | σ=(1, 4, 16)> | Defined Children: 0/64>
None
None
None
None
None
None
None
None
None
None
None
None
<GemmNode (SERIAL) | <Topology (1, 4, 1) | q=4 | σ=(1, 4, 16)> | Defined Children: 0/64>
<GemmNode (SERIAL) | <Topology (1, 4, 1) | q=4 | σ=(1, 4, 16)> | Defined Children: 0/64>
<GemmNode (SERIAL) | <Topology (1, 4, 1) | q=4 | σ=(1, 4, 16)> | Defined Children: 0/64>
<GemmNode (SERIAL) | <Topology (1, 4, 1) | q=4 | σ=(1, 4, 16)> | Defined Children: 0/64>
None
None
None
None
None
None
None
None
None
None
None
None
<GemmNode (SERIAL) | <Topology (1, 4, 1) | q=4 | σ=(1, 4, 16)> | Define