Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] FP16 support #1413

Merged
merged 13 commits into from
Jul 20, 2018
Merged

[CUDA] FP16 support #1413

merged 13 commits into from
Jul 20, 2018

Conversation

nishi-t
Copy link
Contributor

@nishi-t nishi-t commented Jul 10, 2018

This PR changes NVRTCCompile to compile cuda code with cuda_fp16.h disscussed in #699. The cuda's fp16 operator is not supported yet in this PR. I'll work for the fp16 operator support later :)

For now, I confirmed that following code works. It must set the environment variable CUDA_HOME to location of cuda in order to run the following code.

import tvm
import numpy as np

n = tvm.var("n")
m = tvm.var("m")
mysum = tvm.comm_reducer(lambda x, y: x+y,
    lambda t: tvm.const(0, dtype=t), name="mysum")
A = tvm.placeholder((n, m), dtype="float16", name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B')

s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, B], "cuda")
print("build done")
print(fun.imported_modules[0].get_source())

ctx = tvm.context("cuda", 0)

a = tvm.nd.array(np.array([[1, 2], [3, 4]]).astype("float16"), ctx)
b = tvm.nd.array(np.zeros((2,), B.dtype), ctx)

fun(a, b)
print(b.asnumpy())

@tqchen @masahi @merrymercy @abergeron Could you review and comment for this?

@nishi-t nishi-t changed the title Add half type support for CUDA Add half type support to CUDA Jul 10, 2018

if (cudaHomePath != nullptr) {
includeOption += cudaHomePath;
includeOption += "/include";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work on Windows? I'm not sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I overlooked that. I'll address it. thanks

@tqchen
Copy link
Member

tqchen commented Jul 10, 2018

Please add a test case on this, the test case needed to be guarded by

if not gpu(0).exist or not have_fp16(gpu(0). compute_version):
    return

Let us enable fp16 by default when detecting fp16 is used in the code. Directly operating on fp16, while useful, may not be the most effective approach. We will also need to test vectorized load and vector operations (corresponds to half2 in CUDA)

@tqchen tqchen changed the title Add half type support to CUDA [CUDA] FP16 support Jul 10, 2018
@tqchen tqchen self-requested a review July 10, 2018 15:36
@nishi-t
Copy link
Contributor Author

nishi-t commented Jul 11, 2018

@tqchen Ok, I'm working on it. Thanks.


if (include_path) {
std::string includeOption = "--include-path=";
const char* cudaHomePath = std::getenv("CUDA_HOME");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about defining CUDA_HOME as a preprocessor macro in the .cmake file?

Copy link
Contributor Author

@nishi-t nishi-t Jul 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean that cudaHomePath is defined as a preprocessor macro at the time of building tvm? If so, I'm worried that its path depends strongly on the environment at the time of building tvm. For example, whether it will be a problem when distributing a pre-compiled tvm in the future. Please let me know your opinion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, never mind.

Then, I'd suggest using CUDA_PATH instead of CUDA_HOME for the environment variable, and defaulting to /usr/local/cuda. I think it's better to be consistent with how python/tvm/contrib/nvcc.py finds the cuda path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kazum Thank you for the suggestion. I'll address it.

@nishi-t nishi-t force-pushed the cuda_fp16 branch 2 times, most recently from 2728ebd to c7457c5 Compare July 12, 2018 08:47
@tqchen tqchen added the status: need update need update based on feedbacks label Jul 17, 2018
Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made some followup comments. @nishi-t can you also include a simple testcase that do vectorized add?

@@ -0,0 +1,75 @@
"""Utilith for CUDA backend"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to nvcc.py

@@ -43,6 +84,13 @@ std::string NVRTCCompile(const std::string& code) {
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog));

if (include_path) {
for (int i = 0; i < numCompileOptions; i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use vector_style for variable names, Google Cstyle

@@ -26,11 +28,50 @@ namespace codegen {
} \
}

std::string NVRTCCompile(const std::string& code) {
std::string NVRTCCompile(const std::string& code, bool include_path = false) {
char *compileParams[2];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::string to store strings, avoid use malloc and free

@@ -0,0 +1,75 @@
"""Utilith for CUDA backend"""

def parse_cc(compute_capability):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_compute_version

@@ -66,6 +66,7 @@ COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh
RUN bash /install/ubuntu_install_redis.sh

# Environment variables
ENV CUDA_HOME=/usr/local/cuda
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should not rely on CUDA_HOME, and allow some local search happening(in contrib.nvcc), which is aware of CUDA_HOME, but will also search for /usr/local/cuda by default

@nishi-t
Copy link
Contributor Author

nishi-t commented Jul 18, 2018

@tqchen Sorry for the delay and thank you for the comments. I'll address your and reviwer's comment, soon.
By the way, I thought I already added simple vectorized add testcase in here. If you have any problem, please let me know.

@nishi-t nishi-t force-pushed the cuda_fp16 branch 5 times, most recently from f5f264f to c7784d0 Compare July 18, 2018 08:58
Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the updates, I just have one minor comment and it is good to go

@@ -26,11 +30,65 @@ namespace codegen {
} \
}

std::string NVRTCCompile(const std::string& code) {

std::string find_cuda_include_path() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function in CamelCase in C++ FindCUDAIncludePath

@tqchen
Copy link
Member

tqchen commented Jul 18, 2018

major = int(split_ver[0])
minor = int(split_ver[1])
return major, minor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using exceptions:

try:
    major, minor = compute_version.split('.')
    return int(major), int(minor)
except ValueError as err:
    ....

minor = int(split_ver[1])
return major, minor

raise RuntimeError("the compute capability string is unsupported format: " + cc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc is not defined here.

np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)

check_cuda("float32", 64, 2)
if not tvm.gpu(0).exist or not have_fp16(tvm.gpu(0).compute_version):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check of tvm.gpu(0).exist is common to fp32 and fp16. It should be moved into check_cuda().

@nishi-t
Copy link
Contributor Author

nishi-t commented Jul 19, 2018

@masahi Thank you for review.
@tqchen @kazum I addressed. please review again.

Copy link
Contributor

@kazum kazum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks.

Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some final comments on cross platform handling

}

cuda_include_path = "/usr/local/cuda/include";
if (stat(cuda_include_path.c_str(), &st) == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stat function may not available in some of MSVC, consider only use stat query in linux and force user to set CUDA_PATH otherwise

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c.f. https://msdn.microsoft.com/en-us/library/14h5k7ff.aspx Sometimes the stat function is not available, and need to use _stat instead. Given the proposed path do not work for windows anyway, let us just skip this in windows

@@ -5,14 +5,18 @@
*
* \file build_cuda.cc
*/
#include <sys/stat.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider only include sys/stat.h when linux is detected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nishi-t
Copy link
Contributor Author

nishi-t commented Jul 20, 2018

@tqchen Thank you for the comment. I addressed. Please review again.

@tqchen tqchen merged commit 5f7b4d5 into apache:master Jul 20, 2018
@tqchen
Copy link
Member

tqchen commented Jul 20, 2018

Thanks @nishi-t @kazum @masahi , this is now merged!

@tqchen tqchen mentioned this pull request Jul 20, 2018
tqchen pushed a commit to tqchen/tvm that referenced this pull request Aug 4, 2018
sergei-mironov pushed a commit to sergei-mironov/tvm that referenced this pull request Aug 8, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants