In [None]:
N_regs_max = 65536
SMEM_bytes_kb_max = 48

In [1]:
import math

def calc_max_warps(smem_limit_kb=48, reg_limit=65536, k_stages=4, regs_per_thread=128):
    smem_limit_bytes = smem_limit_kb * 1024
    
    # We iterate downwards from the hardware block limit (32 warps = 1024 threads)
    for warps in range(32, 0, -1):
        
        # 1. Register Constraint (Hard limit per SM)
        # 32 threads/warp * Warps * Regs/Thread
        total_regs = warps * 32 * regs_per_thread
        if total_regs > reg_limit:
            continue

        # 2. Geometry Optimization (Try to find best w_m, w_n)
        # We want to minimize (w_m + w_n) to fit in SMEM
        best_perimeter = float('inf')
        best_shape = None
        
        for w_m in range(1, warps + 1):
            if warps % w_m == 0:
                w_n = warps // w_m
                perimeter = w_m + w_n
                if perimeter < best_perimeter:
                    best_perimeter = perimeter
                    best_shape = (w_m, w_n)
        
        if not best_shape: continue
        w_m, w_n = best_shape
        
        # 3. SMEM Constraint
        # Tile A: (w_m * 16) rows * (k_stages * 16) cols * 2 bytes (bf16)
        # Tile B: (k_stages * 16) rows * (w_n * 16) cols * 2 bytes (bf16)
        # Note: 16*16*2 = 512 bytes per "tile unit"
        
        bytes_a = (w_m * 16) * (k_stages * 16) * 2
        bytes_b = (k_stages * 16) * (w_n * 16) * 2
        total_smem = bytes_a + bytes_b
        
        if total_smem <= smem_limit_bytes:
            return {
                "valid": True,
                "warps": warps,
                "shape": (w_m, w_n),
                "smem_kb": total_smem / 1024,
                "regs": total_regs,
                "k_stages": k_stages
            }

    return {"valid": False}

# --- Run Scenarios ---
print(f"{'Stages':<10} {'Regs/Thr':<10} {'Max Warps':<10} {'Shape(m,n)':<12} {'SMEM(KB)':<10}")
print("-" * 60)

for k in [2, 3, 4, 8]:
    # 64 regs is aggressive (high optimization)
    # 128 regs is comfortable/standard for complex kernels
    # 255 is heavy spilling territory
    for r in [64, 128]: 
        res = calc_max_warps(k_stages=k, regs_per_thread=r)
        if res["valid"]:
            print(f"{k:<10} {r:<10} {res['warps']:<10} {str(res['shape']):<12} {res['smem_kb']:<10.1f}")

Stages     Regs/Thr   Max Warps  Shape(m,n)   SMEM(KB)  
------------------------------------------------------------
2          64         32         (4, 8)       12.0      
2          128        16         (4, 4)       8.0       
3          64         32         (4, 8)       18.0      
3          128        16         (4, 4)       12.0      
4          64         32         (4, 8)       24.0      
4          128        16         (4, 4)       16.0      
8          64         32         (4, 8)       48.0      
8          128        16         (4, 4)       32.0      
