In [36]:
import os
import torch
from IPython.core.debugger import set_trace

os.environ["TRITON_INTERPRET"] = '0' # Need to set before importing Triton

def check_tensors_gpu_ready(*tensors: torch.Tensor):
    for t in tensors:
        assert t.is_contiguous(), "A tensor is not contiguous"
        if not os.environ.get('TRITON_INTERPRET') == '1':
            assert t.is_cuda, "A tensor is not on GPU"

def test_pid_conds(conds_str: str, pid_0=[0], pid_1=[0], pid_2=[0]) -> bool:
    pids = pid_0[0], pid_1[0], pid_2[0]
    conds = conds_str.replace(' ', '').split(',')
    for i, (cond, pid) in enumerate(zip(conds, pids)):
        if cond == "":
            continue
        try:
            op, threshold = cond[0], int(cond[1:])
        except ValueError as e:
            if len(cond[1:]) == 2:
                op, threshold = cond[0:2], int(cond[2:])
            else:
                raise ValueError(e)
        if op not in ['<','>','>=','<=','=', '!=']: 
            raise ValueError(f"Rules may only use these ops: '<','>','>=','<=','=', '!='. Invalid rule: '{cond}'.")
        op = "==" if op == "=" else op
        if not eval(f"{pid} {op} {threshold}"):
            return False
    return True

assert test_pid_conds("")
assert test_pid_conds(">=0")
assert test_pid_conds(">0", [1], [1])
assert not test_pid_conds(">0", [0], [1])
assert test_pid_conds("=0,=1", [0], [1], [1])
assert test_pid_conds("=0, =1", [0], [1], [2])

def breakpoint_if(conds: str, pid_0=[0], pid_1=[0], pid_2=[0]):
    if test_pid_conds(conds, pid_0, pid_1, pid_2):
        set_trace()

def print_if(txt: str, conds: str, pid_0=[0], pid_1=[0], pid_2=[0]):
    if test_pid_conds(conds, pid_0, pid_1, pid_2):
        print(txt)

def cdiv(a: int, b: int):
    return (a + b - 1) // b

assert cdiv(10, 2) == 5
assert cdiv(11, 2) == 6

check_tensors_gpu_ready(torch.tensor([1, 2, 3], device="cuda"), torch.tensor([[4, 5, 6], [7 ,8, 9], [10, 11, 12]], device="cuda")[:, 2].contiguous())



0 >= 0
1 > 0
0 > 0
0 == 0
1 == 1
0 == 0
1 == 1
