Skip to content

66RING/pytorch-cuda-binding-tutorial

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pytorch cuda binding tutorial

This repository contains a cheatsheet code for binding C and CUDA function for pytorch. Read pytorch doc for more detail.

Usage

Build

python build.py install

Use the custom package

# NOTE: import torch to include some shared lib
import torch
from tiny_api_c import hello as hello_c
from tiny_api_cuda import hello as hello_cuda

def main():
    hello_c()
    hello_cuda()

if __name__ == "__main__":
    main()

Roadmap

  • api binding
  • torch data binding

API binding

  • Use PYBIND11_MODULE to bind API
#include <torch/python.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  // m.def("package_name", &function_name, "function_docstring"")
  m.def("hello", &hello, "Prints hello world from cuda file");
  m.def("vector_add", &vector_add, "Add two vectors on cuda");
}

data binding

  • torch::Tensor as tensor type
  • tensor.data() to get the underlaying pointer
  • AT_DISPATCH_FLOATING_TYPES() to determing data type. e.g. AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ([&]{your_kernel_call<scalar_t>();}))
    • the typename scalar_t is needed for AT_DISPATCH_FLOATING_TYPES
    • more

AT_DISPATCH_FLOATING_TYPES() can be done by some thing like this.

switch (tensor.type().scalarType()) {
  case torch::ScalarType::Double:
    return function<double>(tensor.data<double>());
  case torch::ScalarType::Float:
    return function<float>(tensor.data<float>());
  ...
}

About

Tutorial for building a custom CUDA and C function for torch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published