Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse convolutional neural networks #4328

Open
wenwei202 opened this issue Jun 18, 2016 · 194 comments
Open

Sparse convolutional neural networks #4328

wenwei202 opened this issue Jun 18, 2016 · 194 comments
Labels

Comments

@wenwei202
Copy link

wenwei202 commented Jun 18, 2016

Anyone has interest to utilize the sparsity to accelerate DNNs?

I am working on the fork https://github.com/wenwei202/caffe/tree/scnn and currently, on average, achieve ~5x CPU and ~3x GPU layer-wise speedups of convolutional layers in AlexNet by off-the-shelf GEMM (after ~2% top1 accuracy loss).

http://papers.nips.cc/paper/6504-learning-structured-sparsity-in-deep-neural-networks.pdf

@jpiabrantes
Copy link

jpiabrantes commented Jun 20, 2016

@wenwei202 could you explain a bit further how to use your fork? Any example? I have convolution layers where 90% of the weights are zero if I use your version of caffe the computations will automatically take advantage of this sparsity? If I use a dense matrix will the computations be slower or will it use the normal way of computing? Thanks for sharing your work 👍

@wenwei202
Copy link
Author

@jpiabrantes You can use conv_mode in each conv layer to indicate which method be utilized to do the computation.
e.g.
layer {
name: "conv2"
type: "Convolution"
bottom: "norm1"
top: "conv2"
convolution_param {
num_output: 256
pad: 2
kernel_size: 5
group: 2
conv_mode: LOWERED_CSRMM # sparse weight matrix in CSR format * lowered feature maps
# conv_mode: LOWERED_GEMM # default original matrix multiplication
}
}

Thanks

@jpiabrantes
Copy link

jpiabrantes commented Jun 24, 2016

I just tested on the Lenet network for the MNIST example. I was able to achieve the following sparse layers:

conv1 is 75.4 percent sparse
conv2 is 94.7 percent sparse
ip1 is 74.5 percent sparse
ip2 is 89.5 percent sparse

I used conv_mode: LOWERED_CSRMM and connectivity_mode: DISCONNECTED_GRPWISE. I used the GPU and the sparse network was not faster. Sometimes it was even slower, my batchsize is 1.

@wenwei202
Copy link
Author

wenwei202 commented Jun 25, 2016

@jpiabrantes in CPU mode, you need to use mkl. LOWERED_CSRMM is only implemented by mkl sparse blas since sparseblas is not supported by openblas and atlas.

@jpiabrantes
Copy link

@wenwei202 I used the GPU mode.

@wenwei202
Copy link
Author

@jpiabrantes it is normal to achieve very limited 'speedup' in GPU even you have sparsity higher than 90%. Because GPU is high-parallelism, and irregular sparse pattern will impact the performance. I am working on structured sparsity to achieve speedup in GPU.

@ghost
Copy link

ghost commented Jun 29, 2016

@wenwei202 I am not able to complete compilation. 'make runtest' fails.

@wenwei202
Copy link
Author

wenwei202 commented Jun 29, 2016

@Rupeshd @wenwei202
When make runtest, use atlas instead of mkl ( seems mkl has some problems to pass some testcases) and export following variables if you have more than one GPU:

export CUDA_VISIBLE_DEVICES=0 # use one GPU

To stabilize the sparsity during training, I zero out weights whose absolute values are smaller than 0.0001 after each weight updating. So, the precision of RMSPropSolverTest may not be enough to pass the test. You can comment the following code if you do not want to zero out (but it is recommended during training to stabilize the sparsity).

template <typename Dtype>
void Net<Dtype>::Update() {
  for (int i = 0; i < learnable_params_.size(); ++i) {
    learnable_params_[i]->Update();
    learnable_params_[i]->Zerout(); //comment this if you do not want to zerout.
  }
}

The only failed (crashed) test case is "TYPED_TEST(ConvolutionLayerTest, Test0DConvolution)" of https://github.com/wenwei202/caffe/blob/scnn/src/caffe/test/test_convolution_layer.cpp#L311.
And, I don't know why. If your guys can figure out, that would be great. Temporarily, I commented codes with in and passed all other test cases. Test0DConvolution was not used for usual 2D or 3D convolution, so it might not be a concern.

Hope this helps.

-Wei

@zhaishengfu
Copy link

@wenwei202
Hello, I think you have implemented Liu.s CVPR Sparse Convolution Neural Network. But in your fork of caffe_scnn(https://github.com/wenwei202/caffe/tree/scnn), I can't find any procedure to implement that(I know you implemented group_lasso and so on, but how can your code implement methods described in Liu's Paper?? Can you give me a simple tutorial??)
Thank you in advance.

@zhaishengfu
Copy link

@wenwei202
Besides, I can see you wrote 'models/eilab_reference_sparsenet/deploy_scnn.prototxt' and so on in some python files, but i can't find anyone of them.How can I generate them or where can i find them??

@wenwei202
Copy link
Author

@zhaishengfu The implementation was abandoned. Hardly it can achieve good speedup unless the sparse weights were hardcoded in the source code as the paper did. I didn't try hardcoding weights but you are free to try if you have interest. What the paper did was to convert each conv layer to three small layers. You can use this to generate the equivalent net prototxt and this to generate the corresponding decomposed caffemodel. But the code is deprecated.

@zhaishengfu
Copy link

@wenwei202 Thank you for your reply. But i don't understand your meaning of 'hardcoded'. I didn't see words describing about it in the paper. According to my understanding, you can get speed-up as long as your network is sparse and you implemented methods of sparse-dense matrix multiplication described in the paper. Am i wrong???

@wenwei202
Copy link
Author

@zhaishengfu Please refer to section 4 in the paper, like "Therefore, the location of non-zero elements are known and can be encoded directly in the compiled multiplication code." The duplication of that work was abandoned because of that tricky scheme. Our speedup is achieved by structured sparsity to overcome the irregular memory access pattern suffered from random distribution of sparse weights in the memory space. Hopefully, we can release our related paper soon.

@zhaishengfu
Copy link

@wenwei202 Thank you very much. Really looking forward to your paper. Can you let me know when you realease your paper??(or can you tell me the name of your paper??)

@wenwei202
Copy link
Author

wenwei202 commented Aug 15, 2016

Hi @zhaishengfu @jpiabrantes @Rupeshd @pluskid @sergeyk , our paper related to this caffe fork is just accepted by NIPS 2016. You are welcome to contribute, in case you still have interest in sparse convolutional neural networks. [paper] [Github code ]

@zhaishengfu
Copy link

zhaishengfu commented Aug 15, 2016

@wenwei202 Thank you very much!! I will read it carefully!! I really enjoy your contribution to this fork

@zhaishengfu
Copy link

@wenwei202 hello, i have seen your paper and code roughly. is the code same with your original code?? i did't see any difference(or may be i should see more carefully)
besides, i have used your original code to train my model(regression problem). it is useful but i lose some accuracy. and if i set the learning rate to >10^-5, it will go to "nan". so i can only set it to small number and the convergence is very slow...

@wenwei202
Copy link
Author

@zhaishengfu Please use the scnn fork, and I have updated tutorial. Help that will help.

@zhaishengfu
Copy link

zhaishengfu commented Aug 26, 2016

@wenwei202 ok, indeed i have used your code already. i used all of your related parameter to generate my prototxt as following. I see that you don't use tensor decomposition.
layer {
name: "conv1_1"
type: "Convolution"
bottom: "image"
top: "conv1_1"
param {
lr_mult: 1
decay_mult: 1
breadth_decay_mult: 1.0
kernel_shape_decay_mult: 1.0
block_group_lasso {
xdimen: 9
ydimen: 64
block_decay_mult: 1.0
}
regularization_type: "L1"
}
param {
lr_mult: 1
decay_mult: 1
breadth_decay_mult: 0.0
kernel_shape_decay_mult: 0.0
regularization_type: "L1"
}
connectivity_mode: DISCONNECTED_ELTWISE
convolution_param {
num_output: 64
bias_term: true
pad: 1
kernel_size: 3
group: 1
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
}
}
}

@hiyijian
Copy link

hiyijian commented Sep 1, 2016

For the setting:
block_group_lasso {
xdimen: 9
ydimen: 64
block_decay_mult: 1.0
}
what's the meaning of 964? Dose it mean that it will reserve 964 group of weights, and zero-out others?

@zhaishengfu
Copy link

@hiyijian the code says clearly, the xdimen and ydimen represents the column and row dimension respectively. For example, if you have A rows and ydimen is B, then you will have A/B groups and in each group you will use regularization

@hiyijian
Copy link

hiyijian commented Sep 1, 2016

Thanks. Clear now.
Is there any guide to set proper xdimen and ydimen in order to achieve better performance in accuracy and speed?

@zhaishengfu
Copy link

@hiyijian Indeed i also want to know the answer. In my trial of traning(my problem is not classification but regression), when the sparsity gets about >60%, the accuracy will decrease apprently. I think the configuration of xdimen and ydimen is related to your network and question. Maybe you can set the configuration as the paper says(such as xdimen is equal to the columns of your convolution kernel and ydimen is equal to the rows of your convolutional kernel).

@hiyijian
Copy link

hiyijian commented Sep 1, 2016

Thank you @zhaishengfu
Maybe the network could be fine-tuned without SSL to regain the accuracy as paper report. I will have a try

@wenwei202
Copy link
Author

@zhaishengfu @hiyijian The setups of xdimen and ydimen are based on what kinds of structure sparsity you want. For example, if weight matrix with many all-zero columns are expected, then xdimen = 1 and ydimen = the number of rows. For the trade off between accuracy and sparsity, pls train nn without ssl first to get the baseline, then train it ssl, and finally finetune it without ssl. Make sure your training converges well at every phase.

@wenwei202 wenwei202 reopened this Sep 1, 2016
@hiyijian
Copy link

hiyijian commented Sep 2, 2016

Thanks @wenwei202 . It very helps.
You introduce 5 ways for group lasso:
1、filter-wise and channle-wise
2、shape-wise
3、depth-wise
4、2D-filter-wise
5、filter-wise and shape-wise

Would you like to make it more clear : How to put them into practice respectively via xdimen/ydimen control ?

@hiyijian
Copy link

hiyijian commented Sep 2, 2016

Say we have a typical conv layer with nfilter* nchannel * nHeight * nWidth = 128 * 64 * 3 * 3
1、filter-wise and channle-wise: xdimen = 9 and ydimen >= 1
2 、shape-wise: xdimen != 9 and ydimen = 0
3、depth-wise: no idea
4、2D-filter-wise: xdimen = 9 and ydimem = 1
5、filter-wise and shape-wise: xdimen != 9 and ydimen >= 1

Did I do anything obivoius stupid?

@gsrivas4
Copy link

Thanks for the explanation. I get the intuition now and can relate with the theory in paper. However, I am not sure about the line in the code above
t = tf.reduce_sum(t, axis=axis)
should not this be, t = tf.reduce_sum(t, axis=(1,2,3))
as for a row regularization, we should be taking the squared sum of all the weights in a filter. Please, let me know where I am going wrong in my interpretation.

Regarding the issue, it is coming during make test, if I resolve it, I will post my fix here.

@wenwei202
Copy link
Author

@srivastavag89 you are right, you should reduce along three axes if it is a 3D filter. My error.

@hyunjaelee410
Copy link

@wenwei202
Thanks for your kind answer!
As a matter of fact, there was a slight problem with my configuration.

I would like to ask few more questions if you don't mind.

  1. I've read your ICCV 2017 paper(nice work by the way), and would like to apply it with SSL.
    However, SSD, which is also implemented in caffe has many newly-implemented layers.
    Do you think the best of applying your work to SSD is to analyse your code then merging it with SSD code? or do you have any other idea??

  2. SSD uses feature extractor net (VGG as a default) pre-trained with ImageNet, then fine-tune with VOC with newly added layers at the back. Do you think the network will converge well only with fine-tuning procedure with VOC?(with pre-trained feature extractor that SSL has been applied) or do you think pre-trained network should be re-trained with SSL then applied to SSD?

  3. I've read your comment about implementing your work in TensorFlow. However, I have found that it doesn't support cuSPARSE at the moment but simple sparse matrix acceleration(https://www.tensorflow.org/api_docs/python/tf/sparse_tensor_dense_matmul).
    Do you think the performance will improve in TensorFlow without support of cuSPARSE?

  4. I have run tests on ConvNet with cifar-10 and came up with results as below.

Baseline (GEMM) : 1657.76 ms
Baseline (CSRMM) : 1594.99 ms
After SSL (GEMM) : 1638.37 ms
After SSL (CSRMM) : 1361.54 ms
After Fine-tuning(GEMM) : 1678.58 ms
After Fine-tuning(CSRMM) : 1183.33 ms

and Sparsities of each case are

Baseline : (0.430833, 0.655312, 0.659883)
After SSL : (0.603333, 0.845117, 0.659688)
After Fine-tuning : (0.624167, 0.860391, 0.708887)
-in conv1, conv2, conv3 order

What I'm interested is that inference speed has changed a lot after fine-tuning even though there wasn't dramatic difference of sparsity. From what I understand, fine tuning is about regaining accuracy. Is it naturally to get such a gain?

Sorry for so much question. I wish I could get your intuition before applying to detection.
I hope I could share nice result of applying your work to SSD.
Thanks a lot!

@wenwei202
Copy link
Author

wenwei202 commented Jul 22, 2017

@HyunJaeLee2

  1. merging is a good way since everything is trackable;
  2. It is a safe way to first use SSL to get a sparsified VggNet and then fine-tune the back by VOC, but it is also doable or even better to train the whole net using SSL and VOC. SSL is a kind of regularization, and if it would not overfit, then you may get higher sparsity because VOC is a smaller dataset and the model can be compressed more? This is just my thought, the best way is to try both ways.
  3. cuSPARSE is still slow if you checked the figure here. I recommend LOWERED_CCNMM. TensorFlow is good for training, for inference, you may need to optimize your inference code based on the sparsity you obtained.
  4. Finetuning is for accuracy regaining while maintaining the sparsity trained by SSL. It can get ~1.0-2.0% accuracy back. I do not know how you measured the speed, but the code is not the best one for final deployment because some code hacking is required.

@gsrivas4
Copy link

gsrivas4 commented Aug 9, 2017

@wenwei202 Thanks for your work.
I am trying to implement structured sparsity in Thenao/Lasagne. I have very similar network architecture as yours. But when I run the test, the sparsity for filters and channels change from 0% to 100% within an epoch. If you can suggest or give some pointers based on your experience where I could be going wrong in my implementation. Also, I wanted to confirm regarding the hyperparamters, did you use the filter and channel sparsity hyperparameters, as .003, as given on your github.

@wenwei202
Copy link
Author

@srivastavag89 may be the hyper-parameter of structured sparsity regularization is too large such that it only optimizes the sparsity?
The hyper-parameter depends on the network architecture and dataset. It is easier to tune the hyper-parameter by retraining the trained DNNs , instead of training from scratch.
One good hyper-parameter to start from is the one such that the value of cross-entropy and the value of regularization are close. Another good value is the one of L2 weight decay.

@gsrivas4
Copy link

gsrivas4 commented Aug 9, 2017

Thanks for the reply. I will try taking a trained network and then retrain it adding structured sparsity to it this time. Also, I will try hyperparameters as per your suggestions.
Another suggestion I needed, would it be fine if I start with another network in theano which gives similar baseline accuracy without SSL applied to it. But this network is not same but similar to your baseline network with convolution, max pooling and batch normalization layers. Would you suggest trying with this network or constructing a network exactly same as yours? Just trying to get your views, if there is anything specific in the network you took or you would expect SSL to work on other networks as well,

@wenwei202
Copy link
Author

@srivastavag89 SSL is a universal method, it would work for versatile networks not just for those in the paper.

@zlheos
Copy link

zlheos commented Sep 14, 2017

@wenwei202 I try to train SSL with CPU mode
but there has a problem " "Deprecated in CPU mode: breadth and kernel shape decay (use block group decay instead)""
I'm appreciated you could answer me!

@wenwei202
Copy link
Author

@zlheos Please use block_group_lasso in net protobuf and block_group_decay in solver protobuf instead. You may check the tutorial here.

@wenwei202
Copy link
Author

wenwei202 commented Sep 18, 2017

FYI, Structured Sparsity Learning (SSL) approach is now also implemented in TensorFlow (code). We also extend and advance SSL to Recurrent Neural Networks to reduce the hidden dimension of LSTMs, i.e., learning the number of hidden/cell states/neurons in RNNs. Missing details (e.g. training method) are included in Section 3.2 here.

@m8w
Copy link

m8w commented Sep 18, 2017

Theory
begin Comments
IF I use fft(Fast Fourier Transformations) in exponential convolutions for video filters, the system implementation breaks crashes. While I wait for the solutions to these video filters and their stream applications by finding non-identity type third-order differential derivatives in some Banach Lp singular integral operators with non-convolution type manifolds I have found it is easier with Gauge systems and solitions yet neural network systems of a gaugian types might only be theory at this point.
end Comment

When going beyond the theories and applications we can simplify all kinds of identities into one type or another.

@zhuochen24
Copy link

@wenwei202 Great job on this paper! It is very impressive! I wonder where I can find your prototxt for training the alexnet for imagenet (the one shown in paper)? Because of the need of my experiment, I would like to make sure that I am using the correct method to compress the network in the correct way (which should reproduce your results).
Much appreciated!

@wenwei202
Copy link
Author

@Jarvistonychen Took a while to look around the logs, and find the the hyperparameter of 0.0005 for entry 5 in table 4.

@zhuochen24
Copy link

@wenwei202 Thanks a lot for spending time answering my question! Since there is only column sparsity in entry 5 of table 4, I think 0.0005 is for kernel_shape_decay? If I also wanna do breadth_decay, is 0.0005 also the right value to use?
Thanks again!

@wenwei202
Copy link
Author

@Jarvistonychen Yes, you may start from there.

@ananddb90
Copy link

@wenwei202
for conv layer with dimension = nfilter x nchannel x nheight x nwidth (64x128x3x3)
how can I combine filter wise and channel wise sparsity on one prototxt of ResNet?
Do I have to write
block_group_lasso{ #filterwise
xdim: 128x3x3
ydim: 1
}
block_group_lasso{ #channelwise
xdim: 3x3
ydim: 64
}

@wenwei202
Copy link
Author

Correct. More precisely:

block_group_lasso{ 
xdim: 1152
ydim: 1
}
block_group_lasso{ 
xdim: 9
ydim: 64
}

@ananddb90
Copy link

@wenwei202
In the log file, what is the meaning of

  1. I1024 15:28:53.767140 10494 sgd_solver.cpp:120] Element Sparsity %:
  2. I1024 15:28:53.776428 10494 sgd_solver.cpp:130] Column Sparsity %:
  3. I1024 15:28:53.777498 10494 sgd_solver.cpp:139] Row Sparsity %:
  4. I1024 15:28:53.777539 10494 sgd_solver.cpp:153] Block Sparsity %:

I1024 15:29:05.849658 10494 solver.cpp:231] Iteration 4420, loss = 0.494514
I1024 15:29:05.849678 10494 solver.cpp:247] Train net output #0: accuarcy = 0.920354
I1024 15:29:05.849684 10494 solver.cpp:247] Train net output #1: loss_bbox = 0.142928 (* 1 = 0.142928 loss)
I1024 15:29:05.849689 10494 solver.cpp:247] Train net output #2: loss_cls = 0.347726 (* 1 = 0.347726 loss)
I1024 15:29:05.849691 10494 solver.cpp:247] Train net output #3: rpn_cls_loss = 0.0584648 (* 1 = 0.0584648 loss)
I1024 15:29:05.849695 10494 solver.cpp:247] Train net output #4: rpn_loss_bbox = 0.0126774 (* 1 = 0.0126774 loss)
I1024 15:29:05.849699 10494 sgd_solver.cpp:106] Iteration 4420, lr = 0.0001
I1024 15:29:06.022259 10494 sgd_solver.cpp:120] Element Sparsity %:
0.28699 100 0 0 0 0 0 1.58691 0 1.17188 0 0 0 0.366211 0 0 0 0 0 0.496419 0 0 0 0 0 1.55029 0 1.17188 0 0 0 1.82495 0 0 0 0 0 0.311957 0 0 0 0 0 0.57373 0 0 0 0 0 7.91092 0 0 0 0 0 7.32727 0 0 0 0 0.78125 6.30493 0 0 0 0 0 16.4474 0 0 0 0 0 11.7126 0 0.78125 0 0 0 6.78982 0 0 0 0 0 21.0297 0 0 0 0.195312 0 7.89948 0 0 0 0 0 7.76062 0 0 0 0 0 6.10775 0 0 0 0 0 8.77342 0 0 0 0 0 11.3316 0 0 0 0 0 6.27865 0 0 0 0 0 8.37975 0 0 0 0 0 8.17313 0 0 0 0 0 7.22427 0 0 0 0 0 5.82258 0 0 0 0 0 7.35607 0.146484 0 0 0 0 7.47519 0 0 0 0 0 5.42183 0 0 0 0 0 7.55234 0.0488281 0 0 0 0 3.44215 21.0938 2.3112 0 10.9375 19.4444 3.55883 14.9414 3.04452 28.5714 77.5074 87.5
I1024 15:29:06.031704 10494 sgd_solver.cpp:130] Column Sparsity %:
0 100 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1.17188 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1.5625 0 0 0 0 0 0.78125 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.976562 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.0976562 0
I1024 15:29:06.032747 10494 sgd_solver.cpp:139] Row Sparsity %:
0 100 0 0 0 0 0 0 0 1.17188 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1.17188 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2.34375 0 0 0 0 0 0 0 0 0 0 0.78125 0 0 0 0 0 0 7.42188 0 0 0 0 0 0.78125 0 0.78125 0 0 0 0 0 0 0 0 0 12.6953 0 0 0 0.195312 0 0.390625 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.292969 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.0976562 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.146484 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.0488281 0 0 0 0 0 21.0938 0 0 0 19.4444 0 14.9414 0 28.5714 27.8061 87.5
I1024 15:29:06.032781 10494 sgd_solver.cpp:153] Block Sparsity %:

and how to calculate/know percentage of sparsity in model before and after training ?

@wenwei202
Copy link
Author

wenwei202 commented Oct 24, 2017

@ananddb90 During training, you will see some sparsity statistics. The sparsity is shown in the order of layers, and in each layer, in the order of weights and then biases. Basically, it plots sparsity for all parameter blobs in caffe, like parameters for a batch normalization layer. We usually care only about the sparsity of weights.

The "Element Sparsity" is the percentage of zeros. "Block Sparsity" is the percentage of all-zero blocks if you used block_group_lasso. Others are pretty self explained. Thanks.

@ananddb90
Copy link

ananddb90 commented Oct 25, 2017

@wenwei202
thank you for your reply
but I couldn't understand your first paragraph. How can I get sparsity in the order of layers, and then layer specific weight/bias since in my log file, I just get my training loss, element, column, row sparsity.

Also, I am using block_group_lasso but I am not getting any output in my log file.

Train_scnn.prototxt
layer {
bottom: "res2b"
top: "res3a_branch1"
name: "res3a_branch1"
type: "Convolution"
convolution_param {
num_output: 512
kernel_size: 1
pad: 0
stride: 2
bias_term: false
}
param {
lr_mult: 1.0
block_group_lasso { # Filter-wise structured sparsity
xdimen: 256 # CxWxH
ydimen: 1 # the block size along the y (row) dimension
block_decay_mult: 1.0 # the local multiplier of weight decay (block_group_decay) by group lasso regularization
}
block_group_lasso { # Channel-wise structured sparsity
xdimen: 1 # CxWxH
ydimen: 512 # the block size along the y (row) dimension
block_decay_mult: 1.0 # the local multiplier of weight decay (block_group_decay) by group lasso regularization
}
}
}

solver.prototxt
#base_lr: 0.001
base_lr: 0.0001 #scnn_lr = 0.1*base_lr
lr_policy: "step"
gamma: 0.1
stepsize: 80000
display: 20
momentum: 0.9

weight_decay: 0.0005

#kernel_shape_decay: 0.0
#breadth_decay: 0.0
block_group_decay: 0.001 #0.007

snapshot: 10000
snapshot_prefix: "rfcn_scnn"

iter_size: 2
debug_info: true

@wenwei202
Copy link
Author

@ananddb90 here is how sparsity displayed. Reading the code may be the best way. Alternatively, you may use pycaffe to analyze the trained model.

@AllenFenglei
Copy link

@jpiabrantes
Hello! I'm a newbie in caffe. May I ask how to get the sparsity in each layer? eg.
conv1 is 75.4 percent sparse
conv2 is 94.7 percent sparse
ip1 is 74.5 percent sparse
ip2 is 89.5 percent spars

@Demohai
Copy link

Demohai commented Sep 30, 2018

@wenwei202 read your paper and watch your blogs recently. I am also studying for a master's degree at Beihang University,very admired you and your research results. I have a question that hope you can help me. In your SSL paper, you want to learn structed sparsity through setting breadth_decay, kernel_shape_decay or block_group_decay, but as you said below, during SSL, zeros in a row or column can still go back to nonzeros if they get a large update by the gradients of the cross entropy. Then, after fine-tuning the SSL, the weights in a row , column or block may not be all zeros, there are some nonzero values at any place, so not the structed sparsity. I don't know am I right?

@wenwei202
Copy link
Author

wenwei202 commented Sep 30, 2018

Hello @Demohai , answers may differ for different stages. In a learning stage of structured sparsity using group Lasso, zeros can go back only when those weights are very important since group Lasso regularization enforces them to zeros; in a fine-tuning stage after group Lasso, we simply fixed zeros in all-zero groups/rows/columns and retrained remaining ones, so that structured sparsity were kept. Thanks! :)

@Demohai
Copy link

Demohai commented Oct 6, 2018

hello @wenwei202 I find each time I run the program ,under the caffe root dictionary, there are some weights files for each layer, what are they used for and which section of the source code product them?
qq 20181006190802

@wenwei202
Copy link
Author

It's just used to analyze the sparsity pattern. The code generating the weights is here and here.

@Demohai
Copy link

Demohai commented Oct 7, 2018

Thanks a lot !!!

@Demohai
Copy link

Demohai commented Oct 8, 2018

hello @wenwei202 bother you again. When to deploy the fine-tuned SSL network, we have several conv mode for choice, I want to know what the lowered tensors and lowered feature maps are as the following picture shows.
qq 20181008105659

@wenwei202
Copy link
Author

@Demohai sorry for late reply. The feature maps are lowered to a matrix for matrix multiplication, and this is how LOWERED comes. Please refer to this post for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests