# Defining new Thunder operators

We are going to add a new operator to Thunder with the corresponding executor. The operator will be called `sincos`` and will compute the sine and cosine of a given input.

Thunder has three sets of core operators: `thunder.torch`, `thunder.clang`, and `thunder.prims`. `thunder.prims` is a set of operators that are implemented in Python and are used to build the other two sets of operators. A primitive is an operator that is not implemented in terms of other operators.

In [1]:
import thunder
import torch

from thunder.core.proxies import TensorProxy
from enum import Enum

In [2]:
#@title Helper functions (execute this cell)
import functools

_indentation = 0
def _log(msg=None):
    """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _log_indent(msg=None):
    """Print a message and then indent the rest."""
    global _indentation
    _log(msg)
    _indentation = 2 + _indentation

def _log_unindent(msg=None):
    """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 2
    _log(msg)
  
def log(func):
    """A decorator for functions to log arguments and results."""
    name = func.__name__
    def pp(v):
        """Print certain values more succinctly"""
        vtype = str(type(v))
        if isinstance(v, tuple):
            return "({})".format(pp_values(v))
        elif isinstance(v, thunder.core.proxies.TensorProxy):
            return f"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})"
        elif isinstance(v, torch.Tensor):
            return f"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}"
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])

    @functools.wraps(func)
    def func_wrapper(*args):
        _log_indent("call {}({})".format(name, pp_values(args)))
        res = func(*args)
        _log_unindent("|<- {} = {}\n".format(name, pp(res)))
        return res

    return func_wrapper

Our new operator has the following signature `sincos(x: Tensor) -> Tuple[Tensor, Tensor]`. It takes a tensor as input and returns a tuple of two tensors. The first tensor is the sine of the input and the second tensor is the cosine of the input.

We call all callables that should be recorded in the trace Symbols. Symbols are the building blocks of the trace. Symbols are either primitives or composite operators. Composite perators are implemented in terms of other operators and primitives. Primitives are operators that are not implemented in terms of other operators or primitives.

Let's create a new Symbol called `sincos` and implement it in Python.

In [3]:
from thunder.core.symbol import Symbol

help(Symbol)

Help on class Symbol in module thunder.core.symbol:

class Symbol(builtins.object)
 |  Symbol(name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = <function default_python_printer at 0x7f30f926dd80>, _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False) -> None
 |  
 |  Symbol(name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = <function default_python_printer at 0x7f30f926dd80>, _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False)
 |  
 |  Methods defined here:
 |  
 |  __call__(self, *args, **kwargs)
 |      Call self as a function.
 |  
 |  __delattr__(self, 

In [4]:
@log
def sincos_meta(input):
    return (TensorProxy(like=input), TensorProxy(like=input))

class CustomOps(Enum):
    sincos = 0

sincos = Symbol(
    id=CustomOps.sincos,
    name="sincos",
    meta=sincos_meta,
    is_prim=True,
)

That's it! We have implemented our new primitive. Let's test it.

In [5]:
def fun(a, b):
    sin, cos = sincos(a)
    return sin + cos + b

In [6]:
a = torch.randn(1, device="cuda")
b = torch.randn(1, device="cuda")

`fun` is now a Thunder function, meaning it can only accept Thunder's TensorProxy as inputs. Let's test it.

In [7]:
try:
    fun(a, b)
except Exception as e:
    print(e)

Couldn't find an eager implementation for sincos


In the future we will add support for `torch.Tensor` and `numpy.ndarray` inputs for eager mode of Thunder functions. But for now this function is working only in the tracing mode.

In [8]:
# Let's see first how this function is represented as a trace
trace = thunder.trace()(fun, a, b)
print(trace)

call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))
|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cuda:0))

# import __main__ as __main__
# import thunder as thunder
# import thunder.torch as ltorch
import torch

@torch.no_grad()
def fun(a, b):
  # a: "cuda:0 f32[1]" 
  # b: "cuda:0 f32[1]" 
  (t0, t1) = __main__.sincos(a)
  t2 = ltorch.add(t0, t1, alpha=None)  # t2: "cuda:0 f32[1]"
    # t2 = prims.add(t0, t1)  # t2: "cuda:0 f32[1]"
  t3 = ltorch.add(t2, b, alpha=None)  # t3: "cuda:0 f32[1]"
    # t3 = prims.add(t2, b)  # t3: "cuda:0 f32[1]"
  return t3


In [9]:
# We can loop over the recorded operations that we call BoundSymbols
for bound_symbol in trace.bound_symbols:
    print(f"Bound symbol with id={bound_symbol.sym.id} is represented in the trace as |{bound_symbol}|")
    if bound_symbol.subsymbols:
        print("  It has the following subsymbols:")
        for subsymbol in bound_symbol.subsymbols:
            print(f"    id={subsymbol.sym.id}  |{subsymbol}|")

Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: "cuda:0 f32[1]" |
Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: "cuda:0 f32[1]" |
Bound symbol with id=CustomOps.sincos is represented in the trace as |(t0, t1) = __main__.sincos(a)|
Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(t0, t1, alpha=None)  # t2: "cuda:0 f32[1]"
  # t2 = prims.add(t0, t1)  # t2: "cuda:0 f32[1]"|
  It has the following subsymbols:
    id=PrimIDs.ADD  |t2 = prims.add(t0, t1)  # t2: "cuda:0 f32[1]"|
Bound symbol with id=torch.add is represented in the trace as |t3 = ltorch.add(t2, b, alpha=None)  # t3: "cuda:0 f32[1]"
  # t3 = prims.add(t2, b)  # t3: "cuda:0 f32[1]"|
  It has the following subsymbols:
    id=PrimIDs.ADD  |t3 = prims.add(t2, b)  # t3: "cuda:0 f32[1]"|
Bound symbol with id=PrimIDs.RETURN is represented in the trace as |return t3|


Let's see what happens if we try to compile a function that uses our new primitive and run it.

In [10]:
cfun = thunder.compile(fun, disable_preprocessing=True)

In [11]:
try:
    cfun(a, b)
except RuntimeError as e:
    print(e)

call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))
|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cuda:0))

Could not find executor for bound symbol (t0, t1) = __main__.sincos(a)


There's no registered executor for `sincos` so we need to register an executor for our new primitive. Let's do that.

Check out the "adding-operator-executor.ipynb" notebook to see how to implement an executor for a Symbol.

In [12]:
from thunder.executors import add_operator_executor

@log
def checker_sincos(a):
    # We allow the sincos function to be called with any tensor
    return True

@log
def executor_sincos(a):
    return torch.sin(a), torch.cos(a)

op_map = {
    CustomOps.sincos: ("sincos", checker_sincos, executor_sincos)
}

add_operator_executor("sincos_executor", op_map, add_to_default_executors=True)

In [13]:
# Let's try again
cfun = thunder.compile(fun, disable_preprocessing=True)

In [14]:
cfun(a, b)

call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))
|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cuda:0))

call checker_sincos(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))
|<- checker_sincos = True

call checker_sincos(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))
|<- checker_sincos = True

call executor_sincos(Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([-0.6296], device='cuda:0'))
|<- executor_sincos = (Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([-0.5889], device='cuda:0'), Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([0.8082], device='cuda:0'))



tensor([0.1889], device='cuda:0')

In [15]:
# Let's check how our function is represented in the execution trace now
thunder.last_traces(cfun)[-1]

# Constructed by Delete Last Used
import torch
@torch.no_grad()
def fun(a, b):
  # a: "cuda:0 f32[1]" 
  # b: "cuda:0 f32[1]" 
  (t0, t1) = sincos(a)
  del [a]
  (t3,) = nvFusion0(b, t0, t1)
    # t2 = prims.add(t0, t1)  # t2: "cuda:0 f32[1]"
    # t3 = prims.add(t2, b)  # t3: "cuda:0 f32[1]"
  del [b, t0, t1]
  return t3

That's it! We've created our custom operator and registered an executor for it. To recap, we've done the following:
* Created a new Symbol called `sincos` that represents the sine and cosine
  computation (but not the actual computation itself). All we know about it is
  that it takes a tensor as input and returns a tuple of two tensors. We gave this Symbol a name and id attributes to identify it in the trace and when processing the trace.
* Implemented the actual computation by calling PyTorch's `sin` and `cos` functions.