Skip to content

Commit

Permalink
Merge pull request #116 from aravindhm/tanh
Browse files Browse the repository at this point in the history
Add TanH = hyperbolic tangent activation layer (popular for sparse
autoencoders).
  • Loading branch information
shelhamer committed Feb 23, 2014
2 parents 2092fbc + 9a45a0a commit 2fce080
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 0 deletions.
17 changes: 17 additions & 0 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ class ReLULayer : public NeuronLayer<Dtype> {
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
};

template <typename Dtype>
class TanHLayer : public NeuronLayer<Dtype> {
public:
explicit TanHLayer(const LayerParameter& param)
: NeuronLayer<Dtype>(param) {}

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
};

template <typename Dtype>
class SigmoidLayer : public NeuronLayer<Dtype> {
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new PoolingLayer<Dtype>(param);
} else if (type == "relu") {
return new ReLULayer<Dtype>(param);
} else if (type == "tanh") {
return new TanHLayer<Dtype>(param);
} else if (type == "sigmoid") {
return new SigmoidLayer<Dtype>(param);
} else if (type == "softmax") {
Expand Down
97 changes: 97 additions & 0 deletions src/caffe/layers/tanh_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2014 Aravindh Mahendran
// TanH neuron activation function layer. Adapted from ReLU layer code written by Yangqing Jia

#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include <algorithm>

namespace caffe {

template <typename Dtype>
void TanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
Dtype exp2x;
const int count = bottom[0]->count();
for (int i = 0; i < count; ++i) {
exp2x = exp(2*bottom_data[i]);
top_data[i] = (exp2x - Dtype(1))/(exp2x + Dtype(1));
}
}

template <typename Dtype>
Dtype TanHLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
const int count = (*bottom)[0]->count();
Dtype exp2x;
Dtype tanhx;
for (int i = 0; i < count; ++i) {
exp2x = exp(2*bottom_data[i]);
tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1));
bottom_diff[i] = top_diff[i] * (1 - tanhx*tanhx);
}
}
return Dtype(0);
}

template <typename Dtype>
__global__ void TanHForward(const int n, const Dtype* in, Dtype* out) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
Dtype exp2x = exp(2*in[index]);
out[index] = (exp2x - Dtype(1))/(exp2x + Dtype(1));
}
}

template <typename Dtype>
void TanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
TanHForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data);
CUDA_POST_KERNEL_CHECK;
// << " count: " << count << " bottom_data: "
// << (unsigned long)bottom_data << " top_data: " << (unsigned long)top_data
// << " blocks: " << CAFFE_GET_BLOCKS(count)
// << " threads: " << CAFFE_CUDA_NUM_THREADS;
}

template <typename Dtype>
__global__ void TanHBackward(const int n, const Dtype* in_diff,
const Dtype* in_data, Dtype* out_diff) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
Dtype exp2x = exp(2*in_data[index]);
Dtype tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1));
out_diff[index] = in_diff[index] * (1 - tanhx*tanhx);
}
}

template <typename Dtype>
Dtype TanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const int count = (*bottom)[0]->count();
TanHBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_data, bottom_diff);
CUDA_POST_KERNEL_CHECK;
}
return Dtype(0);
}

INSTANTIATE_CLASS(TanHLayer);


} // namespace caffe
102 changes: 102 additions & 0 deletions src/caffe/test/test_tanh_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright 2014 Aravindh Mahendran
// Adapted from other test files

#include <cmath>
#include <cstring>
#include <cuda_runtime.h>

#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/test/test_gradient_check_util.hpp"

#include "caffe/test/test_caffe_main.hpp"

namespace caffe {

extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;

template <typename Dtype>
class TanHLayerTest : public ::testing::Test {
protected:
TanHLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 10, 1, 1)),
blob_top_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
blob_bottom_vec_.push_back(blob_bottom_);
blob_top_vec_.push_back(blob_top_);
};
virtual ~TanHLayerTest() { delete blob_bottom_; delete blob_top_; }
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};

typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(TanHLayerTest, Dtypes);

TYPED_TEST(TanHLayerTest, TestForwardCPU) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::CPU);
TanHLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Test exact values
for (int i = 0; i < this->blob_bottom_->num(); ++i) {
for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
for (int k = 0; k < this->blob_bottom_->height(); ++k) {
for (int l = 0; l < this->blob_bottom_->width(); ++l) {
EXPECT_GE(this->blob_top_->data_at(i,j,k,l) + 1e-4,
(exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1));
EXPECT_LE(this->blob_top_->data_at(i,j,k,l) - 1e-4,
(exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1));
}
}
}
}
}

TYPED_TEST(TanHLayerTest, TestGradientCPU) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::CPU);
TanHLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
}

TYPED_TEST(TanHLayerTest, TestForwardGPU) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::GPU);
TanHLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Test exact values
for (int i = 0; i < this->blob_bottom_->num(); ++i) {
for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
for (int k = 0; k < this->blob_bottom_->height(); ++k) {
for (int l = 0; l < this->blob_bottom_->width(); ++l) {
EXPECT_GE(this->blob_top_->data_at(i,j,k,l) + 1e-4,
(exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1));
EXPECT_LE(this->blob_top_->data_at(i,j,k,l) - 1e-4,
(exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1));
}
}
}
}
}

TYPED_TEST(TanHLayerTest, TestGradientGPU) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::GPU);
TanHLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
}

}

0 comments on commit 2fce080

Please sign in to comment.