Skip to content

Commit

Permalink
Add BatchNormalizationLayer
Browse files Browse the repository at this point in the history
We use the cuDNN v4 release candidate implementation of BatchNormalization (cudnnBatchNormalizationForwardTraining, cudnnBatchNormalizationForwardInference, and cudnnBatchNormalizationBackward). There are three things to be aware of:

1) We do not actually use cudnnBatchNormalizationForwardInference because we believe there is a bug in its implementation. We use cudnnBatchNormalizationForwardTraining, but set all the parameters related to maintaining running statistics to NULL
2) cudnnBatchNormalization* do not work with FP16, so you must use float or double
3) If a BatchNormalizationLayer has multiple in responses, it must satisfy the following requirements:
  - The number of in and out responses must be the same
  - The dimensions of each in response must be the same
If these are met, then the running statistics are calculated as follows: in the forward pass, compute the running statistics (mean and variance) for the first in response. We copy these statistics over to the running statistics for the second in response and then compute the running statistics for the second in response using the first in response. And so on. Because we store the running statistics for each in response separately, we can use the appropriate running statistics in the backward pass. Once we finish the backward pass, we set all the running statistics to the running statistics for the last in response.

This mechanism allows us to benefit from the running statistics for each input, but also lets us match the appropriate statistics between forward and backward passes. HOWEVER!! This means that BatchNormalizationLayer is not appropriate for multiple in responses for other uses cases that may require them. You have been warned.
  • Loading branch information
danielsuo committed Jan 7, 2016
1 parent 1fe98f2 commit 9861e29
Showing 1 changed file with 282 additions and 8 deletions.
290 changes: 282 additions & 8 deletions marvin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#define sizeofStorageT 2
#define sizeofComputeT 4
#define CUDNNStorageT CUDNN_DATA_HALF
#define CUDNNConvStorageT CUDNN_DATA_FLOAT
#define CUDNNConvComputeT CUDNN_DATA_FLOAT
#define CPUStorage2ComputeT(x) (cpu_half2float(x))
#define CPUCompute2StorageT(x) (cpu_float2half(x))
#define GPUStorage2ComputeT(x) (__half2float(x))
Expand All @@ -29,7 +29,7 @@
#define sizeofStorageT 4
#define sizeofComputeT 4
#define CUDNNStorageT CUDNN_DATA_FLOAT
#define CUDNNConvStorageT CUDNN_DATA_FLOAT
#define CUDNNConvComputeT CUDNN_DATA_FLOAT
#define CPUStorage2ComputeT(x) (x)
#define CPUCompute2StorageT(x) (x)
#define GPUStorage2ComputeT(x) (x)
Expand All @@ -45,7 +45,7 @@
#define sizeofStorageT 8
#define sizeofComputeT 8
#define CUDNNStorageT CUDNN_DATA_DOUBLE
#define CUDNNConvStorageT CUDNN_DATA_DOUBLE
#define CUDNNConvComputeT CUDNN_DATA_DOUBLE
#define CPUStorage2ComputeT(x) (x)
#define CPUCompute2StorageT(x) (x)
#define GPUStorage2ComputeT(x) (x)
Expand Down Expand Up @@ -417,6 +417,11 @@ class JSON{
else variable = (int)(this->member[name]->returnReal());
};

void set(std::string name, double &variable, double default_value){
if (this->member.find(name) == this->member.end()) variable = default_value;
else variable = (double)(this->member[name]->returnReal());
};

void set(std::string name, unsigned int &variable, unsigned int default_value){
if (this->member.find(name) == this->member.end()) variable = default_value;
else variable = (unsigned int)(this->member[name]->returnReal());
Expand Down Expand Up @@ -547,6 +552,12 @@ class JSON{
else{ std::cout<<"Unsupported "<<name<<" = "<<this->member[name]->returnString()<<std::endl; FatalError(__LINE__); }
};

void set(std::string name, cudnnBatchNormMode_t &variable, cudnnBatchNormMode_t default_value){
if (this->member.find(name) == this->member.end()) variable = default_value;
else if (0 == this->member[name]->returnString().compare("Spatial")) variable = CUDNN_BATCHNORM_SPATIAL;
else if (0 == this->member[name]->returnString().compare("PerActivation")) variable = CUDNN_BATCHNORM_PER_ACTIVATION;
else{ std::cout<<"Unsupported "<<name<<" = "<<this->member[name]->returnString()<<std::endl; FatalError(__LINE__); }
};

void set(std::string name, cudnnPoolingMode_t &variable, cudnnPoolingMode_t default_value){
if (this->member.find(name) == this->member.end()) variable = default_value;
Expand Down Expand Up @@ -2784,7 +2795,7 @@ class Layer {

std::string name;
Phase phase;
bool train_me; // user specify whehter they want to tune this layer
bool train_me; // user specify whether they want to tune this layer

ComputeT weight_lr_mult;
Filler weight_filler;
Expand Down Expand Up @@ -3449,7 +3460,6 @@ template <class T>
class DiskDataLayer : public DataLayer {
std::future<void> lock;


std::vector<size_t> ordering;
std::bernoulli_distribution* distribution_bernoulli;
std::vector<std::uniform_int_distribution<int>*> distribution_uniform;
Expand Down Expand Up @@ -3879,7 +3889,7 @@ class ConvolutionLayer : public Layer {
&stride[0],
&upscale[0],
CUDNN_CROSS_CORRELATION,
CUDNNConvStorageT) );
CUDNNConvComputeT) );

std::vector<int> bias_stride(bias_dim.size());

Expand Down Expand Up @@ -6024,13 +6034,276 @@ class LSTMLayer : public Layer {
};
};

class BatchNormalizationLayer : public Layer{
cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc;
cudnnBatchNormMode_t mode;
double epsilon;

unsigned int numForwardTrainingPasses;

// NOTE: we use weight and bias in place of scale and bias so we don't modify the solver
// StorageT* bnScale; => weight_dataGPU
// StorageT* bnBias; => bias_dataGPU
StorageT* resultRunningMean;
StorageT* resultRunningInvVariance;
StorageT* resultSaveMean;
StorageT* resultSaveInvVariance;
// StorageT* resultBnScaleDiff; => weight_diffGPU
// StorageT* resultBnBiasDiff; => bias_diffGPU
public:

BatchNormalizationLayer(JSON* json){
if (in.size() != out.size()){ std::cout << "BatchNormalizationLayer " << name << " needs same number of input and output responses" << std::endl; FatalError(__LINE__); }
for (int i = 1; i < in.size(); ++i){
if (!same_dim(in[0]->dim, in[i]->dim)){ std::cout <<"BatchNormalizationLayer" << name << " requires same dim for each input" << std::endl; FatalError(__LINE__); }
}

SetOrDie(json, name)
SetValue(json, phase, TrainingTesting)
SetValue(json, train_me, true)
SetValue(json, mode, CUDNN_BATCHNORM_PER_ACTIVATION)
SetValue(json, epsilon, CUDNN_BN_MIN_EPSILON)
SetValue(json, weight_lr_mult, 1.0)
SetValue(json, weight_filler, Constant)
SetValue(json, weight_filler_param, 1)
SetValue(json, bias_lr_mult, 1.0)
SetValue(json, bias_filler, Constant)
SetValue(json, bias_filler_param, 0.001)

numForwardTrainingPasses = 0;
};
// BatchNormalizationLayer(attributes)

size_t Malloc(Phase phase_){
size_t memoryBytes = 0;

train_me = train_me && phase_ != Testing;
std::cout << (train_me? "* " : " ");
std::cout << name;

// NOTE: calculated by cudnnDeriveBNTensorDescriptor
// TODO: delete this section
// std::vector<int> dim;
// std::vector<int> stride;

// dim.resize(in[i]->dim.size());
// stride.resize(in[i]->dim.size());

// // There is one statistic per batch
// dim[0] = 1;

// // In Spatial mode, we have 1xCx1x1 tensor for batch normalization
// // statistic in 2D images and 1xCx1x1x1 in 3D (not documented)
// if (mode == CUDNN_BATCHNORM_SPATIAL){
// for (int j = 2; j < dim.size(); ++j){
// dim[j] = 1;
// }
// }

// // Calculate appropriate stride
// stride[dim.size() - 1] = 1;
// for (int d = dim.size() - 2; d >= 0; --d){
// stride[d] = stride[d + 1] * dim[d + 1];
// }

// checkCUDNN(__LINE__, cudnnCreateTensorDescriptor(&bnScaleBiasMeanVarDesc) );

// checkCUDNN(__LINE__, cudnnSetTensorNdDescriptor(
// bnScaleBiasMeanVarDesc,
// CUDNNStorageT,
// dim.size(),
// &dim[0],
// &stride[0]) );

// We check that the dimensions of all in responses are the same, so this is OK
std::vector<int> dim(in[0]->dim);

// There is one statistic per batch
dim[0] = 1;

// In Spatial mode, we have 1xCx1x1 tensor for batch normalization
// statistic in 2D images and 1xCx1x1x1 in 3D
if (mode == CUDNN_BATCHNORM_SPATIAL){
for (int j = 2; j < dim.size(); ++j){
dim[j] = 1;
}
}

weight_dim = dim;
bias_dim = dim;

// We check that the dimensions of all in responses are the same, so this is OK
checkCUDNN(__LINE__, cudnnCreateTensorDescriptor(&bnScaleBiasMeanVarDesc) );
checkCUDNN(__LINE__, cudnnDeriveBNTensorDescriptor(bnScaleBiasMeanVarDesc, in[0]->getDesc(), mode) );

// These should be the same, but for clarity
weight_numel = numel(weight_dim);
bias_numel = numel(bias_dim);

size_t sizeofBNScaleBiasMeanVar = numel(dim) * sizeofStorageT;

std::cout << " 6x bnScaleBiasMeanVar"; veciPrint(dim);
checkCUDA(__LINE__, cudaMalloc( &weight_dataGPU, sizeofBNScaleBiasMeanVar) );
checkCUDA(__LINE__, cudaMalloc( &bias_dataGPU, sizeofBNScaleBiasMeanVar) );
memoryBytes += sizeofBNScaleBiasMeanVar * 2;

// TODO: We could allocate this as a contiguous block of memory
// We want to keep a history of the running statistic for each in response
checkCUDA(__LINE__, cudaMalloc( &resultRunningMean, sizeofBNScaleBiasMeanVar * in.size()) );
checkCUDA(__LINE__, cudaMalloc( &resultRunningInvVariance, sizeofBNScaleBiasMeanVar * in.size()) );
checkCUDA(__LINE__, cudaMalloc( &resultSaveMean, sizeofBNScaleBiasMeanVar * in.size()) );
checkCUDA(__LINE__, cudaMalloc( &resultSaveInvVariance, sizeofBNScaleBiasMeanVar * in.size()) );
memoryBytes += sizeofBNScaleBiasMeanVar * 4 * in.size();

for (int i = 0; i < out.size(); ++i){
out[i]->need_diff = train_me || in[i]->need_diff; // if one of them need the grad
out[i]->receptive_field = in[i]->receptive_field;
out[i]->receptive_gap = in[i]->receptive_gap;
out[i]->receptive_offset = in[i]->receptive_offset;
memoryBytes += out[i]->Malloc(in[i]->dim);
}

return memoryBytes;
};

// TODO: 'Much higher performance for HW-packed tensors for both x and y'
// TODO: We should keep running tally of where we are in dataGPU and diffGPU
// TODO: Implement groups?
void forward(Phase phase_){
for (int i = 0; i < in.size(); ++i){
if (phase_ == Training){

int weight_index = i * weight_numel;
int bias_index = i * bias_numel;

checkCUDNN(__LINE__, cudnnBatchNormalizationForwardTraining(
cudnnHandle,
mode,
one,
zero,
in[i]->getDesc(),
in[i]->dataGPU,
out[i]->getDesc(),
out[i]->dataGPU,
bnScaleBiasMeanVarDesc,
weight_dataGPU,
bias_dataGPU,
1 / (1 + numForwardTrainingPasses),
resultRunningMean + weight_index,
resultRunningInvVariance + bias_index,
epsilon,
resultSaveMean + weight_index,
resultSaveInvVariance + bias_index) );

// We want the next in response to use the running
// mean/variance we have just computed
if (i < in.size() - 1) {
checkCUDA(__LINE__, cudaMemcpy(resultRunningMean + weight_index + weight_numel, resultRunningMean + weight_index, sizeofStorageT * weight_numel, cudaMemcpyDeviceToDevice));
checkCUDA(__LINE__, cudaMemcpy(resultRunningInvVariance + bias_index + bias_numel, resultRunningInvVariance + bias_index, sizeofStorageT * bias_numel, cudaMemcpyDeviceToDevice));
checkCUDA(__LINE__, cudaMemcpy(resultSaveMean + weight_index + weight_numel, resultSaveMean + weight_index, sizeofStorageT * weight_numel, cudaMemcpyDeviceToDevice));
checkCUDA(__LINE__, cudaMemcpy(resultSaveInvVariance + bias_index + bias_numel, resultSaveInvVariance + bias_index, sizeofStorageT * bias_numel, cudaMemcpyDeviceToDevice));
}

numForwardTrainingPasses++;
} else{
checkCUDNN(__LINE__, cudnnBatchNormalizationForwardTraining(
cudnnHandle,
mode,
one,
zero,
in[i]->getDesc(),
in[i]->dataGPU,
out[i]->getDesc(),
out[i]->dataGPU,
bnScaleBiasMeanVarDesc,
weight_dataGPU,
bias_dataGPU,
1,
NULL,
NULL,
epsilon,
NULL,
NULL) );
}
// TODO: Filed bug report with NVIDIA
// checkCUDNN(__LINE__, cudnnBatchNormalizationForwardInference(
// cudnnHandle,
// mode,
// one,
// zero,
// in[i]->getDesc(),
// in[i]->dataGPU,
// out[i]->getDesc(),
// out[i]->dataGPU,
// bnScaleBiasMeanVarDesc,
// weight_dataGPU,
// bias_dataGPU,
// resultRunningMean,
// resultRunningInvVariance,
// epsilon) );
// }
}
};

// TODO: 'Much higher performance when HW-packed tensors are used for all of x, dy, dx'
// TODO: We should keep running tally of where we are in dataGPU and diffGPU
// TODO: Implement groups?
void backward(Phase phase_){

int last_weight_index = (in.size() - 1) * weight_numel;
int last_bias_index = (in.size() - 1) * bias_numel;

for (int i = 0; i < in.size(); ++i){

int weight_index = i * weight_numel;
int bias_index = i * bias_numel;

checkCUDNN(__LINE__, cudnnBatchNormalizationBackward(
cudnnHandle,
mode,
one,
zero,
in[0]->getDesc(),
in[0]->dataGPU,
out[0]->getDesc(),
out[0]->diffGPU,
in[0]->getDesc(),
in[0]->diffGPU,
bnScaleBiasMeanVarDesc,
weight_dataGPU,
weight_diffGPU,
bias_diffGPU,
epsilon,
resultSaveMean + weight_index,
resultSaveInvVariance + bias_index) );

// Copy running statistic from last in response to earlier in
// responses. Technically, we only need to copy to the first one since
// the next forward pass should take care of propagating new running
// values
if (i < in.size() - 1) {
checkCUDA(__LINE__, cudaMemcpy(resultRunningMean + weight_index, resultRunningMean + last_weight_index, sizeofStorageT * weight_numel, cudaMemcpyDeviceToDevice));
checkCUDA(__LINE__, cudaMemcpy(resultRunningInvVariance + bias_index, resultRunningInvVariance + last_bias_index, sizeofStorageT * weight_numel, cudaMemcpyDeviceToDevice));
checkCUDA(__LINE__, cudaMemcpy(resultSaveMean + weight_index, resultSaveMean + last_weight_index, sizeofStorageT * weight_numel, cudaMemcpyDeviceToDevice));
checkCUDA(__LINE__, cudaMemcpy(resultSaveInvVariance + bias_index, resultSaveInvVariance + last_bias_index, sizeofStorageT * weight_numel, cudaMemcpyDeviceToDevice));
}
}
};

~BatchNormalizationLayer(){
checkCUDNN(__LINE__, cudnnDestroyTensorDescriptor(bnScaleBiasMeanVarDesc) );

if (resultRunningMean != NULL) checkCUDA(__LINE__, cudaFree(resultRunningMean) );
if (resultRunningInvVariance != NULL) checkCUDA(__LINE__, cudaFree(resultRunningInvVariance) );
if (resultSaveMean != NULL) checkCUDA(__LINE__, cudaFree(resultSaveMean) );
if (resultSaveInvVariance != NULL) checkCUDA(__LINE__, cudaFree(resultSaveInvVariance) );
};
};

//////////////////////////////////////////////////////////////////////////////////////////////////
// Add your new layers here
//////////////////////////////////////////////////////////////////////////////////////////////////



//////////////////////////////////////////////////////////////////////////////////////////////////
// Net
//////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -6099,6 +6372,7 @@ class Net{
else if (0==type.compare("Tensor")) pLayer = new TensorLayer(p);
else if (0==type.compare("LSTM")) pLayer = new LSTMLayer(p);
else if (0==type.compare("SequenceGeneration")) {pLayer = new SequenceGenerationLayer(p); sequence_layers.push_back((SequenceGenerationLayer*)pLayer); }
else if (0==type.compare("BatchNormalization")) pLayer = new BatchNormalizationLayer(p);
else if (0==type.compare("Loss")) {pLayer = new LossLayer(p); loss_layers.push_back((LossLayer*)pLayer); }
else { std::cout<<"ERROR: recognizable layer in JSON file: "<<type<<std::endl; FatalError(__LINE__);};

Expand Down

0 comments on commit 9861e29

Please sign in to comment.