<a href="https://colab.research.google.com/github/Jolllly-bot/ToyVM-Triton/blob/main/Toy_Triton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ToyVM Triton



In [None]:
!python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu122

Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly-cu122
  Downloading https://github.com/mlc-ai/package/releases/download/v0.9.dev0/mlc_ai_nightly_cu122-0.19.dev58-cp310-cp310-manylinux_2_28_x86_64.whl (1454.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 GB[0m [31m698.4 kB/s[0m eta [36m0:00:00[0m
Installing collected packages: mlc-ai-nightly-cu122
Successfully installed mlc-ai-nightly-cu122-0.19.dev58


## JIT Decorator

In [None]:
import inspect
import ast

def jit(target="cpu", verbose=True):
    assert target in ["cpu", "gpu"]
    def inner(fn):
        return JIT(fn, target=target, verbose=verbose)
    return inner

class JIT:
    def __init__(self, fn, target="cpu", verbose=True):
        self.fn = fn
        self.target = target
        self.verbose = verbose

    def __call__(self, *args, **kwargs):
        fn_src = inspect.getsource(self.fn)
        fn_ast = ast.parse(fn_src)
        if self.verbose:
          print(ast.dump(fn_ast))

        ctx = self.fn.__globals__.copy()
        code_generator = CodeGenerator(fn_ast, ctx, self.target, self.verbose)
        compiled_kernel = code_generator.code_gen()

        input_args = []
        for arg in args:
            input_args.append(arg.data)
        return compiled_kernel(*input_args)

## Frontend Lexer & Parser
https://github.com/triton-lang/triton/blob/main/python/triton/compiler/compiler.py
https://github.com/triton-lang/triton/blob/main/python/triton/compiler/code_generator.py

To simplify, sema is not implemented https://github.com/triton-lang/triton/blob/main/python/triton/language/semantic.py

GPU runtime:
https://mlc.ai/chapter_gpu_acceleration/part1.html

In [None]:
from typing import Dict, Any, Type
import astunparse
import tvm
from tvm import dlight as dl
from tvm import relax
from tvm.script import relax as R
from tvm.script.ir_builder import relax as relax_builder, ir as I, IRBuilder

class CodeGenerator(ast.NodeVisitor):
    def __init__(self, fn_ast, ctx, target, verbose):
        self.fn_ast = fn_ast
        self.target = target
        self.ib = IRBuilder()
        self.ir_module = None
        self.entry = None
        self.ret = None
        self.local_var_table : Dict[str, Any] = {}
        self.ctx = ctx
        self.verbose = verbose

    def code_gen(self):
        with self.ib:
            self.visit(self.fn_ast)
        module = self.ib.get()
        if self.verbose:
          print("=========TVM IR=========")
          print(module)

        # apply transform pass on module
        with tvm.transform.PassContext(opt_level=3):
            seq = tvm.transform.Sequential(
                [
                    # relax.transform.ConvertToDataflow(),
                    relax.transform.LegalizeOps(),
                    # relax.transform.AnnotateTIROpPattern(),
                    relax.transform.FoldConstant(), # Fold constant expressions

                    relax.transform.FuseOps(),
                    relax.transform.FuseTIR(),
                ])
            module = seq(module)
        if self.verbose:
          print("===>After applied passes...")
          print(module)

        mapped_target = {'cpu': 'llvm', 'gpu': 'cuda'}
        target = tvm.target.Target(mapped_target[self.target])
        if "cuda" in target.keys:
          with target:
            module = dl.ApplyDefaultSchedule(dl.gpu.Fallback(),)(module)
          if self.verbose:
            print("===>After ApplyDefaultSchedule...")
            print(module)
          device = tvm.cuda()
        else:
          device = tvm.cpu()

        with tvm.transform.PassContext(opt_level=3):
            ex = relax.build(module, target=target)

        if self.verbose and "cuda" in target.keys:
          print("=========CUDA CODE=========")
          print(ex.mod.imported_modules[0].imported_modules[0].get_source())

        vm = relax.VirtualMachine(ex, device=device)
        return vm[self.entry]

    def visit(self, node):
        print("Visit " + node.__class__.__name__)
        return super().visit(node)

    def visit_Module(self, node: ast.Module):
        if self.ir_module:
            raise AssertionError("We should have only one module!")
        self.ir_module = I.ir_module()
        with self.ir_module:
            super().generic_visit(node)


    def visit_FunctionDef(self, node: ast.FunctionDef):
        fn = relax_builder.function()
        self.entry = node.name
        print("entry Function: {}".format(node.name))
        with fn:
            R.func_name(node.name)
            self.visit(node.args)
            self._visit_compound_stmt(node.body)

            if self.ret is None:
                R.func_ret_value(relax.ShapeExpr([]))
            else:
                R.func_ret_value(self.ret)

    def visit_arguments(self, node: ast.arguments):
      for arg in node.args:
          arg_name = arg.arg
          if arg.annotation is None:
                raise ValueError(arg, "Type annotation is required for function parameters.")
          anno = eval(astunparse.unparse(arg.annotation), self.ctx)
          print(anno)
          param = R.arg(arg_name, R.Tensor(shape=anno.shape, dtype=anno.dtype))
          self.local_var_table[arg_name] = param

    def _visit_compound_stmt(self, stmts):
        assert isinstance(stmts, (list, tuple))
        for stmt in stmts:
            ret = self.visit(stmt)
            if ret is not None and isinstance(stmt, ast.Return):
                self.ret = ret

    def visit_Pass(self, node: ast.Pass):
        pass

    def visit_Assign(self, node: ast.Assign):
        if len(node.targets) != 1:
            raise NotImplementedError("Doesn't support simultaneous multiple assignment like 'a = b = c' in AST node type: {}".format(type(node).__name__))
        target: relax.Var = self.visit(node.targets[0])
        value = self.visit(node.value)
        self.local_var_table[target.name_hint] = value
        self.ib.name(target.name_hint, value)

    def visit_Name(self, node: ast.Name):
        name = node.id
        if isinstance(node.ctx, ast.Store):
            if name not in self.local_var_table.keys():
                self.local_var_table[name] = relax.Var(name, struct_info=relax.ObjectStructInfo())
        return self.local_var_table[name]

    def visit_BinOp(self, node: ast.BinOp):
        lhs = self.visit(node.left)
        rhs = self.visit(node.right)
        # TODO: refactor
        method_name = self._method_name_for_bin_op.get(type(node.op))
        print("method name", method_name)

        if isinstance(node.op, ast.Add):
            return R.emit(R.add(lhs, rhs))
        elif isinstance(node.op, ast.Mult):
            return R.emit(R.multiply(lhs, rhs))
        else:
            raise NotImplementedError("Unsupported AST node type: {}".format(type(node).__name__))

    _method_name_for_bin_op: Dict[Type[ast.operator], str] = {
            ast.Add: '__add__',
            ast.Sub: '__sub__',
            ast.Mult: '__mul__',
            ast.Div: '__truediv__',
            ast.FloorDiv: '__floordiv__',
            ast.Mod: '__mod__',
            ast.Pow: '__pow__',
            ast.LShift: '__lshift__',
            ast.RShift: '__rshift__',
            ast.BitAnd: '__and__',
            ast.BitOr: '__or__',
            ast.BitXor: '__xor__',
        }


    def visit_Return(self, node: ast.Return):
        ret_value = self.visit(node.value)
        return ret_value

    def visit_Constant(self, node: ast.Constant):
        return R.emit(relax.const(node.value))

    def generic_visit(self, node: ast.AST):
        raise NotImplementedError("Unsupported AST node type: {}".format(type(node).__name__))

## Tensor Definition using [DLPack](https://dmlc.github.io/dlpack/latest/)

https://dmlc.github.io/dlpack/latest/python_spec.html
Tensor usage Test:

In [None]:
import torch
class Tensor:
    def __init__(self, shape, dtype):
        self.shape = shape
        self.dtype = dtype
        self._data = None

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, data: "torch.Tensor"):
        def _from_dlpack(tensor):
            from tvm.runtime import Device
            from tvm.runtime import ndarray
            try:
                return ndarray.from_dlpack(tensor)
            except RuntimeError:
                pass
            device_type = tensor.device.type
            device_id = tensor.device.index or 0
            return ndarray.array(
                tensor.numpy(),
                device=Device(
                    Device.STR2MASK[device_type],
                    device_id,
                ),
            )
        data = _from_dlpack(data)
        if data.shape != tuple(self.shape):
            raise ValueError(f"Shape mismatch: expected {tuple(self.shape)}, got {data.shape}")
        if data.dtype != self.dtype:
            raise ValueError(f"Dtype mismatch: expected {self.dtype}, got {data.dtype}")
        self._data = data

    def __str__(self):
        return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'

In [None]:
a = Tensor(shape=(2, 3), dtype="float32")
a.data = torch.ones(size=(2, 3), dtype=torch.float32)
print(a)
print(a.data)
print(type(a.data))

float32[2, 3]
[[1. 1. 1.]
 [1. 1. 1.]]
<class 'tvm.runtime.ndarray.NDArray'>


## Test Functions

In [None]:
@jit(target="cpu")
def add_tensor(a: Tensor(shape=(2, 3), dtype="float32"), b: Tensor(shape=(2, 3), dtype="float32")):
    out = a + b
    return out

a = Tensor(shape=(2, 3), dtype="float32")
b = Tensor(shape=(2, 3), dtype="float32")
a.data = torch.ones(size=(2, 3), dtype=torch.float32)
b.data = torch.ones(size=(2, 3), dtype=torch.float32)
print(add_tensor(a, b))

Module(body=[FunctionDef(name='add_tensor', args=arguments(posonlyargs=[], args=[arg(arg='a', annotation=Call(func=Name(id='Tensor', ctx=Load()), args=[], keywords=[keyword(arg='shape', value=Tuple(elts=[Constant(value=2), Constant(value=3)], ctx=Load())), keyword(arg='dtype', value=Constant(value='float32'))])), arg(arg='b', annotation=Call(func=Name(id='Tensor', ctx=Load()), args=[], keywords=[keyword(arg='shape', value=Tuple(elts=[Constant(value=2), Constant(value=3)], ctx=Load())), keyword(arg='dtype', value=Constant(value='float32'))]))], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[Assign(targets=[Name(id='out', ctx=Store())], value=BinOp(left=Name(id='a', ctx=Load()), op=Add(), right=Name(id='b', ctx=Load()))), Return(value=Name(id='out', ctx=Load()))], decorator_list=[Call(func=Name(id='jit', ctx=Load()), args=[], keywords=[keyword(arg='target', value=Constant(value='cpu'))])])], type_ignores=[])
Visit Module
Visit FunctionDef
entry Function: add_tensor
Visit arguments
fl

In [None]:
@jit(target="gpu", verbose=False)
def add_tensor(a: Tensor(shape=(2, 3), dtype="float32"), b: Tensor(shape=(2, 3), dtype="float32")):
    out = a + b
    return out

a = Tensor(shape=(2, 3), dtype="float32")
b = Tensor(shape=(2, 3), dtype="float32")
a.data = torch.ones(size=(2, 3), dtype=torch.float32, device="cuda")
b.data = torch.ones(size=(2, 3), dtype=torch.float32, device="cuda")
print(add_tensor(a, b))

Visit Module
Visit FunctionDef
entry Function: add_tensor
Visit arguments
float32[2, 3]
float32[2, 3]
Visit Assign
Visit Name
Visit BinOp
Visit Name
Visit Name
method name __add__
Visit Return
Visit Name
[[2. 2. 2.]
 [2. 2. 2.]]


In [None]:
@jit(target="gpu")
def mul_tensor(a: Tensor(shape=(2, 3), dtype="float32"), b: Tensor(shape=(2, 3), dtype="float32")):
    out = a * b
    return out

a = Tensor(shape=(2, 3), dtype="float32")
b = Tensor(shape=(2, 3), dtype="float32")
a.data = torch.ones(size=(2, 3), dtype=torch.float32, device="cuda")
b.data = torch.ones(size=(2, 3), dtype=torch.float32, device="cuda")
print(mul_tensor(a, b))

Module(body=[FunctionDef(name='mul_tensor', args=arguments(posonlyargs=[], args=[arg(arg='a', annotation=Call(func=Name(id='Tensor', ctx=Load()), args=[], keywords=[keyword(arg='shape', value=Tuple(elts=[Constant(value=2), Constant(value=3)], ctx=Load())), keyword(arg='dtype', value=Constant(value='float32'))])), arg(arg='b', annotation=Call(func=Name(id='Tensor', ctx=Load()), args=[], keywords=[keyword(arg='shape', value=Tuple(elts=[Constant(value=2), Constant(value=3)], ctx=Load())), keyword(arg='dtype', value=Constant(value='float32'))]))], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[Assign(targets=[Name(id='out', ctx=Store())], value=BinOp(left=Name(id='a', ctx=Load()), op=Mult(), right=Name(id='b', ctx=Load()))), Return(value=Name(id='out', ctx=Load()))], decorator_list=[Call(func=Name(id='jit', ctx=Load()), args=[], keywords=[keyword(arg='target', value=Constant(value='gpu'))])])], type_ignores=[])
Visit Module
Visit FunctionDef
entry Function: mul_tensor
Visit arguments
f

In [None]:
import timeit
import inspect


fn_src = inspect.getsource(mul_tensor.fn)  # Get source code of the original function
fn_ast = ast.parse(fn_src)
ctx = mul_tensor.fn.__globals__.copy()
code_generator = CodeGenerator(fn_ast, ctx, target="gpu", verbose=False)  # Create CodeGenerator instance
compiled_kernel = code_generator.code_gen()  # Get compiled kernel

# Time your add function (execution only)
time_your_add = timeit.timeit(lambda: compiled_kernel(a.data, b.data), number=1000)

# Convert NDArrays back to PyTorch tensors for torch.add
a_tensor = torch.from_numpy(a.data.numpy())
b_tensor = torch.from_numpy(b.data.numpy())

# Time Torch's add function
time_torch_add = timeit.timeit(lambda: torch.mul(a_tensor, b_tensor), number=1000)

print(f"Your mul time: {time_your_add:.6f} seconds")
print(f"Torch mul time: {time_torch_add:.6f} seconds")

Visit Module
Visit FunctionDef
entry Function: mul_tensor
Visit arguments
float32[2, 3]
float32[2, 3]
Visit Assign
Visit Name
Visit BinOp
Visit Name
Visit Name
method name __mul__
Visit Return
Visit Name
Your add time: 0.010741 seconds
Torch add time: 0.015473 seconds


In [None]:
import timeit
import inspect


fn_src = inspect.getsource(add_tensor.fn)  # Get source code of the original function
fn_ast = ast.parse(fn_src)
ctx = mul_tensor.fn.__globals__.copy()
code_generator = CodeGenerator(fn_ast, ctx, target="gpu", verbose=False)  # Create CodeGenerator instance
compiled_kernel = code_generator.code_gen()  # Get compiled kernel

# Time your add function (execution only)
time_your_add = timeit.timeit(lambda: compiled_kernel(a.data, b.data), number=1000)

# Convert NDArrays back to PyTorch tensors for torch.add
a_tensor = torch.from_numpy(a.data.numpy())
b_tensor = torch.from_numpy(b.data.numpy())

# Time Torch's add function
time_torch_add = timeit.timeit(lambda: torch.add(a_tensor, b_tensor), number=1000)

print(f"Your add time: {time_your_add:.6f} seconds")
print(f"Torch add time: {time_torch_add:.6f} seconds")

In [None]:
@jit(target="cpu")
def add():
    add = 1 + 1
    res = add + 1
    return res
print("Add Result:", add())

Module(body=[FunctionDef(name='add', args=arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[Assign(targets=[Name(id='add', ctx=Store())], value=BinOp(left=Constant(value=1), op=Add(), right=Constant(value=1))), Assign(targets=[Name(id='res', ctx=Store())], value=BinOp(left=Name(id='add', ctx=Load()), op=Add(), right=Constant(value=1))), Return(value=Name(id='res', ctx=Load()))], decorator_list=[Call(func=Name(id='jit', ctx=Load()), args=[], keywords=[keyword(arg='target', value=Constant(value='cpu'))])])], type_ignores=[])
Visit Module
Visit FunctionDef
entry Function: add
Visit arguments
Visit Assign
Visit Name
Visit BinOp
Visit Constant
Visit Constant
method name __add__
Visit Assign
Visit Name
Visit BinOp
Visit Name
Visit Constant
method name __add__
Visit Return
Visit Name
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def add() -> R.Tensor((), dtype="int32"):
        gv: R.Ten

In [None]:
@jit(target="cpu")
def mul():
    res = 1 * 1
    return res

print("Mul Result:", mul())

Module(body=[FunctionDef(name='mul', args=arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[Assign(targets=[Name(id='res', ctx=Store())], value=BinOp(left=Constant(value=1), op=Mult(), right=Constant(value=1))), Return(value=Name(id='res', ctx=Load()))], decorator_list=[Call(func=Name(id='jit', ctx=Load()), args=[], keywords=[keyword(arg='target', value=Constant(value='cpu'))])])], type_ignores=[])
Visit Module
Visit FunctionDef
entry Function: mul
Visit arguments
Visit Assign
Visit Name
Visit BinOp
Visit Constant
Visit Constant
method name __mul__
Visit Return
Visit Name
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def mul() -> R.Tensor((), dtype="int32"):
        gv: R.Tensor((), dtype="int32") = R.const(1, "int32")
        gv1: R.Tensor((), dtype="int32") = R.const(1, "int32")
        res: R.Tensor((), dtype="int32") = R.multiply(gv, gv1)
        return res
===>After applied p

dot product, 1d conv, matmul