Skip to content

Torch Function Implementation

Prasun Anand edited this page Sep 18, 2019 · 7 revisions

Based on NEP 18, this document summarises the implementation of torch_function protocol for PyTorch APIs.

Background

  1. Pytorch Issues Page
  2. NEP-18

Dispatcher

  1. Add a dispatch decorator to Pytorch methods.
  2. The dispatcher then verifies function signature. It checks if the args supplied and args implemented are same, then the code is further executed.
  3. Next step is generating the source of code compiling it to Python and injecting it as Public API.
  4. The code generation checks for overloaded args. This means if the arg supplied has torch_function define and it handles the torch.operator API, it would use that implementation else it falls back to Torch implemetation. Note that the return type would depend on the type of first arg supplied. See NEP - 18 for more such details.

Implementing the dispatcher

The dispatcher when used as a decorator around a Torch method treats it differently. It checks if the arguments provided by the dispatcher and those accepted by the Torch method match. Then it exposes it as a public API.

The implemetation is actually handled by implement_torch_function. implement_torch_function then checks if the arg has been overloaded with a custom function. If not overloaded use the torch API that is exposed as public else use the implementation provided by the overloaded arg.

def torch_function_dispatch(dispatcher, module=None, verify=True,
                            docs_from_dispatcher=False):
    def decorator(implementation):
        if verify:
            verify_matching_signatures(implementation, dispatcher)

        if docs_from_dispatcher:
            add_docstring(implementation, dispatcher.__doc__)

        # Equivalently, we could define this function directly instead of using
        # exec. This version has the advantage of giving the helper function a
        # more interpretable name. Otherwise, the original function does not
        # show up at all in many cases, e.g., if it's written in C++ or if the
        # dispatcher gets an invalid keyword argument.
        source = _wrapped_func_source.format(name=implementation.__name__)

        source_object = compile(
            source, filename='<__torch_function__ internals>', mode='exec')
        scope = {
            'implementation': implementation,
            'dispatcher': dispatcher,
            'functools': functools,
            'implement_torch_function': implement_torch_function,
        }
        exec(source_object, scope)

        public_api = scope[implementation.__name__]

        if module is not None:
            public_api.__module__ = module

        public_api._implementation = implementation

        return public_api

    return decorator

Dispatcher types

A dispatcher takes a list of args passed and returns them as the members of a tuple.

For example:

def gemm_dispatcher(input, mat2, out=None):
    return (input, mat2, out)

Verify Function Signature

verify_matching_signatures check if the args passed by the dispatcher are the same as expected by the original implementation in methods provided by the Torch library.

ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')

def verify_matching_signatures(implementation, dispatcher):
    """Verify that a dispatcher function has the right signature."""
    implementation_spec = ArgSpec(*getargspec(implementation))
    dispatcher_spec = ArgSpec(*getargspec(dispatcher))

    if (implementation_spec.args != dispatcher_spec.args or
            implementation_spec.varargs != dispatcher_spec.varargs or
            implementation_spec.keywords != dispatcher_spec.keywords or
            (bool(implementation_spec.defaults) !=
             bool(dispatcher_spec.defaults)) or
            (implementation_spec.defaults is not None and
             len(implementation_spec.defaults) !=
             len(dispatcher_spec.defaults))):
        raise RuntimeError('implementation and dispatcher for %s have '
                           'different function signatures' % implementation)

    if implementation_spec.defaults is not None:
        if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
            raise RuntimeError('dispatcher functions can only use None for '
                               'default argument values')

Codegeneration and injecting it to public API

If the signatures match wrapped_func_source generates Python code corresponding to the implementation, compiles it and exposes it as Public API.

_wrapped_func_source = textwrap.dedent("""
    @functools.wraps(implementation)
    def {name}(*args, **kwargs):
        relevant_args = dispatcher(*args, **kwargs)
        return implement_torch_function(
            implementation, {name}, relevant_args, args, kwargs)
    """)

torch_function_implementation

The

def implement_torch_function(
        implementation, public_api, relevant_args, args, kwargs):

    # Check for __torch_function__ methods.
    types, overloaded_args = get_overloaded_types_and_args(relevant_args)
    # Short-cut for common cases: no overload or only Tensor overload
    # (directly or with subclasses that do not override __torch_function__).
    if (not overloaded_args or types == _TENSOR_ONLY or
            all(type(arg).__torch_function__ is _TORCH_FUNCTION
                for arg in overloaded_args)):
        return implementation(*args, **kwargs)

    # Call overrides
    for overloaded_arg in overloaded_args:
        # Use `public_api` instead of `implemenation` so __torch_function__
        # implementations can do equality/identity comparisons.
        result = overloaded_arg.__torch_function__(
            public_api, types, args, kwargs)

        if result is not NotImplemented:
            return result

    func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
    raise TypeError("no implementation found for '{}' on types that implement "
                    '__torch_function__: {}'
                    .format(func_name, list(map(type, overloaded_args))))

DuckTensor example

Thecks if the args supplied and args implemented are same.

Examples:

Implementing torch.gemm routine

Implementing torch.lu routine

Tests

The test can be found here .

Benchmarks

The benchmark code was added in this commit.

Moving to C++

In torch/crsc/autograd/generated/python_torch_functions.cpp we try to inject our code.

static PyObject * THPVariable_mean(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  std::cout << "hello world from mean!" << std::endl;
  static PythonArgParser parser({
    "mean(Tensor input, *, ScalarType? dtype=None)",
    "mean(Tensor input, IntArrayRef[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor out=None)",
  }, /*traceable=*/true);

  ParsedArgs<5> parsed_args;
  auto r = parser.parse2(args, kwargs, parsed_args); 
  
  # check if r.torch_function_dispatch == true and then look for r.tensor_like.HANDLED_FUNCTIONS[r.function_name]
  # return call(r.tensor_like[r.function_name], args, kwargs);

  if (r.idx == 0) {
    return wrap(dispatch_mean(r.tensor(0), r.scalartypeOptional(1)));
  } else if (r.idx == 1) {
    if (r.isNone(4)) {
      return wrap(dispatch_mean(r.tensor(0), r.intlist(1), r.toBool(2), r.scalartypeOptional(3)));
    } else {
      return wrap(dispatch_mean(r.tensor(0), r.intlist(1), r.toBool(2), r.scalartypeOptional(3), r.tensor(4)));
    }
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

Parse2

template<int N>
inline PythonArgs PythonArgParser::parse2(PyObject* args, PyObject* kwargs, ParsedArgs<N>& dst) {
  if (N < max_args) {
    throw ValueError("PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)",
        (int)max_args, N);
  }
  return raw_parse2(args, kwargs, dst.args);
}

PythonArgParser ctor

PythonArgParser::PythonArgParser(std::vector<std::string> fmts, bool traceable)
 : max_args(0)
 , traceable(traceable)
{
  std::cout << "fmts is => " << fmts << std::endl;
  for (auto& fmt : fmts) {
    std::cout << "fmt=> " << fmt << std::endl;
    signatures_.emplace_back(fmt);
  }
  for (auto& signature : signatures_) {
    if (signature.max_args > max_args) {
      max_args = signature.max_args;
    }
  }
  if (signatures_.size() > 0) {
    function_name = signatures_[0].name;
    std::cout << "function_name is => " << function_name << std::endl;
  }
}

PythonArgParser::raw_parse2

PythonArgs PythonArgParser::raw_parse2(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {
  std::cout << "In PythonArgParser::raw_parse" << std::endl;
  if (signatures_.size() == 1) {
    auto& signature = signatures_[0];
    signature.parse2(args, kwargs, parsed_args, true);
    auto x =  PythonArgs(0, traceable, signature, parsed_args);
    return x;
  }

  int i = 0;
  for (auto& signature : signatures_) {
    if (signature.parse2(args, kwargs, parsed_args, false)) {
      auto x = PythonArgs(i, traceable, signature, parsed_args);
      return x;
    }
    i++;
  }

  print_error(args, kwargs, parsed_args);
}

Function_signature::parse2

Here all the signatures are validated from args and kwargs. Add a check if we detect a tensor like PyObject with torch_function defined. Then collect all such arguments in an overloaded_args and overloaded_types list. Note that the inserting into overloaded_args list needs to check subclass.

bool FunctionSignature::parse2(PyObject* args, PyObject* kwargs, PyObject* dst[],
                              bool raise_exception) {
  std::cout << "FunctionSignature::parse2, Trying to find out torch_Function" << std::endl;
  auto nargs = PyTuple_GET_SIZE(args);
  std::cout << "nargs ->" << nargs << std::endl;
  ssize_t remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
  ssize_t arg_pos = 0;
  bool allow_varargs_intlist = false;

  // if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
  // allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
  if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
    allow_varargs_intlist = true;
  }

  if (nargs > max_pos_args && !allow_varargs_intlist) {
    if (raise_exception) {
      // foo() takes takes 2 positional arguments but 3 were given
      extra_args(*this, nargs);
    }
    return false;
  }

  int i = 0;
  for (auto& param : params) {
    PyObject* obj = nullptr;
    bool is_kwd = false;
    if (arg_pos < nargs) {
      // extra positional args given after single positional IntArrayRef arg
      if (param.keyword_only) {
        if (raise_exception) {
          extra_args(*this, nargs);
        }
        return false;
      }
      obj = PyTuple_GET_ITEM(args, arg_pos);
    } else if (kwargs) {
      obj = PyDict_GetItem(kwargs, param.python_name);
      for (PyObject *numpy_name: param.numpy_python_names) {
        if (obj) {
          break;
        }
        obj = PyDict_GetItem(kwargs, numpy_name);
      }
      is_kwd = true;
    }

    if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
      dst[i++] = nullptr;
    } else if (!obj) {
      if (raise_exception) {
        // foo() missing 1 required positional argument: "b"
        missing_args(*this, i);
      }
      return false;
    } else if (param.check2(obj, args, kwargs)) {
      dst[i++] = obj;
    // XXX: the Variable check is necessary because sizes become tensors when
    // tracer is enabled. This behavior easily leads to ambiguities, and we
    // should avoid having complex signatures that make use of it...
    } else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
               THPUtils_checkIndex(obj)) {
      // take all positional arguments as this parameter
      // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
      dst[i++] = args;
      arg_pos = nargs;
      continue;
    } else if (raise_exception) {
      if (is_kwd) {
        // foo(): argument 'other' must be str, not int
        throw TypeError("%s(): argument '%s' must be %s, not %s",
            name.c_str(), param.name.c_str(), param.type_name().c_str(),
            Py_TYPE(obj)->tp_name);
      } else {
        // foo(): argument 'other' (position 2) must be str, not int
        throw TypeError("%s(): argument '%s' (position %d) must be %s, not %s",
            name.c_str(), param.name.c_str(), arg_pos + 1,
            param.type_name().c_str(), Py_TYPE(obj)->tp_name);
      }
    } else {
      return false;
    }

    if (!is_kwd) {
      arg_pos++;
    } else if (obj) {
      remaining_kwargs--;
    }
  }

  if (remaining_kwargs > 0) {
    if (raise_exception) {
      // foo() got an unexpected keyword argument "b"
      extra_kwargs(*this, kwargs, nargs);
    }
    return false;
  }

  return true;
}

Code Generation for Python Bindings

Once you have it figured out you need to look into gen_python_functions.py

PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  std::cout << "hello world!" << std::endl;   // added this to check
  static PythonArgParser parser({
    ${signatures}
  }, /*traceable=*/${traceable});
  ${unpack_self}
  ParsedArgs<${max_args}> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  ${declare_namedtuple_return_types}
  ${dispatch}                                 // modify dispatch to check for torch function and call torch
                                              // function or use central dispatch machinery. See the example
                                              // below
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}
""")

PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
{
  HANDLE_TH_ERRORS
  ${declare_namedtuple_return_types}
  ${unpack_self}
  return wrap(${namedtuple_return_type}${dispatch_name}(${actuals}));
  END_HANDLE_TH_ERRORS
}
""")

Now we would need to modify ${dispatch} .

static PyObject * THPVariable_mean(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  std::cout << "hello world from mean!" << std::endl;
  static PythonArgParser parser({
    "mean(Tensor input, *, ScalarType? dtype=None)",
    "mean(Tensor input, IntArrayRef[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor out=None)",
  }, /*traceable=*/true);

  ParsedArgs<5> parsed_args;
  auto r = parser.parse2(args, kwargs, parsed_args);
  std::cout << "parsed and got r" << std::endl;

  if(r.has_torch_function()){
    std::cout << "Found torch_function" << std::endl;
    PyObject* handled_functions = maybe_get_attr(r.get_overloaded_arg(0), "__torch_function__");
    // How to get handled_functions[get torch.mean]
    return PyObject_CallFunctionObjArgs(handled_functions, PyUnicode_FromString(r.get_func_name().data()), args, kwargs, NULL);
  }
  else{
    std::cout << "Not found torch_function" << std::endl;
    if (r.idx == 0) {
      return wrap(dispatch_mean(r.tensor(0), r.scalartypeOptional(1)));
    } else if (r.idx == 1) {
      if (r.isNone(4)) {
        return wrap(dispatch_mean(r.tensor(0), r.intlist(1), r.toBool(2), r.scalartypeOptional(3)));
      } else {
        return wrap(dispatch_mean(r.tensor(0), r.intlist(1), r.toBool(2), r.scalartypeOptional(3), r.tensor(4)));
      }
    }
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

The goal here is to generate something like the code snippet provided above.