Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU gemv speedup #5257

Merged
merged 3 commits into from
Nov 23, 2016
Merged

GPU gemv speedup #5257

merged 3 commits into from
Nov 23, 2016

Conversation

khaotik
Copy link
Contributor

@khaotik khaotik commented Nov 20, 2016

Added call to cublasSdot in GPU gemv code.

This is proposed solution for the 1st issue in #1168. Despite not solving all the problem, I'm getting 4x speedup on my GPU.

Copy link
Member

@nouiz nouiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few small changes in comments.

Otherwise, I'll accept this small PR, but we don't develop the old gpu back-end. It would be great to port this to the new back-end.

// alpha and beta parameter
// 2. permanant solution:
// define a new "InnerProduct" Op, add an optimization
// "gemv -> inner_prod", perhaps for CPU/GPU both
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new op trick won't always work. Sometimes we won't know the shapes, so having it here is good I think and is easier. So just remove this comment.

cublasPointerMode_t pmode;
cublasGetPointerMode(handle, &pmode);
// need to store dot result on device here
cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't set the pointermode. We don't do this anywhere, we always pass them on the CPU (the default). Also, you are passing them on the host bellow. So I don't understand why you added this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't add this, this will cause CUBLAS take dot result as a host pointer, causing segfault (Example). I set it back to make sure other CUDA code still function correctly.

// 2. permanant solution:
// define a new "InnerProduct" Op, add an optimization
// "gemv -> inner_prod", perhaps for CPU/GPU both
float* dev_dst = CudaNdarray_DEV_DATA(C)+1-sc_0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the "+1-sc_0"? I would completly remove that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was misunderstanding cublas doc, removed.

@nouiz
Copy link
Member

nouiz commented Nov 22, 2016

jenkins test this

@nouiz nouiz merged commit c89d973 into Theano:master Nov 23, 2016
@khaotik khaotik deleted the gpu_gemv_speedup branch November 24, 2016 01:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants