Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
gineshidalgo99 committed Apr 27, 2018
2 parents 3cb22ee + 8645207 commit 9453eb0
Show file tree
Hide file tree
Showing 12 changed files with 400 additions and 25 deletions.
37 changes: 26 additions & 11 deletions .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
Please use the [caffe-users list](https://groups.google.com/forum/#!forum/caffe-users) for usage, installation, or modeling questions, or other requests for help.
_Do not post such requests to Issues._ Doing so interferes with the development of Caffe.
## Important - read before submitting

Please read the [guidelines for contributing](https://github.com/BVLC/caffe/blob/master/CONTRIBUTING.md) before submitting this issue.
*Please read the [guidelines for contributing](https://github.com/BVLC/caffe/blob/master/CONTRIBUTING.md) before submitting this issue!*

*Please do not post installation, build, usage, or modeling questions, or other requests for help to Issues.*
Use the [caffe-users list](https://groups.google.com/forum/#!forum/caffe-users) instead.
This helps developers maintain a clear, uncluttered, and efficient view of the state of Caffe.

### Issue summary


### Steps to reproduce

If you are having difficulty building Caffe or training a model, please ask the caffe-users mailing list. If you are reporting a build error that seems to be due to a bug in Caffe, please attach your build configuration (either Makefile.config or CMakeCache.txt) and the output of the make (or cmake) command.

### Your system configuration
Operating system:
Compiler:
CUDA version (if applicable):
CUDNN version (if applicable):
BLAS:
Python or MATLAB version (for pycaffe and matcaffe respectively):
### Tried solutions


### System configuration

* Operating system:
* Compiler:
* CUDA version (if applicable):
* CUDNN version (if applicable):
* BLAS:
* Python version (if using pycaffe):
* MATLAB version (if using matcaffe):

### Issue checklist

- [ ] read the guidelines and removed the first paragraph
- [ ] written a short summary and detailed steps to reproduce
- [ ] explained how solutions to related problems failed (tick if found none)
- [ ] filled system configuration
- [ ] attached relevant logs/config files (tick if not applicable)
56 changes: 49 additions & 7 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,63 @@
# Contributing

Below you will find a collection of guidelines for submitting issues as well as contributing code to the Caffe repository.
Please read those before starting an issue or a pull request.

## Issues

Specific Caffe design and development issues, bugs, and feature requests are maintained by GitHub Issues.

_Please do not post usage, installation, or modeling questions, or other requests for help to Issues._
Use the [caffe-users list](https://groups.google.com/forum/#!forum/caffe-users) instead. This helps developers maintain a clear, uncluttered, and efficient view of the state of Caffe.

When reporting a bug, it's most helpful to provide the following information, where applicable:
*Please do not post installation, build, usage, or modeling questions, or other requests for help to Issues.*
Use the [caffe-users list](https://groups.google.com/forum/#!forum/caffe-users) instead.
This helps developers maintain a clear, uncluttered, and efficient view of the state of Caffe.
See the chapter [caffe-users](#caffe-users) below for guidance on posting to the users list.

* What steps reproduce the bug?
* Can you reproduce the bug using the latest [master](https://github.com/BVLC/caffe/tree/master), compiled with the `DEBUG` make option?
* What hardware and operating system/distribution are you running?
When reporting an issue, it's most helpful to provide the following information, where applicable:
* How does the problem look like and what steps reproduce it?
* Can you reproduce it using the latest [master](https://github.com/BVLC/caffe/tree/master), compiled with the `DEBUG` make option?
* What hardware and software are you running? In particular:
* GPU make and model, if relevant,
* operating system/distribution,
* compiler; please also post which version (for example, with GCC run `gcc --version` to check),
* CUDA version, if applicable (run `nvcc --version` to check),
* cuDNN version, if applicable (version number is stored in `cudnn.h`, look for lines containing `CUDNN_MAJOR`, `CUDNN_MINOR` and `CUDNN_PATCHLEVEL`),
* BLAS library,
* Python version, if relevant,
* MATLAB version, if relevant.
* **What have you already tried** to solve the problem? How did it fail? Are there any other issues related to yours?
* If this is not a build-related issue, does your installation pass `make runtest`?
* If the bug is a crash, provide the backtrace (usually printed by Caffe; always obtainable with `gdb`).
* If you are reporting a build error that seems to be due to a bug in Caffe, please attach your build configuration (either Makefile.config or CMakeCache.txt) and the output of the make (or cmake) command.

If only a small portion of the code/log is relevant to your issue, you may paste it directly into the post, preferably using Markdown syntax for code block: triple backtick ( \`\`\` ) to open/close a block.
In other cases (multiple files, or long files), please **attach** them to the post - this greatly improves readability.

If the problem arises during a complex operation (e.g. large script using pycaffe, long network prototxt), please reduce the example to the minimal size that still causes the error.
Also, minimize influence of external modules, data etc. - this way it will be easier for others to understand and reproduce your issue, and eventually help you.
Sometimes you will find the root cause yourself in the process.

Try to give your issue a title that is succinct and specific. The devs will rename issues as needed to keep track of them.

## Caffe-users

Before you post to the [caffe-users list](https://groups.google.com/forum/#!forum/caffe-users), make sure you look for existing solutions.
The Caffe community has encountered and found solutions to countless problems - benefit from the collective experience.
Recommended places to look:
* the [users list](https://groups.google.com/forum/#!forum/caffe-users) itself,
* [`caffe`](https://stackoverflow.com/questions/tagged/caffe) tag on StackOverflow,
* [GitHub issues](https://github.com/BVLC/caffe/issues) tracker (some problems have been answered there),
* the public [wiki](https://github.com/BVLC/caffe/wiki),
* the official [documentation](http://caffe.berkeleyvision.org/).

Found a post/issue with your exact problem, but with no answer?
Don't just leave a "me too" message - provide the details of your case.
Problems with more available information are easier to solve and attract good attention.

When posting to the list, make sure you provide as much relevant information as possible - recommendations for an issue report (see above) are a good starting point.
*Please make it very clear which version of Caffe you are using, especially if it is a fork not maintained by BVLC.*

Formatting recommendations hold: paste short logs/code fragments into the post (use fixed-width text for them), **attach** long logs or multiple files.

## Pull Requests

Caffe welcomes all contributions.
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ $(PROTO_BUILD_DIR)/%.pb.cc $(PROTO_BUILD_DIR)/%.pb.h : \
$(PY_PROTO_BUILD_DIR)/%_pb2.py : $(PROTO_SRC_DIR)/%.proto \
$(PY_PROTO_INIT) | $(PY_PROTO_BUILD_DIR)
@ echo PROTOC \(python\) $<
$(Q)protoc --proto_path=$(PROTO_SRC_DIR) --python_out=$(PY_PROTO_BUILD_DIR) $<
$(Q)protoc --proto_path=src --python_out=python $<

$(PY_PROTO_INIT): | $(PY_PROTO_BUILD_DIR)
touch $(PY_PROTO_INIT)
Expand Down
6 changes: 6 additions & 0 deletions cmake/Cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ function(caffe_select_nvcc_arch_flags out_variable)
set(__nvcc_flags "")
set(__nvcc_archs_readable "")

string(COMPARE LESS "${CUDA_VERSION}" "9.0" iscudaolderthan90)
if(NOT iscudaolderthan90)
string(REPLACE "21(20)" "" __cuda_arch_bin "${__cuda_arch_bin}")
string(REPLACE "20" "" __cuda_arch_bin "${__cuda_arch_bin}")
endif()

# Tell NVCC to add binaries for the specified GPUs
foreach(__arch ${__cuda_arch_bin})
if(__arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
Expand Down
2 changes: 1 addition & 1 deletion cmake/ProtoBuf.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function(caffe_protobuf_generate_cpp_py output_dir srcs_var hdrs_var python_var)
"${output_dir}/${fil_we}_pb2.py"
COMMAND ${CMAKE_COMMAND} -E make_directory "${output_dir}"
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --cpp_out ${output_dir} ${_protoc_include} ${abs_fil}
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${output_dir} ${_protoc_include} ${abs_fil}
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${PROJECT_BINARY_DIR}/include --proto_path ${PROJECT_SOURCE_DIR}/src ${_protoc_include} ${abs_fil}
DEPENDS ${abs_fil}
COMMENT "Running C++/Python protocol buffer compiler on ${fil}" VERBATIM )
endforeach()
Expand Down
96 changes: 96 additions & 0 deletions include/caffe/layers/swish_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#ifndef CAFFE_SWISH_LAYER_HPP_
#define CAFFE_SWISH_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

#include "caffe/layers/neuron_layer.hpp"
#include "caffe/layers/sigmoid_layer.hpp"

namespace caffe {

/**
* @brief Swish non-linearity @f$ y = x \sigma (\beta x) @f$.
* A novel activation function that tends to work better than ReLU [1].
*
* [1] Prajit Ramachandran, Barret Zoph, Quoc V. Le. "Searching for
* Activation Functions". arXiv preprint arXiv:1710.05941v2 (2017).
*/
template <typename Dtype>
class SwishLayer : public NeuronLayer<Dtype> {
public:
/**
* @param param provides SwishParameter swish_param,
* with SwishLayer options:
* - beta (\b optional, default 1).
* the value @f$ \beta @f$ in the @f$ y = x \sigma (\beta x) @f$.
*/
explicit SwishLayer(const LayerParameter& param)
: NeuronLayer<Dtype>(param),
sigmoid_layer_(new SigmoidLayer<Dtype>(param)),
sigmoid_input_(new Blob<Dtype>()),
sigmoid_output_(new Blob<Dtype>()) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "Swish"; }

protected:
/**
* @param bottom input Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the inputs @f$ x @f$
* @param top output Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the computed outputs @f$
* y = x \sigma (\beta x)
* @f$.
*/
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

/**
* @brief Computes the error gradient w.r.t. the sigmoid inputs.
*
* @param top output Blob vector (length 1), providing the error gradient with
* respect to the outputs
* -# @f$ (N \times C \times H \times W) @f$
* containing error gradients @f$ \frac{\partial E}{\partial y} @f$
* with respect to computed outputs @f$ y @f$
* @param propagate_down see Layer::Backward.
* @param bottom input Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the inputs @f$ x @f$; Backward fills their diff with
* gradients @f$
* \frac{\partial E}{\partial x}
* = \frac{\partial E}{\partial y}(\beta y +
* \sigma (\beta x)(1 - \beta y))
* @f$ if propagate_down[0]
*/
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

/// The internal SigmoidLayer
shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
/// sigmoid_input_ stores the input of the SigmoidLayer.
shared_ptr<Blob<Dtype> > sigmoid_input_;
/// sigmoid_output_ stores the output of the SigmoidLayer.
shared_ptr<Blob<Dtype> > sigmoid_output_;
/// bottom vector holder to call the underlying SigmoidLayer::Forward
vector<Blob<Dtype>*> sigmoid_bottom_vec_;
/// top vector holder to call the underlying SigmoidLayer::Forward
vector<Blob<Dtype>*> sigmoid_top_vec_;
};

} // namespace caffe

#endif // CAFFE_SWISH_LAYER_HPP_
5 changes: 5 additions & 0 deletions src/caffe/layers/embed_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ __global__ void EmbedForward(const int nthreads, const Dtype* bottom_data,
const int n = top_index / N;
const int d = top_index % N;
const int index = static_cast<int>(bottom_data[n]);
#ifdef DEBUG
assert(index >= 0);
assert(index < K);
assert(static_cast<Dtype>(index) == bottom_data[n]);
#endif
const int weight_index = index * N + d;
top_data[top_index] = weight[weight_index];
}
Expand Down
68 changes: 68 additions & 0 deletions src/caffe/layers/swish_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include <cmath>
#include <vector>

#include "caffe/layers/swish_layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void SwishLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
NeuronLayer<Dtype>::LayerSetUp(bottom, top);
sigmoid_bottom_vec_.clear();
sigmoid_bottom_vec_.push_back(sigmoid_input_.get());
sigmoid_top_vec_.clear();
sigmoid_top_vec_.push_back(sigmoid_output_.get());
sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_);
}

template <typename Dtype>
void SwishLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
NeuronLayer<Dtype>::Reshape(bottom, top);
sigmoid_input_->ReshapeLike(*bottom[0]);
sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_);
}

template <typename Dtype>
void SwishLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* sigmoid_input_data = sigmoid_input_->mutable_cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
const int count = bottom[0]->count();
Dtype beta = this->layer_param_.swish_param().beta();
caffe_copy(count, bottom_data, sigmoid_input_data);
caffe_scal(count, beta, sigmoid_input_data);
sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
caffe_mul(count, bottom_data, sigmoid_output_->cpu_data(), top_data);
}

template <typename Dtype>
void SwishLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0]) {
const Dtype* top_data = top[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
const int count = bottom[0]->count();
Dtype beta = this->layer_param_.swish_param().beta();
for (int i = 0; i < count; ++i) {
const Dtype swish_x = top_data[i];
bottom_diff[i] = top_diff[i] * (beta * swish_x + sigmoid_output_data[i]
* (1. - beta * swish_x));
}
}
}

#ifdef CPU_ONLY
STUB_GPU(SwishLayer);
#endif

INSTANTIATE_CLASS(SwishLayer);
REGISTER_LAYER_CLASS(Swish);

} // namespace caffe
54 changes: 54 additions & 0 deletions src/caffe/layers/swish_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include <cmath>
#include <vector>

#include "caffe/layers/swish_layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void SwishLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* sigmoid_input_data = sigmoid_input_->mutable_gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const int count = bottom[0]->count();
Dtype beta = this->layer_param_.swish_param().beta();
caffe_copy(count, bottom_data, sigmoid_input_data);
caffe_gpu_scal(count, beta, sigmoid_input_data);
sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
caffe_gpu_mul(count, bottom_data, sigmoid_output_->gpu_data(), top_data);
}

template <typename Dtype>
__global__ void SwishBackward(const int n, const Dtype* in_diff,
const Dtype* out_data, const Dtype* sigmoid_output_data, Dtype* out_diff,
const Dtype beta) {
CUDA_KERNEL_LOOP(index, n) {
const Dtype swish_x = out_data[index];
out_diff[index] = in_diff[index] * (beta * swish_x
+ sigmoid_output_data[index] * (1 - beta * swish_x));
}
}

template <typename Dtype>
void SwishLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0]) {
const Dtype* top_data = top[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const int count = bottom[0]->count();
Dtype beta = this->layer_param_.swish_param().beta();
// NOLINT_NEXT_LINE(whitespace/operators)
SwishBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, top_data, sigmoid_output_data, bottom_diff, beta);
CUDA_POST_KERNEL_CHECK;
}
}

INSTANTIATE_LAYER_GPU_FUNCS(SwishLayer);

} // namespace caffe

0 comments on commit 9453eb0

Please sign in to comment.