Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

corr_gemm optimization to improve CNN performance #4591

Merged
merged 7 commits into from Jul 8, 2016
Merged

corr_gemm optimization to improve CNN performance #4591

merged 7 commits into from Jul 8, 2016

Conversation

ciyongch
Copy link
Contributor

@ciyongch ciyongch commented Jun 6, 2016

Hi experts,

This PR is to optimize current corr_gemm Ops (including corrMM, corrMM_gradWeights and corrMM_gradInputs), which uses openmp to unroll the first loop (batch size level), and it turns out 5.2x (count all these three Ops in) speedup compared than original version.

Thanks.
Ciyong

@JesseLivezey
Copy link
Contributor

This looks similar to #3689 for the corrMM and corrMM_gradInputs. Can you run the tests with gcc 4.9 or 5.x? For the other PR the tests passed with gcc 4.8, but failed with the other versions.

@@ -359,8 +377,8 @@ class CorrMM(BaseCorrMM):
Set to `(1, 1)` to disable subsampling.

"""
def __init__(self, border_mode="valid", subsample=(1, 1)):
super(CorrMM, self).__init__(border_mode, subsample)
def __init__(self, border_mode="valid", subsample=(1, 1), openmp=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, we don't need that method for the 3 subclass. Can you just remove them?

@nouiz
Copy link
Member

nouiz commented Jun 7, 2016

Thanks for the PR. Overall, it seem ok, I got a few small comments. As @JesseLivezey told, we need to make sure this PR work in the case the others didn't.

@ciyongch
Copy link
Contributor Author

ciyongch commented Jun 8, 2016

Sure, let me try with gcc 4.9 and 5.x and update the test result.


Regards,
Ciyong

From: Jesse Livezey [mailto:notifications@github.com]
Sent: Tuesday, June 7, 2016 11:15 PM
To: Theano/Theano Theano@noreply.github.com
Cc: Chen, Ciyong ciyong.chen@intel.com; Author author@noreply.github.com
Subject: Re: [Theano/Theano] corr_gemm optimization to improve CNN performance (#4591)

This looks similar to #3689#3689 for the corrMM and corrMM_gradInputs. Can you run the tests with gcc 4.9 or 5.x? For the other PR the tests passed with gcc 4.8, but failed with the other versions.


You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHubhttps://github.com//pull/4591#issuecomment-224313008, or mute the threadhttps://github.com/notifications/unsubscribe/AQorxBJ2Be8nSMPSn-Rd-xIF28mh1ptEks5qJYr1gaJpZM4IuvGM.

@ciyongch
Copy link
Contributor Author

ciyongch commented Jun 8, 2016

@nouiz @JesseLivezey Thanks for your comments, codes are updated according to your comments, and it works well with both gcc 4.9 and 5.3.

@@ -377,8 +374,8 @@ class CorrMM(BaseCorrMM):
Set to `(1, 1)` to disable subsampling.

"""
def __init__(self, border_mode="valid", subsample=(1, 1), openmp=None):
super(CorrMM, self).__init__(border_mode, subsample, openmp=openmp)
def __init__(self, border_mode="valid", subsample=(1, 1)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I was meaning to remove completly the init method. It need to keep the openmp parameter as before this commit. And the code end up being what the default init is when it is not implemented. So we don't need it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response. That makes sense, will update the codes as you suggested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see any new commits. Did you forgot to push?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nouiz I will push the code today.

@JesseLivezey
Copy link
Contributor

I can also try and test this on OSX this week.

@nouiz
Copy link
Member

nouiz commented Jun 8, 2016

thanks

On Wed, Jun 8, 2016 at 2:35 PM, Jesse Livezey notifications@github.com
wrote:

I can also try and test this on OSX this week.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#4591 (comment), or mute
the thread
https://github.com/notifications/unsubscribe/AALC-yG6G6rM-J992qJ6wm7zjFqBsE29ks5qJwuIgaJpZM4IuvGM
.

col_dim[0] = (npy_intp)max_threads;
col_dim[1] = (npy_intp)(nChannels * kW * kH);
col_dim[2] = (npy_intp)(topHeight * topWidth);
PyArrayObject* col = (PyArrayObject*)PyArray_ZEROS(3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why zeros col? if this is needed only when max_threads > 1, then do it only in that case to don't slow down the single thread case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyArray_ZEROS is faster than PyArray_EMPTY as I tested in the conditinon of max_threads > 1, but I didn't test their performance with single thread case. If you're referring to PyArray_ZEROS could slow down performance than PyArray_EMPTY with single thread case, then we can distinguish those two conditions and select the faster one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why ZEROS was faster then EMPTY. Do you remember how much difference it was doing? If it is small, we could just let it like this for single thread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remembered ZEROS is around 1.7x faster than EMPTY when running AlexNet workload. Anyway, I can retry then gather and post data here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nouiz looks like PyArray_ZEROS could run with multi-thread mode in Numpy, while PyArray_EMPTY can't. So I got the better result with PyArray_ZEROS when max_threads > 1.

Here's the snippet of profiling of AlexNet Workload when max_threads>1:

Calling PyArray_ZEROS

 <% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
   19.7%    40.1%      23.662s       1.18e-01s     C      200      15   theano.tensor.nnet.corr.CorrMM
   10.3%    67.1%      12.288s       1.23e-01s     C      100       5   theano.tensor.nnet.corr.CorrMM_gradWeights
    9.5%    76.6%      11.422s       1.43e-01s     C       80       4   theano.tensor.nnet.corr.CorrMM_gradInputs

Calling PyArray_EMPTY

 <% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  23.1%    23.1%      30.808s       1.54e-01s     C      200      15   theano.tensor.nnet.corr.CorrMM
  12.2%    68.4%      16.244s       1.62e-01s     C      100       5   theano.tensor.nnet.corr.CorrMM_gradWeights
  10.7%    79.0%      14.218s       1.78e-01s     C       80       4   theano.tensor.nnet.corr.CorrMM_gradInputs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just saw this comment. It clearly show that ZEROS is faster, but I don't don't understand why. This won't block the PR, but if you know why, I would be interrested to know.

Did you time this with the single thread? As I don't understand why this is faster, maybe it is faster in single thread too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi nouiz,

When running in single thread, using PyArray_ZEROS is a little bit slow than using PyArray_EMPTY, this is what I found from the test result, but haven't figured out why these two APIs have such performance difference. I looked into their codes, the only difference between them is PyArray_ZEROS could use multi-thread (when enabled) to zero the memory.

Here's the snippet of profiling of AlexNet workload when max_threads=1
Calling PyArray_ZEROS

 <% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
   45.0%    45.0%     321.487s       1.61e+00s     C      200      15   theano.tensor.nnet.corr.CorrMM
   20.0%    65.0%     142.646s       1.43e+00s     C      100       5   theano.tensor.nnet.corr.CorrMM_gradWeights
   19.9%    84.9%     141.937s       1.77e+00s     C       80       4   theano.tensor.nnet.corr.CorrMM_gradInputs

Calling PyArray_EMPTY

<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
   44.9%    44.9%     319.600s       1.60e+00s     C      200      15   theano.tensor.nnet.corr.CorrMM
   20.0%    64.9%     142.701s       1.43e+00s     C      100       5   theano.tensor.nnet.corr.CorrMM_gradWeights
   19.8%    84.8%     141.364s       1.77e+00s     C       80       4   theano.tensor.nnet.corr.CorrMM_gradInputs

@JesseLivezey
Copy link
Contributor

JesseLivezey commented Jun 13, 2016

This is the script I've been using:
https://gist.github.com/JesseLivezey/12e0e320960a58c278138d402f724fca
to get these timings.

@nouiz, not setting MKL gives about the same performance as setting MKL to 1 thread. Overall, I get about a 2-4x speedup when all threads are dedicated to OMP versus having all threads dedicated to MKL. On the machine I'm using, it seems that the BLAS performance gets worse with more than 8 threads.

Note that the baseline op is different than in the previous comparison. I've taken out the legacy conv op.

Timing done on 12-core Intel Xeon X5650 with 24GB RAM.
BLAS was MKL from Anaconda Python, gcc 4.8.5.

Op: forward
Image shape, filter shape
(128, 3, 128, 128) (96, 3, 5, 5)
Imagenet-like
MKL       OpenMP    Time(ms)  Speedup   Notes     
1         None      1668      0.45      single thread
12        None      752       1.00      BLAS only, baseline
None      12        234       3.21      only OMP set
2         1         1002      0.75      BLAS only 
4         1         680       1.10      BLAS only 
8         1         619       1.21      BLAS only 
1         2         863       0.87      OMP only  
1         4         479       1.57      OMP only  
1         8         277       2.71      OMP only  
1         12        234       3.21      OMP only  


Op: forward
Image shape, filter shape
(128, 85, 2, 258) (64, 85, 2, 20)
Spectrogram-like
MKL       OpenMP    Time(ms)  Speedup   Notes     
1         None      757       0.39      single thread
12        None      294       1.00      BLAS only, baseline
None      12        90        3.25      only OMP set
2         1         511       0.58      BLAS only 
4         1         338       0.87      BLAS only 
8         1         271       1.09      BLAS only 
1         2         397       0.74      OMP only  
1         4         209       1.40      OMP only  
1         8         117       2.50      OMP only  
1         12        87        3.36      OMP only  


Op: gradInputs
Image shape, filter shape
(128, 3, 128, 128) (96, 3, 5, 5)
Imagenet-like
MKL       OpenMP    Time(ms)  Speedup   Notes     
1         None      3256      0.27      single thread
12        None      883       1.00      BLAS only, baseline
None      12        374       2.36      only OMP set
2         1         2073      0.43      BLAS only 
4         1         1301      0.68      BLAS only 
8         1         1004      0.88      BLAS only 
1         2         1681      0.53      OMP only  
1         4         888       0.99      OMP only  
1         8         487       1.81      OMP only  
1         12        375       2.35      OMP only  


Op: gradInputs
Image shape, filter shape
(128, 85, 2, 258) (64, 85, 2, 20)
Spectrogram-like
MKL       OpenMP    Time(ms)  Speedup   Notes     
1         None      1619      0.30      single thread
12        None      491       1.00      BLAS only, baseline
None      12        217       2.26      only OMP set
2         1         1050      0.47      BLAS only 
4         1         699       0.70      BLAS only 
8         1         545       0.90      BLAS only 
1         2         834       0.59      OMP only  
1         4         440       1.11      OMP only  
1         8         271       1.81      OMP only  
1         12        214       2.29      OMP only  


Op: gradWeights
Image shape, filter shape
(128, 3, 128, 128) (96, 3, 5, 5)
Imagenet-like
MKL       OpenMP    Time(ms)  Speedup   Notes     
1         None      1661      0.45      single thread
12        None      744       1.00      BLAS only, baseline
None      12        235       3.16      only OMP set
2         1         1001      0.74      BLAS only 
4         1         667       1.11      BLAS only 
8         1         621       1.20      BLAS only 
1         2         856       0.87      OMP only  
1         4         465       1.60      OMP only  
1         8         278       2.67      OMP only  
1         12        235       3.16      OMP only  


Op: gradWeights
Image shape, filter shape
(128, 85, 2, 258) (64, 85, 2, 20)
Spectrogram-like
MKL       OpenMP    Time(ms)  Speedup   Notes     
1         None      749       0.41      single thread
12        None      307       1.00      BLAS only, baseline
None      12        89        3.44      only OMP set
2         1         512       0.60      BLAS only 
4         1         336       0.91      BLAS only 
8         1         270       1.14      BLAS only 
1         2         416       0.74      OMP only  
1         4         211       1.46      OMP only  
1         8         127       2.40      OMP only  
1         12        88        3.48      OMP only

@JesseLivezey
Copy link
Contributor

@pcs-theano, the machine I tested this on is pretty old. If you have a more modern server with lots of cores, I'd be interested to see the timings. You might have to mess around with the script a bit, but hopefully it is fairly clear.

@ciyongch
Copy link
Contributor Author

ciyongch commented Jun 14, 2016

@JesseLivezey ,thanks for your testing data and script, I can run it on my workstation which has two "Intel(R) Xeon(R) CPU E5-2697 v3 @ 2.60GHz" sockets (28cores in total) and 128GB RAM, but it got stuck at some point.

Op: forward
Image shape, filter shape
(128, 3, 128, 128) (96, 3, 5, 5)
Imagenet-like
MKL       OpenMP    Time(ms)  Speedup   Notes     
1         None      715       0.39      single thread
28        None      275       1.00      BLAS only, baseline
None      28        69        3.95      only OMP set
2         1         411       0.67      BLAS only 
4         1         311       0.89      BLAS only 
8         1         599       0.46      BLAS only 
16        1         601       0.46      BLAS only 
1         2         339       0.81      OMP only  
1         4         179       1.54      OMP only  
1         8         120       2.30      OMP only  
1         16        74        3.70      OMP only  
1         28        58        4.71      OMP only  


Op: forward
Image shape, filter shape
(128, 85, 2, 258) (64, 85, 2, 20)
Spectrogram-like
MKL       OpenMP    Time(ms)  Speedup   Notes     
1         None      275       0.49      single thread
28        None      134       1.00      BLAS only, baseline
None      28        159       0.84      only OMP set
2         1         196       0.68      BLAS only 
4         1         275       0.49      BLAS only 
8         1         276       0.49      BLAS only 
16        1         316       0.42      BLAS only 
1         2         139       0.96      OMP only  
1         4         76        1.76      OMP only  
1         8         46        2.88      OMP only  
1         16        25        5.20      OMP only  
1         28        18        7.16      OMP only  


Op: gradInputs
Image shape, filter shape
(128, 3, 128, 128) (96, 3, 5, 5)
Imagenet-like
MKL       OpenMP    Time(ms)  Speedup   Notes     
1         None      1398      0.28      single thread
28        None      398       1.00      BLAS only, baseline
None      28        798       0.50      only OMP set
2         1         893       0.45      BLAS only 
4         1         616       0.65      BLAS only 
8         1         988       0.40      BLAS only    <-- got stuck here


PyArrayObject* col;
if (max_threads > 1) {
col = (PyArrayObject*)PyArray_ZEROS(3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you do the timing with single thread to know which one was faster?

@nouiz
Copy link
Member

nouiz commented Jun 14, 2016

What do you mean by "got stuck here"? Was the process still running or it was blocked?

The number seem to tell that we are better just setting MKL to 0 threads and OpenMP to the max. Could you try mixed like MKL 2 and OpenMP 14?

From a user point of view, it is not great to set those 2 env variable separatly. Especially as other operation in the graph could need other case. Could we force mkl to user 1 threads here?

@ciyongch
Copy link
Contributor Author

@nouiz When I ran the script, it's not going to dump the result at the place I marked "got stuck here", the process looked like still running but just not any output on the terminal. I'm checking the script...

It's a good point that we can try other combinations with BLAS thread and OMP thread. BLAS could still run in parallel when OMP havn't taken all of the available threads. Let's see the result when it's done.

MKL BLAS will first check both MKL_NUM_THREADS and OMP_NUM_THREADS env, but the former one has higher priority than the latter one, which means MKL_NUM_THREADS could overwrite the thread num when OMP_NUM_THREADS is also being set.

If the result shows that MKL 1 and OpenMP max_threads are the best combination, then we can force mkl to use 1 thread in this place.

@JesseLivezey
Copy link
Contributor

JesseLivezey commented Jun 15, 2016

Two more concerns:

  • not every BLAS library allows the number of threads to be set dynamically, ATLAS, for instance. I'm not sure how this would behave. Maybe it's not a problem.
  • if the BLAS isn't from MKL (OpenBLAS for instance), then is it possible to have the loop use more than 1 thread and BLAS use 1 thread dynamically?

An example of the second point would be: we would want omp-loop parallelism and single thread BLAS in the conv layers of a conv net, but BLAS parallelism in the FC layers.

I'm not familiar enough with OpenMP and how compilers behave when OpenMP and BLAS are combined to know what to watch out for.

@ciyongch
Copy link
Contributor Author

@nouiz @JesseLivezey ,

There's a minor issue with the line #42 of script which should be if direction == 'gradWeights' instead of if direction == 'gradInputs'.
After clean up/kill all python process before re-run the script, script hang issue is gone.

As for ATLAS, it doesn't support changing threads dynamically, the maximum number of threads to use is determined at compile time. But I think the new codes would not slow down its performance due to parallelism in for loop has better performance than gemm in corr OP.

As for OpenBLAS, it has same way to changing threads dynamically as MKL via setting OPENBLAS_NUM_THREADS env parameter or calling API openblas_set_num_threads().

I've collected the test data with both MKL and OpenBLAS, and g++(4.8.4) on a workstation with 36 cores (Intel(R) Xeon(R) CPU E5-2699 v3 @ 2.30GHz) and 128 GB RAM.
Besides of original combinations of BLAS_NUM_THREADS and OMP_NUM_THREADS, I also added some other combinations (BLAS_NUM_THREADS x OMP_NUM_THREADS = max_threads).
Test result also includes test time with calling PyArray_ZEROS and PyArray_EMPTY.

From the test result, calling PyArray_ZEROS gains better performance than PyArray_EMPTY during forward and gradInputs when OpenMP > 1 in AlexNet-like case, and and didn't slow down with single thread mode.
So, maybe we can change PyArray_EMPTY to PyArray_ZEROS in corr.py?

GEMM=1 and OpenMP=max_threads gains best and stable performance, while GEMM=None and OpenMP=max_threads sometimes could led to worse result such as gradWeight of ImageNet-like in file1 and 2, and gradInputs of Spectrogram-like in file3 and file 4.
So, it's better to force 1 for GEMM here. Should we only support MKL and OpenBLAS API here?

Hoping these test results should make things clear.

Please check the log files as attachment:
mkl_g++_pyarray_empty.txt
mkl_g++_pyarray_zeros.txt
openblas_g++_pyarray_empty.txt
openblas_g++_pyarray_zeros.txt

@nouiz
Copy link
Member

nouiz commented Jun 16, 2016

Forcing gemm=1 for mkl and openblas would be great.

There is a few corner case where using MKL=2 OMP=18 would be a little bit
faster with mkl (~10%, can be much slower with OpenBLAS in the same case),
but I think it is simpler and garantie to not have the big slow down to
force GEMM=1. Probably that ~10% depend of the shape and number of core
available, so I would go as you to just ignore it.

Agreed to always use ZEROS, just make a comment that it was timed to be
faster. So later people won't try to revert that back to empty.

On Thu, Jun 16, 2016 at 4:32 AM, pcs-theano notifications@github.com
wrote:

@nouiz https://github.com/nouiz @JesseLivezey
https://github.com/JesseLivezey ,

There's a minor issue with the line #42
#42 of script which should be if
direction == 'gradWeights' instead of if direction == 'gradInputs'.
After clean up/kill all python process before re-run the script, script
hang issue is gone.

As for ATLAS, it doesn't support changing threads dynamically, the maximum
number of threads to use is determined at compile time. But I think the new
codes would not slow down its performance due to parallelism in for loop
has better performance than gemm in corr OP.

As for OpenBLAS, it has same way to changing threads dynamically as MKL
via setting OPENBLAS_NUM_THREADS env parameter or calling API
openblas_set_num_threads().

I've collected the test data with both MKL and OpenBLAS, and g++(4.8.4) on
a workstation with 36 cores (Intel(R) Xeon(R) CPU E5-2699 v3 @ 2.30GHz) and
128 GB RAM.
Besides of original combinations of BLAS_NUM_THREADS and OMP_NUM_THREADS,
I also added some other combinations (BLAS_NUM_THREADS x OMP_NUM_THREADS
= max_threads).
Test result also includes test time with calling PyArray_ZEROS and
PyArray_EMPTY.

From the test result, calling PyArray_ZEROS gains better performance than
PyArray_EMPTY during forward and gradInputs when OpenMP > 1 in
AlexNet-like case, and and didn't slow down with single thread mode.
So, maybe we can change PyArray_EMPTY to PyArray_ZEROS in corr.py?

GEMM=1 and OpenMP=max_threads gains best and stable performance, while
GEMM=None and OpenMP=max_threads sometimes could led to worse result such
as gradWeight of ImageNet-like in file_1_ and 2, and gradInputs of
Spectrogram-like in file_3_ and file 4.
So, it's better to force 1 for GEMM here. Should we only support MKL and
OpenBLAS API here?

Hoping these test results should make things clear.

Please check the log files as attachment:
mkl_g++_pyarray_empty.txt
https://github.com/Theano/Theano/files/317954/mkl_g._pyarray_empty.txt
mkl_g++_pyarray_zeros.txt
https://github.com/Theano/Theano/files/317955/mkl_g._pyarray_zeros.txt
openblas_g++_pyarray_empty.txt
https://github.com/Theano/Theano/files/317956/openblas_g._pyarray_empty.txt
openblas_g++_pyarray_zeros.txt
https://github.com/Theano/Theano/files/317957/openblas_g._pyarray_zeros.txt


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#4591 (comment), or mute
the thread
https://github.com/notifications/unsubscribe/AALC-zwh_kwyRdHXPMREIPQooXbZT0Nhks5qMQoWgaJpZM4IuvGM
.

@ciyongch
Copy link
Contributor Author

@nouiz, Codes have been updated to forcing gemm=1 and using PyArray_ZEROS.

@@ -224,6 +217,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
bottomWidth, kH, kW, padH, padW, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
// Second, gemm
// Always forcing gemm to one thread here for best and stable performance.
%(blas_flags)s;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to restore to the previous value after the call to gemm? This is to allow BLAS to be parallel for other layers like dense layer that do just 1 big gemm call.

Also, I would call that before the for loop. Idem for the 3 other places.

Copy link
Contributor Author

@ciyongch ciyongch Jun 18, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, it should restore to the previous thread number after the call to gemm as you mentioned. Looks like OpenBLAS has different way as MKL to to so, let me verify them and update later.

@nouiz
Copy link
Member

nouiz commented Jun 17, 2016

Travis have some errors. It don't find the file mkl.h. Do you know how to find it automatically? The update to our code that find the good flags to pass (I suppose a -IPATH is missing) should be in the function default_blas_ldflags() in the file theano/configdefaults().

@nouiz
Copy link
Member

nouiz commented Jun 17, 2016

In fact, when using python from anaconda, we don't have access to an mkl.h file. So instead of requesting this file, can you just declare the function signature so that it know how to compile and we should link correctly?

Should we do the same for openblas?

@ciyongch
Copy link
Contributor Author

@nouiz
If user want to call MKL or OpenBLAS gemm function, they should have mkl.h or cblas.h header file along with the library in their system.

Regarding Travis build, it has such BLAS header file (mkl.h and cblas.h), right? Are you referring to adding the flag in theano/configdefaults() to let it find the header file automatically?

Regarding the scenario of using python frm anaconda, is it able to access any other header files, or it only just can't access to such BLAS header file. I'm not familiar with anaconda, but I think maybe it's not a good way to poring the function declaration from BLAS header file to corr file, if so, then we may also need to port other function declaration which might be used in future, what do you think?

@nouiz
Copy link
Member

nouiz commented Jun 18, 2016

Theano provide the blas header in the file:

https://github.com/Theano/Theano/blob/master/theano/tensor/blas_headers.py

This help us be more rebust to different installation setup.

In the case of anaconda, it include the library, but not headers. So there is no mkl.h file. anaconda is a popular way to install python, so we can't just drop that setup.

The only work around I'm thinking is to remove the include of mkl.h and provide the declaration of that extra function ourself.

If you have other idea, tell us.

@ciyongch
Copy link
Contributor Author

Got it, looks like adding declaration within theano is the only way to make it compatibility with different installation setup, will move forward in this way.

sub['get_blas_threads'] = 'openblas_get_num_threads()'
elif self.blas_type == 'mkl':
sub['set_blas_threads'] = 'mkl_set_num_threads'
sub['get_blas_threads'] = 'mkl_get_max_threads()'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need need an else close here in case another blas is used.

@nouiz
Copy link
Member

nouiz commented Jun 20, 2016

It seem pretty good, just one questions and we should be good.

@ciyongch
Copy link
Contributor Author

@nouiz the codes were updated, any new comments for the latest commit?

@nouiz
Copy link
Member

nouiz commented Jul 7, 2016

All looks good, but a rebase is required. thanks.

@theano-bot
Copy link
Contributor

Can one of the admins verify this patch?

@theano-bot theano-bot merged commit 2cdc1e6 into Theano:master Jul 8, 2016
@theano-bot
Copy link
Contributor

thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants