# Adding an operator executor

We are going to write a simple executor for `prims.add` function that calls NumPy's addition function. Our executor will be restricted to only work with inputs with certain properties. We will use the `add_operator_executor` function to create our executor.

In [1]:
import thunder
import torch
import numpy as np

In [2]:
#@title Helper functions (execute this cell)
import functools

_indentation = 0
def _log(msg=None):
    """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _log_indent(msg=None):
    """Print a message and then indent the rest."""
    global _indentation
    _log(msg)
    _indentation = 2 + _indentation

def _log_unindent(msg=None):
    """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 2
    _log(msg)
  
def log(func):
    """A decorator for functions to log arguments and results."""
    name = func.__name__
    def pp(v):
        """Print certain values more succinctly"""
        vtype = str(type(v))
        if isinstance(v, tuple):
            return "({})".format(pp_values(v))
        elif isinstance(v, thunder.core.proxies.TensorProxy):
            return f"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})"
        elif isinstance(v, torch.Tensor):
            return f"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}"
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])

    @functools.wraps(func)
    def func_wrapper(*args):
        _log_indent("call {}({})".format(name, pp_values(args)))
        res = func(*args)
        _log_unindent("|<- {} = {}\n".format(name, pp(res)))
        return res

    return func_wrapper

In [3]:
# This is our test function
def fun(a, b):
    return a + b * a

In [4]:
# This is our test input
a = torch.randn(2, 2, device="cuda")
b = torch.randn(2, 1, device="cuda")

In [5]:
# Let's see first how this function is represented as a trace
trace = thunder.trace()(fun, a, b)
print(trace)

# import thunder as thunder
# import thunder.torch as ltorch
import torch

@torch.no_grad()
def fun(a, b):
  # a: "cuda:0 f32[2, 2]" 
  # b: "cuda:0 f32[2, 1]" 
  t1 = ltorch.mul(b, a)  # t1: "cuda:0 f32[2, 2]"
    # t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1))  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.mul(t0, a)  # t1: "cuda:0 f32[2, 2]"
  t2 = ltorch.add(a, t1, alpha=None)  # t2: "cuda:0 f32[2, 2]"
    # t2 = prims.add(a, t1)  # t2: "cuda:0 f32[2, 2]"
  return t2


In [6]:
# We can loop over the recorded operations that we call BoundSymbols
for bound_symbol in trace.bound_symbols:
    print(f"Bound symbol with id={bound_symbol.sym.id} is represented in the trace as |{bound_symbol}|")
    if bound_symbol.subsymbols:
        print("  It has the following subsymbols:")
        for subsymbol in bound_symbol.subsymbols:
            print(f"    id={subsymbol.sym.id}  |{subsymbol}|")

Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: "cuda:0 f32[2, 2]" |
Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: "cuda:0 f32[2, 1]" |
Bound symbol with id=torch.mul is represented in the trace as |t1 = ltorch.mul(b, a)  # t1: "cuda:0 f32[2, 2]"
  # t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1))  # t0: "cuda:0 f32[2, 2]"
  # t1 = prims.mul(t0, a)  # t1: "cuda:0 f32[2, 2]"|
  It has the following subsymbols:
    id=PrimIDs.BROADCAST_IN_DIM  |t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1))  # t0: "cuda:0 f32[2, 2]"|
    id=PrimIDs.MUL  |t1 = prims.mul(t0, a)  # t1: "cuda:0 f32[2, 2]"|
Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(a, t1, alpha=None)  # t2: "cuda:0 f32[2, 2]"
  # t2 = prims.add(a, t1)  # t2: "cuda:0 f32[2, 2]"|
  It has the following subsymbols:
    id=PrimIDs.ADD  |t2 = prims.add(a, t1)  # t2: "cuda:0 f32[2, 2]"|
Bound symbol with id=PrimIDs.RETURN is represented in the tr

In [7]:
from thunder.executors import add_operator_executor

help(add_operator_executor)

Help on function add_operator_executor in module thunder.executors:

add_operator_executor(name, op_map, *, add_to_default_executors: bool = True) -> None



The key argument here is `op_map`.

`op_map` is a dictionary with the id of the operator we're providing executor for as a key and `(name, checker_fn, implementation_fn)` tuple as a value.

* `name` is the name of our execution function that would be appearing in the execution trace.
* `checker_fn` accepts the same set of arguments as the operator itself but returns `True` or `False` to signal to the executor orchestrator whether this particular set of inputs is supported or not.
* `implementation_fn` accepts real PyTorch tensors and expected to return PyTorch tensors.

In [8]:
# Let's define the addition function that can work only with NumPy's ndarrays

@log
def add_numpy(a, b):
    assert isinstance(a, np.ndarray), "a must be a NumPy ndarray"
    assert isinstance(b, np.ndarray), "b must be a NumPy ndarray"
    return np.add(a, b)


In [9]:
# We also need conversion functions from PyTorch to NumPy and back
@log
def torch_to_numpy(tensors):
    return tuple(t.detach().cpu().numpy() for t in tensors)

@log
def numpy_to_torch(arrays, device):
    return tuple(torch.from_numpy(arr).to(device) for arr in arrays)

In [10]:
@log
def checker_add_numpy(a, b):
    # Suppose we only support float32 dtype, 2D, and (2, N) shape
    first_condition = a.dtype == b.dtype == thunder.dtypes.float32
    second_condition = a.ndim == b.ndim == 2
    third_condition = a.shape[0] == b.shape[0] == 2
    return first_condition and second_condition and third_condition

In [11]:
@log
def executor_add_numpy(a, b):
    np_a, np_b = torch_to_numpy((a, b))
    np_res = add_numpy(np_a, np_b)
    res, = numpy_to_torch((np_res,), a.device)
    return res

Now we have all the pieces to create our executor.

In [12]:
op_map = {
    thunder.prims.PrimIDs.ADD: ("add_numpy", checker_add_numpy, executor_add_numpy)
}

In [13]:
# Let's send our operator map to `add_operator_executor` to register our executor under the name "custom_add_executor"

add_operator_executor("custom_add_executor", op_map, add_to_default_executors=False)

In [14]:
# Let's test our executor

cfun = thunder.compile(fun, executors_list=["custom_add_executor"])

In [15]:
try:
    cfun(a, b)
except RuntimeError as e:
    print(e)

Could not find executor for bound symbol t1 = ltorch.mul(b, a)  # t1: "cuda:0 f32[2, 2]"
  # t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1))  # t0: "cuda:0 f32[2, 2]"
  # t1 = prims.mul(t0, a)  # t1: "cuda:0 f32[2, 2]"


The above function errors out because we haven't provided an executor for `ltorch.mul` yet. Let's do that.

In [16]:
cfun = thunder.compile(fun, executors_list=["custom_add_executor", thunder.executors.TORCH])

In [17]:
cfun(a, b)

call checker_add_numpy(TensorProxy(name=a, shape=(2, 2), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2, 2), dtype=float32, device=cuda:0))
|<- checker_add_numpy = True

call checker_add_numpy(TensorProxy(name=a, shape=(2, 2), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2, 2), dtype=float32, device=cuda:0))
|<- checker_add_numpy = True

call executor_add_numpy(Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],
        [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],
        [-0.2897,  0.0392]], device='cuda:0'))
    call torch_to_numpy((Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],
        [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, d

tensor([[-2.0177, -2.8520],
        [ 0.6923, -0.0936]], device='cuda:0')

Our logging decorator shows us that the `checker_add_numpy` function got called twice with `TensorProxy` as arguments and both times the function returned `True`. This means that our executor is going to be used for this particular execution trace.

Then we see that the `executor_add_numpy` function is called with regular PyTorch tensors as arguments and it returns a regular PyTorch tensor.

In [18]:
# Let's check how our function is represented in the execution trace now
thunder.last_traces(cfun)[-1]

# Constructed by Delete Last Used
# import torch as torch
import torch

@torch.no_grad()
def fun(a, b):
  # a: "cuda:0 f32[2, 2]" 
  # b: "cuda:0 f32[2, 1]" 
  t1 = torch.mul(b, a)  # t1: "cuda:0 f32[2, 2]"
  del [b]
  t2 = add_numpy(a, t1)  # t2: "cuda:0 f32[2, 2]"
  del [a, t1]
  return t2

In [19]:
# Let's test whether the result is correct
cfun_torch = thunder.compile(fun, executors_list=[thunder.executors.TORCH])
expected = cfun_torch(a, b)
actual = cfun(a, b)
torch.testing.assert_close(expected, actual) # Should not raise an exception

call executor_add_numpy(Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],
        [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],
        [-0.2897,  0.0392]], device='cuda:0'))
    call torch_to_numpy((Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],
        [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],
        [-0.2897,  0.0392]], device='cuda:0')))
    |<- torch_to_numpy = ([[-0.6905969  -0.97613984]
 [ 0.98193294 -0.13276565]], [[-1.3271405  -1.8758768 ]
 [-0.28966585  0.03916528]])

    call add_numpy([[-0.6905969  -0.97613984]
 [ 0.98193294 -0.13276565]], [[-1.3271405  -1.8758768 ]
 [-0.28966585  0.03916528

In [20]:
from thunder.tests.opinfos import add_opinfo

sample = next(add_opinfo.sample_input_generator(add_opinfo, device="cuda", dtype=torch.float32, requires_grad=False))
print(sample)

[SampleInput args=(tensor([[-5.6039,  5.0201, -8.2948, -0.1738],
        [ 8.4915, -2.8353, -7.4601, -4.3015],
        [ 6.0777, -7.6420,  3.4135,  3.2371],
        [-0.8413, -1.7334, -1.0025, -0.7366]], device='cuda:0'), tensor([[ 4.5391,  1.5542,  7.9208, -1.3760],
        [-6.5864,  8.6491,  6.1823, -1.8481],
        [ 7.9385, -0.4884,  4.2281,  1.3158],
        [-4.6107,  3.5805,  3.1749, -4.5989]], device='cuda:0')) kwargs={}]


In [21]:
# Let's test whether the result is correct
expected = cfun_torch(*sample.args)
actual = cfun(*sample.args)
torch.testing.assert_close(expected, actual) # Should not raise an exception

call checker_add_numpy(TensorProxy(name=a, shape=(4, 4), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(4, 4), dtype=float32, device=cuda:0))
|<- checker_add_numpy = False



In [22]:
# The order of executors matters today
cfun_torch_first = thunder.compile(fun, executors_list=[thunder.executors.TORCH, "custom_add_executor"])
cfun_torch_first(a, b)
thunder.last_traces(cfun_torch_first)[-1]

# Constructed by Delete Last Used
# import torch as torch
import torch

@torch.no_grad()
def fun(a, b):
  # a: "cuda:0 f32[2, 2]" 
  # b: "cuda:0 f32[2, 1]" 
  t1 = torch.mul(b, a)  # t1: "cuda:0 f32[2, 2]"
  del [b]
  t2 = torch.add(a, t1)  # t2: "cuda:0 f32[2, 2]"
  del [a, t1]
  return t2

In [23]:
# Let's try inputs that are not supported by our executor
a = torch.randn(3, 2, device="cuda", dtype=torch.float64)
b = torch.randn(3, 1, device="cuda", dtype=torch.float64)

In [24]:
# Let's see how our function is represented in the execution trace now with the new unsupported inputs
cfun(a, b)
thunder.last_traces(cfun)[-1]

call checker_add_numpy(TensorProxy(name=a, shape=(3, 2), dtype=float64, device=cuda:0), TensorProxy(name=t1, shape=(3, 2), dtype=float64, device=cuda:0))
|<- checker_add_numpy = False



# Constructed by Delete Last Used
# import torch as torch
import torch

@torch.no_grad()
def fun(a, b):
  # a: "cuda:0 f64[3, 2]" 
  # b: "cuda:0 f64[3, 1]" 
  t1 = torch.mul(b, a)  # t1: "cuda:0 f64[3, 2]"
  del [b]
  t2 = torch.add(a, t1)  # t2: "cuda:0 f64[3, 2]"
  del [a, t1]
  return t2

That's it! We've created our first executor. The process is very similar for other existing operators. There are two ingridients that are required to create an executor:
* `checker_fn` that checks whether the executor is applicable for a particular set of inputs (works with `TensorProxy` objects),
* `implementation_fn` that implements the operator (works with regular PyTorch tensors).