In [2]:
from Orch import Bijection, ExecutionType, GemmTopology, GemmNode
import numpy as np

In [3]:
shape = (4,4,4)
tiler = np.zeros((3,4)).astype(int)
tiler[0,:] = [1,1,1,1]
tiler[1,:] = [4,0,0,0] 
tiler[2,:] = [1,1,1,1] 

bits = 1 
base = 1 
shift = 1
scale = 3
step = 5
length = 4*4*4 

sub_shape = (1,4,1)
sub_tiler = np.zeros((3,4)).astype(int)
sub_tiler[0,:] = [1,0,0,0]
sub_tiler[1,:] = [1,1,1,1]
sub_tiler[2,:] = [1,0,0,0]

sub_length = 4

b1 = Bijection(length, bits, base, shift, scale, step)

b2 = Bijection(sub_length, bits, base, shift, scale, step)

coarse_toplogy = GemmTopology(shape, tiler, (1,0,2), b1)
fine_subtopology = GemmTopology(sub_shape,sub_tiler, (0,2,1), b2)

In [4]:
root = GemmNode(coarse_toplogy, ExecutionType.PARALLEL)

for m in range(4): 
  for n in range(4): 
    root.set_child(root.topology.get_tile_id((m,0,n)), fine_subtopology,sub_exec_type=ExecutionType.SERIAL)

In [5]:
root.print_layout_table()


Node Layout: PARALLEL | Topology: (4, 4, 4)
-------------------------------------------------------------------------------------
ID   | Rank | Slice (M, K, N)                | Topology             | Status
-------------------------------------------------------------------------------------
0    | 15   | [0:1, 0:4, 0:1]                | Top(1, 4, 1)         | 0/64 children
1    | 33   | [1:2, 0:4, 0:1]                | Top(1, 4, 1)         | 0/64 children
2    | 39   | [2:3, 0:4, 0:1]                | Top(1, 4, 1)         | 0/64 children
3    | 57   | [3:4, 0:4, 0:1]                | Top(1, 4, 1)         | 0/64 children
16   | 63   | [0:1, 0:4, 1:2]                | Top(1, 4, 1)         | 0/64 children
17   | 17   | [1:2, 0:4, 1:2]                | Top(1, 4, 1)         | 0/64 children
18   | 23   | [2:3, 0:4, 1:2]                | Top(1, 4, 1)         | 0/64 children
19   | 41   | [3:4, 0:4, 1:2]                | Top(1, 4, 1)         | 0/64 children
32   | 47   | [0:1, 0:4, 2:3]     

In [41]:
class SliceIR_node: 
    def __init__(self, rank, shape, slice, depth_id, execution_policy):
        self.rank = rank  # <--- Now stores its own rank
        self.shape = shape 
        self.slice = slice 
        self.depth_id = depth_id 
        self.execution_policy = execution_policy
        self.children = [] # <--- List[SliceIR_node], sorted by rank
        self.parent = None

    def _fmt_slice(self, s):
        """Helper: converts slice(0, 4, None) -> '0:4'"""
        if s is None: return ":"
        start = s.start if s.start is not None else ""
        stop = s.stop if s.stop is not None else ""
        step = f":{s.step}" if s.step is not None else ""
        return f"{start}:{stop}{step}"

    def __repr__(self):
        slice_str = "[" + ", ".join(self._fmt_slice(s) for s in self.slice) + "]"
        S = (int(self.shape[0]), int(self.shape[1]), int(self.shape[2]))
        # Added Rank to the repr for clarity
        return (f"<IR_Node d={self.depth_id} | r={self.rank} | {self.execution_policy.name} | "
                f"shape={S} | slice={slice_str} | Children={len(self.children)}>")

class SliceIR_tree: 
    def __init__(self, root_gemm_node: GemmNode): 
        # 1. Setup Root Context
        root_shape = root_gemm_node.topology.shape 
        try:
            root_slice = root_gemm_node.get_owned_slice_rel_parent()
        except AttributeError:
            root_slice = tuple(slice(0, s) for s in root_shape)

        # 2. Start Recursion (Root usually has rank 0 or None contextually)
        self.root = self._build_recursive(
            node=root_gemm_node, 
            rank=0,
            shape=root_shape, 
            slc=root_slice, 
            depth=0
        )

    def _build_recursive(self, node: GemmNode, rank, shape, slc, depth) -> SliceIR_node:
        # A. Create the IR Node
        ir_node = SliceIR_node(rank, shape, slc, depth, node.execution_type)
        
        # B. Collect all valid children first
        temp_children = []
        n_tiles = node.topology.N_tiles 
        
        for c in range(n_tiles): 
            child_shape = node.topology.get_tile_shape(c)
            child_slice = node.topology.get_tile_slice(c)
            child_rank = node.topology.get_rank(c) # Get the execution rank
            
            if 0 in child_shape: 
                assert node.children[c] is None, \
                    f"Tile {c} has 0-volume {child_shape} but has an attached strategy."
                continue 

            child_gemm_node = node.children[c]

            if child_gemm_node is not None:
                # Recurse
                child_ir = self._build_recursive(
                    node=child_gemm_node, 
                    rank=child_rank, # Pass rank down
                    shape=child_shape, 
                    slc=child_slice, 
                    depth=depth + 1
                )
                child_ir.parent = ir_node
                temp_children.append(child_ir)
        
        # C. Sort children by Rank and assign to list
        # This ensures ir_node.children[0] is always the first to execute
        temp_children.sort(key=lambda x: x.rank)
        ir_node.children = temp_children

        return ir_node

In [42]:
S = SliceIR_tree(root)

In [43]:
slice_root = S.root

In [44]:
print(slice_root)

<IR_Node d=0 | r=0 | PARALLEL | shape=(4, 4, 4) | slice=[0:4, 0:4, 0:4] | Children=16>


In [46]:
for c in slice_root.children: 
  print(c)

<IR_Node d=1 | r=1 | SERIAL | shape=(1, 4, 1) | slice=[1:2, 0:4, 2:3] | Children=0>
<IR_Node d=1 | r=7 | SERIAL | shape=(1, 4, 1) | slice=[2:3, 0:4, 2:3] | Children=0>
<IR_Node d=1 | r=9 | SERIAL | shape=(1, 4, 1) | slice=[3:4, 0:4, 3:4] | Children=0>
<IR_Node d=1 | r=15 | SERIAL | shape=(1, 4, 1) | slice=[0:1, 0:4, 0:1] | Children=0>
<IR_Node d=1 | r=17 | SERIAL | shape=(1, 4, 1) | slice=[1:2, 0:4, 1:2] | Children=0>
<IR_Node d=1 | r=23 | SERIAL | shape=(1, 4, 1) | slice=[2:3, 0:4, 1:2] | Children=0>
<IR_Node d=1 | r=25 | SERIAL | shape=(1, 4, 1) | slice=[3:4, 0:4, 2:3] | Children=0>
<IR_Node d=1 | r=31 | SERIAL | shape=(1, 4, 1) | slice=[0:1, 0:4, 3:4] | Children=0>
<IR_Node d=1 | r=33 | SERIAL | shape=(1, 4, 1) | slice=[1:2, 0:4, 0:1] | Children=0>
<IR_Node d=1 | r=39 | SERIAL | shape=(1, 4, 1) | slice=[2:3, 0:4, 0:1] | Children=0>
<IR_Node d=1 | r=41 | SERIAL | shape=(1, 4, 1) | slice=[3:4, 0:4, 1:2] | Children=0>
<IR_Node d=1 | r=47 | SERIAL | shape=(1, 4, 1) | slice=[0:1, 0:4, 2: