In [1]:
from Orch import Bijection, ExecutionType, GemmTopology, GemmNode, generate_dot_source, SliceIR_node, SliceIR_tree
import numpy as np

In [2]:
shape = (4,4,4)
tiler = np.zeros((3,4)).astype(int)
tiler[0,:] = [1,1,1,1]
tiler[1,:] = [4,0,0,0] 
tiler[2,:] = [1,1,1,1] 

bits = 1 
base = 1 
shift = 1
scale = 3
step = 5
length = 4*4*4 

sub_shape = (1,4,1)
sub_tiler = np.zeros((3,4)).astype(int)
sub_tiler[0,:] = [1,0,0,0]
sub_tiler[1,:] = [1,1,1,1]
sub_tiler[2,:] = [1,0,0,0]

sub_length = 4

b1 = Bijection(length, bits, base, shift, scale, step)

b2 = Bijection(sub_length, bits, base, shift, scale, step)

coarse_toplogy = GemmTopology(shape, tiler, (1,0,2), b1)
fine_subtopology = GemmTopology(sub_shape,sub_tiler, (0,2,1), b2)

In [3]:
root = GemmNode(coarse_toplogy, ExecutionType.PARALLEL)

for m in range(4): 
  for n in range(4): 
    root.set_child(root.topology.get_tile_id((m,0,n)), fine_subtopology,sub_exec_type=ExecutionType.SERIAL)

In [4]:
root.print_layout_table()


Node Layout: PARALLEL | Topology: (4, 4, 4)
-------------------------------------------------------------------------------------
ID   | Rank | Slice (M, K, N)                | Topology             | Status
-------------------------------------------------------------------------------------
0    | 15   | [0:1, 0:4, 0:1]                | Top(1, 4, 1)         | 0/64 children
1    | 33   | [1:2, 0:4, 0:1]                | Top(1, 4, 1)         | 0/64 children
2    | 39   | [2:3, 0:4, 0:1]                | Top(1, 4, 1)         | 0/64 children
3    | 57   | [3:4, 0:4, 0:1]                | Top(1, 4, 1)         | 0/64 children
16   | 63   | [0:1, 0:4, 1:2]                | Top(1, 4, 1)         | 0/64 children
17   | 17   | [1:2, 0:4, 1:2]                | Top(1, 4, 1)         | 0/64 children
18   | 23   | [2:3, 0:4, 1:2]                | Top(1, 4, 1)         | 0/64 children
19   | 41   | [3:4, 0:4, 1:2]                | Top(1, 4, 1)         | 0/64 children
32   | 47   | [0:1, 0:4, 2:3]     

In [5]:
S = SliceIR_tree(root)

In [6]:
slice_root = S.root

In [7]:
dot_graph = generate_dot_source(S)

In [8]:
dot_data = generate_dot_source(S)
print(dot_data) 
open("debug_tree.dot", "w").write(dot_data)

digraph SliceIR {
    rankdir=TB;
    splines=polyline;
    node [shape=record, fontname="Courier", style=filled, fillcolor="#f8f9fa"];
    edge [fontname="Courier"];
    node_139773539933536 [label="{ Rank:0 | Depth:0 | child_exec:PARALLEL } | { Shape:(4, 4, 4) | Slice:[0:4, 0:4, 0:4] | Offset:(0, 0, 0) }"];
    node_139773539933536 -> node_139773539934160;
    node_139773539934160 [label="{ Rank:1 | Depth:1 | child_exec:SERIAL } | { Shape:(1, 4, 1) | Slice:[1:2, 0:4, 2:3] | Offset:(1, 0, 2) }"];
    node_139773539933536 -> node_139773539934208;
    node_139773539934208 [label="{ Rank:7 | Depth:1 | child_exec:SERIAL } | { Shape:(1, 4, 1) | Slice:[2:3, 0:4, 2:3] | Offset:(2, 0, 2) }"];
    node_139773539933536 -> node_139773539934448;
    node_139773539934448 [label="{ Rank:9 | Depth:1 | child_exec:SERIAL } | { Shape:(1, 4, 1) | Slice:[3:4, 0:4, 3:4] | Offset:(3, 0, 3) }"];
    node_139773539933536 -> node_139773539933632;
    node_139773539933632 [label="{ Rank:15 | Depth:1 | child_ex

3397

In [None]:
import torch

def torch_verify(T: SliceIR_tree): 
  M, K, N = map(int, T.root.shape)
  
  A = torch.randn(M, K)
  B = torch.randn(K, N)
  C_ref = torch.matmul(A, B)
  C_test = torch.zeros(M, N)

  print(f"Verifying GEMM ({M}x{N}x{K})...")

  for leaf in T.walk(strategy="dfs"): 
    # 1. Anchor
    g_off_m, g_off_k, g_off_n = leaf.parent.global_offset if leaf.parent else (0, 0, 0)
    slice_m, slice_k, slice_n = leaf.slice

    # 2. Slice Construction
    # FIX: Removed inner parens ()
    # NOTE: Assumes .start and .stop are strictly INTEGERS (not None)
    g_m_slice = slice(slice_m.start + g_off_m, slice_m.stop + g_off_m, slice_m.step)
    g_k_slice = slice(slice_k.start + g_off_k, slice_k.stop + g_off_k, slice_k.step)
    g_n_slice = slice(slice_n.start + g_off_n, slice_n.stop + g_off_n, slice_n.step)

    # 3. Accumulate
    temp = torch.matmul(A[g_m_slice, g_k_slice], B[g_k_slice, g_n_slice])
    C_test[g_m_slice, g_n_slice] += temp

  is_close = torch.allclose(C_test, C_ref, atol=1e-4)
  print(f"Verification: {'PASS' if is_close else 'FAIL'}")
  
  return C_test, is_close


import numpy as np

def numpy_verify(T: SliceIR_tree): 
  M, K, N = map(int, T.root.shape)
  
  # Standard NumPy operands
  A = np.random.randn(M, K)
  B = np.random.randn(K, N)
  C_ref = A @ B  # Ground truth using standard matmul
  C_test = np.zeros((M, N))

  print(f"Verifying GEMM ({M}x{N}x{K})...")

  for leaf in T.walk(strategy="dfs"): 
    # 1. Anchor: Get Parent's global offset
    g_off_m, g_off_k, g_off_n = leaf.parent.global_offset if leaf.parent else (0, 0, 0)
    slice_m, slice_k, slice_n = leaf.slice

    # 2. Slice Construction
    # Applying the same logic: Global = Parent_Offset + Local_Slice
    g_m_slice = slice(slice_m.start + g_off_m, slice_m.stop + g_off_m, slice_m.step)
    g_k_slice = slice(slice_k.start + g_off_k, slice_k.stop + g_off_k, slice_k.step)
    g_n_slice = slice(slice_n.start + g_off_n, slice_n.stop + g_off_n, slice_n.step)

    # 3. Accumulate
    temp = A[g_m_slice, g_k_slice] @ B[g_k_slice, g_n_slice]
    C_test[g_m_slice, g_n_slice] += temp

  # 4. Check
  is_close = np.allclose(C_test, C_ref, atol=1e-4)
  print(f"Verification: {'PASS' if is_close else 'FAIL'}")
  
  if not is_close:
      print(f"Max Diff: {np.abs(C_test - C_ref).max()}")

  return C_test, is_close

In [18]:
C_out, Truth = torch_verify(S)

Verifying GEMM (4x4x4)...
Verification: PASS


In [None]:
Truth

True