Embed layer #2032
Merged
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
6067869
test_gradient_check_util: check_bottom < -1 only checks params
jeffdonahue ac9e29f
EmbedBackward with no loops -- use caffe_gpu_atomic_add instead
jeffdonahue 443b16f
Add gpu_util.cuh, with caffe_gpu_atomic_add
jeffdonahue 4d299c3
Add EmbedLayer for inner products with sparse input (one-hot vectors),
jeffdonahue
Jump to file or symbol
Failed to load files and symbols.
| @@ -0,0 +1,35 @@ | ||
| +#ifndef CAFFE_UTIL_GPU_UTIL_H_ | ||
| +#define CAFFE_UTIL_GPU_UTIL_H_ | ||
| + | ||
| +namespace caffe { | ||
| + | ||
| +template <typename Dtype> | ||
| +inline __device__ Dtype caffe_gpu_atomic_add(const Dtype val, Dtype* address); | ||
| + | ||
| +template <> | ||
| +inline __device__ | ||
| +float caffe_gpu_atomic_add(const float val, float* address) { | ||
| + return atomicAdd(address, val); | ||
| +} | ||
| + | ||
| +// double atomicAdd implementation taken from: | ||
| +// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#axzz3PVCpVsEG | ||
| +template <> | ||
| +inline __device__ | ||
| +double caffe_gpu_atomic_add(const double val, double* address) { | ||
| + unsigned long long int* address_as_ull = // NOLINT(runtime/int) | ||
| + // NOLINT_NEXT_LINE(runtime/int) | ||
| + reinterpret_cast<unsigned long long int*>(address); | ||
| + unsigned long long int old = *address_as_ull; // NOLINT(runtime/int) | ||
| + unsigned long long int assumed; // NOLINT(runtime/int) | ||
| + do { | ||
| + assumed = old; | ||
| + old = atomicCAS(address_as_ull, assumed, | ||
| + __double_as_longlong(val + __longlong_as_double(assumed))); | ||
| + } while (assumed != old); | ||
| + return __longlong_as_double(old); | ||
| +} | ||
| + | ||
| +} // namespace caffe | ||
| + | ||
| +#endif // CAFFE_UTIL_GPU_UTIL_H_ |
| @@ -0,0 +1,122 @@ | ||
| +#include <vector> | ||
| + | ||
| +#include "caffe/blob.hpp" | ||
| +#include "caffe/common.hpp" | ||
| +#include "caffe/common_layers.hpp" | ||
| +#include "caffe/filler.hpp" | ||
| +#include "caffe/layer.hpp" | ||
| +#include "caffe/util/math_functions.hpp" | ||
| + | ||
| +namespace caffe { | ||
| + | ||
| +template <typename Dtype> | ||
| +void EmbedLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, | ||
| + const vector<Blob<Dtype>*>& top) { | ||
| + N_ = this->layer_param_.embed_param().num_output(); | ||
| + CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive."; | ||
| + K_ = this->layer_param_.embed_param().input_dim(); | ||
| + CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive."; | ||
| + bias_term_ = this->layer_param_.embed_param().bias_term(); | ||
| + // Check if we need to set up the weights | ||
| + if (this->blobs_.size() > 0) { | ||
| + LOG(INFO) << "Skipping parameter initialization"; | ||
| + } else { | ||
| + if (bias_term_) { | ||
| + this->blobs_.resize(2); | ||
| + } else { | ||
| + this->blobs_.resize(1); | ||
| + } | ||
| + // Initialize the weights -- | ||
| + // transposed from InnerProductLayer for spatial locality. | ||
| + vector<int> weight_shape(2); | ||
| + weight_shape[0] = K_; | ||
| + weight_shape[1] = N_; | ||
| + this->blobs_[0].reset(new Blob<Dtype>(weight_shape)); | ||
| + // fill the weights | ||
| + shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>( | ||
| + this->layer_param_.embed_param().weight_filler())); | ||
| + weight_filler->Fill(this->blobs_[0].get()); | ||
| + // If necessary, initialize and fill the bias term | ||
| + if (bias_term_) { | ||
| + vector<int> bias_shape(1, N_); | ||
| + this->blobs_[1].reset(new Blob<Dtype>(bias_shape)); | ||
| + shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>( | ||
| + this->layer_param_.embed_param().bias_filler())); | ||
| + bias_filler->Fill(this->blobs_[1].get()); | ||
| + } | ||
| + } // parameter initialization | ||
| + this->param_propagate_down_.resize(this->blobs_.size(), true); | ||
| +} | ||
| + | ||
| +template <typename Dtype> | ||
| +void EmbedLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, | ||
| + const vector<Blob<Dtype>*>& top) { | ||
| + // Figure out the dimensions | ||
| + M_ = bottom[0]->count(); | ||
| + vector<int> top_shape = bottom[0]->shape(); | ||
| + top_shape.push_back(N_); | ||
| + top[0]->Reshape(top_shape); | ||
| + // Set up the bias multiplier | ||
| + if (bias_term_) { | ||
| + vector<int> bias_shape(1, M_); | ||
| + bias_multiplier_.Reshape(bias_shape); | ||
| + caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data()); | ||
| + } | ||
| +} | ||
| + | ||
| +template <typename Dtype> | ||
| +void EmbedLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, | ||
| + const vector<Blob<Dtype>*>& top) { | ||
| + const Dtype* bottom_data = bottom[0]->cpu_data(); | ||
| + const Dtype* weight = this->blobs_[0]->cpu_data(); | ||
| + Dtype* top_data = top[0]->mutable_cpu_data(); | ||
| + int index; | ||
| + for (int n = 0; n < M_; ++n) { | ||
| + index = static_cast<int>(bottom_data[n]); | ||
| + DCHECK_GE(index, 0); | ||
| + DCHECK_LT(index, K_); | ||
| + DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) << "non-integer input"; | ||
| + caffe_copy(N_, weight + index * N_, top_data + n * N_); | ||
| + } | ||
| + if (bias_term_) { | ||
| + const Dtype* bias = this->blobs_[1]->cpu_data(); | ||
| + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1), | ||
| + bias_multiplier_.cpu_data(), bias, Dtype(1), top_data); | ||
| + } | ||
| +} | ||
| + | ||
| +template <typename Dtype> | ||
| +void EmbedLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, | ||
| + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { | ||
| + CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input."; | ||
| + if (this->param_propagate_down_[0]) { | ||
| + const Dtype* top_diff = top[0]->cpu_diff(); | ||
| + const Dtype* bottom_data = bottom[0]->cpu_data(); | ||
| + // Gradient with respect to weight | ||
| + Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); | ||
| + int index; | ||
| + for (int n = 0; n < M_; ++n) { | ||
| + index = static_cast<int>(bottom_data[n]); | ||
| + DCHECK_GE(index, 0); | ||
| + DCHECK_LT(index, K_); | ||
| + DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) | ||
| + << "non-integer input"; | ||
| + caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_); | ||
| + } | ||
| + } | ||
| + if (bias_term_ && this->param_propagate_down_[1]) { | ||
| + const Dtype* top_diff = top[0]->cpu_diff(); | ||
| + Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); | ||
| + caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff, | ||
| + bias_multiplier_.cpu_data(), Dtype(1), bias_diff); | ||
| + } | ||
| +} | ||
| + | ||
| +#ifdef CPU_ONLY | ||
| +STUB_GPU(EmbedLayer); | ||
| +#endif | ||
| + | ||
| +INSTANTIATE_CLASS(EmbedLayer); | ||
| +REGISTER_LAYER_CLASS(Embed); | ||
| + | ||
| +} // namespace caffe |
| @@ -0,0 +1,85 @@ | ||
| +#include <vector> | ||
| + | ||
| +#include "caffe/blob.hpp" | ||
| +#include "caffe/common.hpp" | ||
| +#include "caffe/common_layers.hpp" | ||
| +#include "caffe/filler.hpp" | ||
| +#include "caffe/layer.hpp" | ||
| +#include "caffe/util/gpu_util.cuh" | ||
| +#include "caffe/util/math_functions.hpp" | ||
| + | ||
| +namespace caffe { | ||
| + | ||
| +template <typename Dtype> | ||
| +__global__ void EmbedForward(const int nthreads, const Dtype* bottom_data, | ||
| + const Dtype* weight, const int M, const int N, const int K, | ||
| + Dtype* top_data) { | ||
| + CUDA_KERNEL_LOOP(top_index, nthreads) { | ||
| + const int n = top_index / N; | ||
| + const int d = top_index % N; | ||
| + const int index = static_cast<int>(bottom_data[n]); | ||
| + const int weight_index = index * N + d; | ||
| + top_data[top_index] = weight[weight_index]; | ||
| + } | ||
| +} | ||
| + | ||
| +template <typename Dtype> | ||
| +__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data, | ||
| + const Dtype* top_diff, const int M, const int N, const int K, | ||
| + Dtype* weight_diff); | ||
| + | ||
| +template <typename Dtype> | ||
| +__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data, | ||
| + const Dtype* top_diff, const int M, const int N, const int K, | ||
| + Dtype* weight_diff) { | ||
| + CUDA_KERNEL_LOOP(top_index, nthreads) { | ||
| + const int n = top_index / N; | ||
| + const int d = top_index % N; | ||
| + const int index = static_cast<int>(bottom_data[n]); | ||
| + const int weight_index = index * N + d; | ||
| + caffe_gpu_atomic_add(top_diff[top_index], weight_diff + weight_index); | ||
| + } | ||
| +} | ||
| + | ||
| +template <typename Dtype> | ||
| +void EmbedLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, | ||
| + const vector<Blob<Dtype>*>& top) { | ||
| + const Dtype* bottom_data = bottom[0]->gpu_data(); | ||
| + Dtype* top_data = top[0]->mutable_gpu_data(); | ||
| + const Dtype* weight = this->blobs_[0]->gpu_data(); | ||
| + const int count = top[0]->count(); | ||
| + EmbedForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators) | ||
| + <<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>( | ||
| + count, bottom_data, weight, M_, N_, K_, top_data); | ||
| + if (bias_term_) { | ||
| + caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1), | ||
| + bias_multiplier_.gpu_data(), | ||
| + this->blobs_[1]->gpu_data(), Dtype(1), top_data); | ||
| + } | ||
| +} | ||
| + | ||
| +template <typename Dtype> | ||
| +void EmbedLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, | ||
| + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { | ||
| + CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input."; | ||
| + if (this->param_propagate_down_[0]) { | ||
| + const int top_count = top[0]->count(); | ||
| + const int count = this->blobs_[0]->count(); | ||
| + const Dtype* top_diff = top[0]->gpu_diff(); | ||
| + const Dtype* bottom_data = bottom[0]->gpu_data(); | ||
| + Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); | ||
| + EmbedBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators) | ||
| + <<<CAFFE_GET_BLOCKS(top_count), CAFFE_CUDA_NUM_THREADS>>>( | ||
| + top_count, bottom_data, top_diff, M_, N_, K_, weight_diff); | ||
| + } | ||
| + if (bias_term_ && this->param_propagate_down_[1]) { | ||
| + const Dtype* top_diff = top[0]->gpu_diff(); | ||
| + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); | ||
| + caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff, | ||
| + bias_multiplier_.gpu_data(), Dtype(1), bias_diff); | ||
| + } | ||
| +} | ||
| + | ||
| +INSTANTIATE_LAYER_GPU_FUNCS(EmbedLayer); | ||
| + | ||
| +} // namespace caffe |
Oops, something went wrong.