# Defining custom forward and backward for existing operators

We are going to add custom executor for forward and backward of `torch.nn.functional.cross_entropy` operator.

Here's `SoftmaxCrossEntropyLoss` definition from https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py:

```py
import torch

import xentropy_cuda


class SoftmaxCrossEntropyLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
        losses, max_log_sum_exp = xentropy_cuda.forward(
            logits, labels, smoothing, half_to_float)
        losses.masked_fill_(labels==padding_idx, 0)

        ctx.save_for_backward(logits, max_log_sum_exp, labels,
            torch.FloatTensor([smoothing]),
            torch.LongTensor([padding_idx]))

        return losses

    @staticmethod
    def backward(ctx, grad_loss):
        logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors

        if not grad_loss.is_contiguous():
            grad_loss = grad_loss.contiguous()
        grad_loss.masked_fill_(labels==padding_idx.item(), 0)
        grad_logits = xentropy_cuda.backward(
            grad_loss.contiguous(), logits, max_log_sum_exp,
            labels, smoothing.item())

        return grad_logits, None, None, None, None
```

In [1]:
import thunder
import torch

from thunder.core.proxies import TensorProxy

In [2]:
#@title Helper functions (execute this cell)
import functools

_indentation = 0
def _log(msg=None):
    """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _log_indent(msg=None):
    """Print a message and then indent the rest."""
    global _indentation
    _log(msg)
    _indentation = 2 + _indentation

def _log_unindent(msg=None):
    """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 2
    _log(msg)
  
def log(func):
    """A decorator for functions to log arguments and results."""
    name = func.__name__
    def pp(v):
        """Print certain values more succinctly"""
        vtype = str(type(v))
        if isinstance(v, tuple):
            return "({})".format(pp_values(v))
        elif isinstance(v, thunder.core.proxies.TensorProxy):
            return f"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})"
        elif isinstance(v, torch.Tensor):
            return f"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}"
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])

    @functools.wraps(func)
    def func_wrapper(*args):
        _log_indent("call {}({})".format(name, pp_values(args)))
        res = func(*args)
        _log_unindent("|<- {} = {}\n".format(name, pp(res)))
        return res

    return func_wrapper

We need to make `xentropy_cuda.forward` and `xentropy_cuda.backward` traceable by Thunder, for this we need to create corresponding Symbols.

In [3]:
from thunder.core.symbol import Symbol

help(Symbol)

Help on class Symbol in module thunder.core.symbol:

class Symbol(builtins.object)
 |  Symbol(name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = <function default_python_printer at 0x7f946b601a20>, _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False) -> None
 |  
 |  Symbol(name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = <function default_python_printer at 0x7f946b601a20>, _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False)
 |  
 |  Methods defined here:
 |  
 |  __call__(self, *args, **kwargs)
 |      Call self as a function.
 |  
 |  __delattr__(self, 

In [4]:
@log
def apex_xentropy_forward_meta(
    a,
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
    max_log_sum_exp = TensorProxy(like=target)
    if reduction == "none":
        return TensorProxy(like=target), max_log_sum_exp
    else:
        raise ValueError(f"Invalid reduction: {reduction}")

xentropy_forward = Symbol(
    id="xentropy_forward",
    name="xentropy_forward",
    meta=apex_xentropy_forward_meta,
    is_prim=True,
)

In [5]:
@log
def apex_xentropy_backward_meta(
    grad,
    logits,
    labels,
    max_log_sum_exp,
    smoothing,
):
    return TensorProxy(like=logits)

xentropy_backward = Symbol(
    id="xentropy_backward",
    name="xentropy_backward",
    meta=apex_xentropy_backward_meta,
    is_prim=True,
)

In [6]:
from thunder.core.transforms import register_augmented_forward

@register_augmented_forward("torch.nn.functional.cross_entropy")
def apex_cross_entropy_forward_rule(
    a,
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
    loss, max_log_sum_exp = xentropy_forward(
        a,
        target,
        weight,
        size_average,
        ignore_index,
        reduce,
        reduction,
        label_smoothing,
    )
    primal = loss
    saved_for_backward = (a, target, max_log_sum_exp, reduction, label_smoothing)
    return primal, saved_for_backward

In [7]:
from thunder.core.transforms import register_backward

@register_backward("torch.nn.functional.cross_entropy")
def apex_cross_entropy_backward_rule(
    logits,
    labels,
    max_log_sum_exp,
    reduction,
    smoothing,
    grad,
):
    if reduction != "none":
        raise ValueError(f"Invalid reduction: {reduction}")

    grad_logits = xentropy_backward(
        grad,
        logits,
        labels,
        max_log_sum_exp,
        smoothing,
    )
    return grad_logits, *([None] * 7)

In [8]:
from thunder.core.transforms import inline, value_and_grad
from thunder import torch as ltorch

torch.manual_seed(0)

logits = torch.randn([2048, 50257], device="cuda")
labels = torch.randint(0, 50257, [2048], device="cuda")

@inline
@value_and_grad
def fun(logits, labels):
    return ltorch.cross_entropy(logits, labels, reduction="none", ignore_index=-1)

trace = thunder.trace()(fun, logits, labels)
print(trace)

call apex_xentropy_forward_meta(TensorProxy(name=t0, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2048,), dtype=int64, device=cuda:0), None, None, -1, None, none, 0.0)
|<- apex_xentropy_forward_meta = (TensorProxy(name=t4, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t3, shape=(2048,), dtype=int64, device=cuda:0))

call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t3, shape=(2048,), dtype=int64, device=cuda:0), 0.0)
|<- apex_xentropy_backward_meta = TensorProxy(name=t5, shape=(2048, 50257), dtype=float32, device=cuda:0)

call apex_xentropy_forward_meta(TensorProxy(name=t0, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2048,), dtype=int64, device=cuda:0), None, None, -1, None, none, 0.0)
|<- apex_xentr

In [9]:
import xentropy_cuda

@log
def apex_xentropy_forward_impl(
    a,
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
    losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, False)

    if reduction == "none":
        losses = losses.to(a.dtype)
    else:
        raise ValueError(f"Invalid reduction: {reduction}")

    return losses, max_log_sum_exp

@log
def apex_xentropy_backward_impl(
    grad,
    logits,
    labels,
    max_log_sum_exp,
    smoothing,
):
    return xentropy_cuda.backward(grad.contiguous(), logits, max_log_sum_exp, labels, smoothing)

In [10]:
always_executable = lambda *args, **kwargs: True

op_to_xentropy = {
    "xentropy_forward": ("xentropy_forward_impl", always_executable, apex_xentropy_forward_impl),
    "xentropy_backward": ("xentropy_backward_impl", always_executable, apex_xentropy_backward_impl),
}

In [11]:
from thunder.executors import add_operator_executor

add_operator_executor("xentropy", op_to_xentropy)

That's it! We have implemented our custom forward and backward for `torch.nn.functional.cross_entropy` operator. Let's test it.

The trace looks correct, let's check whether we can compile and run it.

In [12]:
cfun = thunder.compile(fun, disable_preprocessing=True)

In [13]:
out, (grad_logits, grad_labels) = cfun(logits, labels)

call apex_xentropy_forward_meta(TensorProxy(name=t0, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2048,), dtype=int64, device=cuda:0), None, None, -1, None, none, 0.0)
|<- apex_xentropy_forward_meta = (TensorProxy(name=t4, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t3, shape=(2048,), dtype=int64, device=cuda:0))

call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t3, shape=(2048,), dtype=int64, device=cuda:0), 0.0)
|<- apex_xentropy_backward_meta = TensorProxy(name=t5, shape=(2048, 50257), dtype=float32, device=cuda:0)

call apex_xentropy_forward_meta(TensorProxy(name=t0, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2048,), dtype=int64, device=cuda:0), None, None, -1, None, none, 0.0)
|<- apex_xentr

Our logging shows that we have successfully compiled and run our custom forward and backward for `torch.nn.functional.cross_entropy` operator calling `xentropy_cuda.forward` and `xentropy_cuda.backward` respectively. Let's check the correctness of our implementation.

In [14]:
# Let's compute the gradients with respect to the logits using PyTorch's autograd and compare
logits.requires_grad_(True)

loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none", ignore_index=-1)
logits.grad = None
loss.sum().backward()

print("Max error in logits grad:", (grad_logits - logits.grad).abs().max())

Max error in logits grad: tensor(1.3970e-09, device='cuda:0')


That's it! We have successfully implemented custom forward and backward for `torch.nn.functional.cross_entropy` operator using CUDA extension from Apex. The same approach can be used for any other operator. The key is to make forward and backward traceable by Thunder by creating corresponding Symbols and registering an executor for them.