Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fuse optimize op transpiler #8940

Conversation

jacquesqiao
Copy link
Member

@jacquesqiao jacquesqiao commented Mar 9, 2018

issue: #8941

@jacquesqiao jacquesqiao changed the title Merge optimize op transpiler fuse optimize op transpiler Mar 9, 2018
int grid_size = blockDim.x * gridDim.x;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += grid_size) {
T g_data = g[i];
T p_data = p[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assignments aren't needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these assignments can be saved.

auto* grad_data = grads[j]->data<T>();
auto* param_data = params[j]->data<T>();
int param_num = params[j]->numel();
int block = 512;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is the block number decided?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The block number should is decided by the property of GPU and the task of CUDA Kernel.
In most GPU, the max threads of one block are 1024, so the block number is 1024 in most case, but for SGDKernel, sometimes, the number of param_data maybe less 200 which is less 1024, if we also set the block number to 1024, it will cause waste the resource of GPU.
So block = 512 maybe not appropriate too.

template <typename T>
class SGDGroupOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgd_group_op.h is just used to analysis the affection of using only one sgd op_kernel on GPU. It doesn't attempt to merge.

# limitations under the License.


def fuse_optimize_op(input_program):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have plan to do transpiler in C++? Iterating ops in Python won't scale to large programs

Copy link
Contributor

@panyx0718 panyx0718 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. A question is, how to scale your method to other optimizers? like momentum and adam?

@chengduoZH
Copy link
Contributor

@panyx0718 Other optimizers also need a group operation separately, if adopting this strategy.

@chengduoZH
Copy link
Contributor

I have done some experiments about replacing sgd_op with sgd_group_op, this is the experiment result:
Profile script: performance_tuning_se_resnet_train

  sgd sgd_group after/before
total time 318.099 171.128 0.462029117
called times 26920 40  

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants