In [None]:
import math

In [None]:

class bijection: 
  def __init__(self, length, bits, base, shift, scale, step): 
    
    assert math.gcd(scale, length) == 1
    assert shift > bits 
    self.length = length
    self.scale = scale
    self.step = step
    self.shift = shift
    
    self.mask = ((1 << bits) - 1) << base

  def __call__(self, x):

    swizzled = x ^ ((x >> self.shift) & self.mask)
    return (self.scale * (swizzled + self.step)) % self.length



In [None]:
<gem_gem_code>
import math
import numpy as np

# --- 1. THE BIJECTION (Permutation / Swizzle) ---
class bijection: 
  def __init__(self, length, bits, base, shift, scale, step): 
    assert math.gcd(scale, length) == 1
    assert shift >= 0 # Enforcing right-shift only
    
    self.length = length
    self.scale = scale
    self.step = step
    self.shift = shift
    self.mask = ((1 << bits) - 1) << base

  def __call__(self, x):
    # Vectorized for numpy
    swizzled = x ^ ((x >> self.shift) & self.mask)
    return (self.scale * (swizzled + self.step)) % self.length
  
  def __repr__(self):
    return f"Bijection(Scale={self.scale}, Step={self.step}, XOR_Shift={self.shift})"

# --- 2. THE SURJECTION (Partition / Split) ---
class surjection:
  def __init__(self, num_parts, chunk_size):
    self.num_parts = num_parts
    self.chunk_size = chunk_size

  def __call__(self, x):
    # Vectorized for numpy
    return (x // self.chunk_size) % self.num_parts

  def __repr__(self):
    return f"Surjection(Parts={self.num_parts}, Chunk={self.chunk_size})"

# --- 3. THE ORCHESTRATOR (Rigid Tile Dispatch) ---
class MatmulOrchestrator:
    def __init__(self, M, N, K, tile_M, tile_N, tile_K, num_parts):
        self.shape = (M, N, K)
        self.tile_shape = (tile_M, tile_N, tile_K)
        self.num_parts = num_parts
        
        # Grid Dims
        self.Gm = M // tile_M
        self.Gn = N // tile_N
        self.Gk = K // tile_K
        self.total_tiles = self.Gm * self.Gn * self.Gk
        
        # Default: Identity Map, Block Partition
        self.bijection = bijection(self.total_tiles, bits=0, base=0, shift=0, scale=1, step=0)
        self.surjection = surjection(num_parts, chunk_size=(self.total_tiles + num_parts - 1)//num_parts)

    def set_strategy(self, bij, surj):
        self.bijection = bij
        self.surjection = surj

    def get_coord_from_linear(self, idx):
        """Colexical Inverse: Linear -> (m, n, k)"""
        # Vectorized for numpy if idx is array
        m = idx % self.Gm
        tmp = idx // self.Gm
        n = tmp % self.Gn
        k = tmp // self.Gn
        return m, n, k

    def run(self):
        """
        Returns a Dict: {ProcessorID: Matrix of shape (T, 3)}
        Where T is number of tiles assigned to that processor.
        """
        # 1. Create Linear Space
        t_domain = np.arange(self.total_tiles)
        
        # 2. Permute (Global Order)
        permuted_ids = self.bijection(t_domain)
        
        # 3. Partition (Assign to Processors)
        # Note: We partition based on 't' (time slots), not the permuted ID.
        # This keeps load balancing deterministic.
        proc_assignments = self.surjection(t_domain)
        
        # 4. Group Results
        results = {}
        # We need to resolve the (m,n,k) coords for the permuted IDs
        coords_m, coords_n, coords_k = self.get_coord_from_linear(permuted_ids)
        
        # Stack coords: Shape (N, 3)
        all_coords = np.stack([coords_m, coords_n, coords_k], axis=1)
        
        for p in range(self.num_parts):
            # Mask for this processor
            mask = (proc_assignments == p)
            results[p] = all_coords[mask]
            
        return results

# --- DEMO ---
# 64x64 Grid of Tiles (Total 4096), 32 Processors
orch = MatmulOrchestrator(M=64*16, N=64*16, K=16, tile_M=16, tile_N=16, tile_K=16, num_parts=32)

# Strategy: Z-Curve Swizzle + Block Distribution
# Swizzle: Bits=3, Shift=3 (Standard 8x8 block swizzle roughly)
b = bijection(length=4096, bits=3, base=0, shift=3, scale=1, step=0)
s = surjection(num_parts=32, chunk_size=4096//32) # Equal split

orch.set_strategy(b, s)
res = orch.run()

print(f"Processor 0 handles {len(res[0])} tiles.")
print(f"First 5 coords for P0:\n{res[0][:5]}")