Skip to content

Commit

Permalink
Merge branch 'caffe-0.16' of https://github.com/drnikolaev/nvcaffe-dev
Browse files Browse the repository at this point in the history
…into caffe-0.16
  • Loading branch information
borisgin committed Nov 1, 2017
2 parents b3a0fe1 + db2f3c7 commit 759e155
Show file tree
Hide file tree
Showing 17 changed files with 238 additions and 201 deletions.
17 changes: 12 additions & 5 deletions include/caffe/blob.hpp
Expand Up @@ -37,8 +37,6 @@ class TBlob;
class Blob {
public:
void Swap(Blob& other) noexcept {
CHECK_EQ(data_tensor_->type(), other.data_tensor_->type());
CHECK_EQ(diff_tensor_->type(), other.diff_tensor_->type());
std::swap(data_tensor_, other.data_tensor_);
std::swap(diff_tensor_, other.diff_tensor_);
std::swap(shape_data_, other.shape_data_);
Expand All @@ -50,7 +48,7 @@ class Blob {
Blob(Type data_type, Type diff_type)
: data_tensor_(make_shared<Tensor>(data_type)),
diff_tensor_(make_shared<Tensor>(diff_type)),
count_(0) {}
count_(0), safe_reshape_mode_(false) {}
explicit Blob(Type dtype)
: Blob(dtype, dtype) {}

Expand Down Expand Up @@ -208,8 +206,16 @@ class Blob {
return count_;
}

size_t size_of(bool of_data) const {
return of_data ? data_tensor_->size_of() : diff_tensor_->size_of();
size_t sizeof_data(bool allocated = false) const {
return data_tensor_->size_of(allocated);
}

size_t sizeof_diff(bool allocated = false) const {
return diff_tensor_->size_of(allocated);
}

void safe_reshape_mode(bool mode) {
safe_reshape_mode_ = mode;
}

/**
Expand Down Expand Up @@ -541,6 +547,7 @@ class Blob {
shared_ptr<SyncedMemory> shape_data_;
vector<int> shape_;
int count_;
bool safe_reshape_mode_; // if true, reshape never shrinks

bool is_current_data_valid() const {
return data_tensor_->is_current_valid();
Expand Down
5 changes: 4 additions & 1 deletion include/caffe/data_transformer.hpp
Expand Up @@ -48,7 +48,8 @@ class DataTransformer {
* @return Output shape
*/
template<typename Dtype>
vector<int> Transform(const Datum* datum, Dtype* buf, size_t buf_len, bool repack = true) {
vector<int> Transform(const Datum* datum, Dtype* buf, size_t buf_len,
Packing& out_packing, bool repack = true) {
vector<int> shape;
const bool shape_only = buf == nullptr;
CHECK(!(param_.force_color() && param_.force_gray()))
Expand All @@ -66,6 +67,7 @@ class DataTransformer {
TransformV1(*datum, buf, buf_len);
shape = vector<int>{1, datum->channels(), datum->height(), datum->width()};
v1_path = true;
out_packing = NCHW;
}
}
if (param_.crop_size() > 0) {
Expand All @@ -75,6 +77,7 @@ class DataTransformer {
if (!shape_only && !v1_path) {
CHECK_NOTNULL(img.data);
Transform(img, buf, buf_len, repack);
out_packing = NHWC;
}
return shape;
}
Expand Down
8 changes: 5 additions & 3 deletions include/caffe/layers/base_data_layer.hpp
Expand Up @@ -55,7 +55,9 @@ class Batch {

Batch(Type data_type, Type diff_type)
: data_(Blob::create(data_type, diff_type)), label_(Blob::create(data_type, diff_type)),
id_((size_t) -1), data_packing_(NCHW) {}
id_((size_t) -1), data_packing_(NCHW) {
data_->safe_reshape_mode(true);
}

size_t id() const {
return id_;
Expand All @@ -64,7 +66,7 @@ class Batch {
id_ = id;
}
size_t bytes() const {
return data_->size_of(true) + label_->size_of(true);
return data_->sizeof_data(true) + label_->sizeof_data(true);
}
Packing data_packing() const {
return data_packing_;
Expand Down Expand Up @@ -113,9 +115,9 @@ class BasePrefetchingDataLayer : public BaseDataLayer<Ftype, Btype>, public Inte
protected:
void InternalThreadEntry() override;
void InternalThreadEntryN(size_t thread_id) override;
void ResizeQueues();
void AllocatePrefetch();

virtual void ResizeQueues();
virtual void InitializePrefetch();
virtual void load_batch(Batch* batch, int thread_id, size_t queue_id) = 0;
virtual void start_reading() = 0;
Expand Down
3 changes: 2 additions & 1 deletion include/caffe/layers/data_layer.hpp
Expand Up @@ -44,6 +44,7 @@ class DataLayer : public BasePrefetchingDataLayer<Ftype, Btype> {
}

protected:
void ResizeQueues() override;
void InitializePrefetch() override;
void load_batch(Batch* batch, int thread_id, size_t queue_id = 0UL) override;
size_t queue_id(size_t thread_id) const override;
Expand All @@ -56,7 +57,7 @@ class DataLayer : public BasePrefetchingDataLayer<Ftype, Btype> {
shared_ptr<DataReader> sample_reader_, reader_;

#ifndef CPU_ONLY
vector<shared_ptr<GPUMemory::Workspace>> tmp_batch_holder_;
vector<shared_ptr<GPUMemory::Workspace>> tmp_holder_;
#endif

// stored random numbers for this batch
Expand Down
8 changes: 5 additions & 3 deletions include/caffe/tensor.hpp
Expand Up @@ -42,15 +42,15 @@ class Tensor {
return type_;
}

size_t size_of() const {
return tsize(type_) * count_;
size_t size_of(bool allocated = false) const {
return tsize(type_) * (allocated ? alloc_count_ : count_);
}

void set(float value);
void scale(float new_scale, void* handle = nullptr);
void invalidate_others();
void convert(Type new_type);
void Reshape(int count);
void Reshape(int count, bool safe_reshape = false);
float asum() const;
const shared_ptr<SyncedMemory>& synced_mem() const;
shared_ptr<SyncedMemory>& mutable_synced_mem(bool flush = true);
Expand Down Expand Up @@ -85,6 +85,8 @@ class Tensor {
shared_ptr<vector<shared_ptr<SyncedMemory>>> synced_arrays_;
// number of entries - comes from Blob via Reshape
int count_;
// number of entries allocated (useful when avoiding deallocations is needed)
int alloc_count_;

DISABLE_COPY_MOVE_AND_ASSIGN(Tensor);
}; // class Tensor
Expand Down
18 changes: 18 additions & 0 deletions include/caffe/util/gpu_math_functions.cuh
Expand Up @@ -45,6 +45,15 @@ return __half2float(a) < __half2float(b);
#endif
}

__device__ __inline__
half hmul(half a, half b) {
#if __CUDA_ARCH__ >= 530
return __hmul(a, b);
#else
return float2half_clip(__half2float(a) * __half2float(b));
#endif
}

__device__ __inline__
half2 hmul2(half2 a, half2 b) {
#if __CUDA_ARCH__ >= 530
Expand Down Expand Up @@ -75,6 +84,15 @@ half2 hge2(half2 a, half2 b) {
#endif
}

__device__ __inline__
half hadd(half a, half b) {
#if __CUDA_ARCH__ >= 530
return __hadd(a, b);
#else
return float2half_clip(__half2float(a) + __half2float(b));
#endif
}

__device__ __inline__
half2 hadd2(half2 a, half2 b) {
#if __CUDA_ARCH__ >= 530
Expand Down
2 changes: 2 additions & 0 deletions include/caffe/util/io.hpp
Expand Up @@ -189,6 +189,8 @@ void CVMatToDatum(const cv::Mat& cv_img, Datum& datum);
vector<int> DatumToCVMat(const Datum& datum, cv::Mat& img, bool shape_only);
vector<int> DecodeDatumToCVMat(const Datum& datum, int color_mode, cv::Mat& cv_img,
bool shape_only, bool accurate_jpeg = true);
void DecodeDatumToSignedBuf(const Datum& datum, int color_mode,
char* buf, size_t buf_len, bool accurate_jpeg);

template<typename Dtype>
void TBlobDataToCVMat(const TBlob<Dtype>& blob, cv::Mat& img) {
Expand Down
8 changes: 0 additions & 8 deletions include/caffe/util/math_functions.hpp
Expand Up @@ -197,10 +197,6 @@ template <typename Dtype>
void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
Dtype* Y, void* handle = nullptr);

void caffe_gpu_axpy_extfp16(const int N, const float alpha, const float16* X,
float16* Y);


template <typename Dtype>
void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
const Dtype beta, Dtype* Y);
Expand All @@ -223,9 +219,6 @@ void caffe_gpu_add_scalar(const int N, const Dtype alpha, Dtype *X);
template <typename Dtype>
void caffe_gpu_scal(const int N, const Dtype alpha, Dtype* X);

void caffe_gpu_scal_fp16(const int N, const float alpha, float16* X,
cublasHandle_t cublas_handle);

template <typename Dtype>
void caffe_gpu_scal(const int N, const Dtype alpha, Dtype* X, cublasHandle_t cublas_handle);

Expand All @@ -244,7 +237,6 @@ void caffe_gpu_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_gpu_square(const int N, const Dtype* a, Dtype* y);


template <typename Dtype>
void caffe_gpu_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);

Expand Down
16 changes: 10 additions & 6 deletions src/caffe/blob.cpp
Expand Up @@ -50,8 +50,8 @@ void Blob::Reshape(const vector<int>& shape) {
shape_[i] = shape[i];
shape_data[i] = shape[i];
}
data_tensor_->Reshape(count_);
diff_tensor_->Reshape(count_);
data_tensor_->Reshape(count_, safe_reshape_mode_);
diff_tensor_->Reshape(count_, safe_reshape_mode_);
CHECK(is_current_data_valid());
CHECK(is_current_diff_valid());
}
Expand Down Expand Up @@ -202,11 +202,11 @@ void Blob::gpu_axpy(int count, Type dtype, float alpha, const void* X, void* Y)
caffe_gpu_axpy(count, alpha, static_cast<const float*>(X),
static_cast<float*>(Y));
} else if (is_type<float16>(dtype)) {
caffe_gpu_axpy_extfp16(count, alpha,
static_cast<const float16*>(X), static_cast<float16*>(Y));
caffe_gpu_axpy(count, static_cast<float16>(alpha), static_cast<const float16*>(X),
static_cast<float16*>(Y));
} else if (is_type<double>(dtype)) {
caffe_gpu_axpy(count, static_cast<double>(alpha),
static_cast<const double*>(X), static_cast<double*>(Y));
caffe_gpu_axpy(count, static_cast<double>(alpha), static_cast<const double*>(X),
static_cast<double*>(Y));
} else {
LOG(FATAL) << "Unsupported data type: " << Type_Name(dtype);
}
Expand Down Expand Up @@ -284,17 +284,20 @@ void Blob::CopyFrom(const Blob& source, bool copy_diff, bool reshape,
Type src_type = copy_diff ? source.diff_type() : source.data_type();
Type dst_type = copy_diff ? diff_type() : data_type();
const bool is_gpu = Caffe::mode() == Caffe::GPU;
#ifndef CPU_ONLY
if ((src_packing == dst_packing && src_type == dst_type)
|| !is_gpu || shape().size() != 4 || source.shape().size() != 4) {
if (srct == dstt) {
return;
}
#endif
Tensor::copy_helper(is_gpu, count_,
is_gpu ? src->gpu_data() : src->cpu_data(),
src_type,
is_gpu ? dst->mutable_gpu_data(false) : dst->mutable_cpu_data(false),
dst_type);
dst->validate();
#ifndef CPU_ONLY
} else {
CHECK(srct != dstt);
cudnnHandle_t handle = Caffe::cudnn_handle();
Expand All @@ -311,6 +314,7 @@ void Blob::CopyFrom(const Blob& source, bool copy_diff, bool reshape,
CUDNN_CHECK(cudnnDestroyTensorDescriptor(src_desc));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dst_desc));
}
#endif
}

void Blob::FromProto(const BlobProto& proto, bool reshape) {
Expand Down
26 changes: 11 additions & 15 deletions src/caffe/layers/base_data_layer.cpp
Expand Up @@ -196,21 +196,17 @@ void BasePrefetchingDataLayer<Ftype, Btype>::InitializePrefetch() {
template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::AllocatePrefetch() {
#ifndef CPU_ONLY
if (Caffe::mode() == Caffe::GPU) {
for (int i = 0; i < prefetch_.size(); ++i) {
Btype* bdata = prefetch_[i]->data_->template mutable_cpu_data_c<Btype>(false);
(void) bdata;
Ftype* tdata = prefetch_[i]->data_->template mutable_cpu_data_c<Ftype>(false);
(void) tdata;
if (this->output_labels_) {
bdata = prefetch_[i]->label_->template mutable_cpu_data_c<Btype>(false);
(void) bdata;
tdata = prefetch_[i]->label_->template mutable_cpu_data_c<Ftype>(false);
(void) tdata;
}
}
}
LOG(INFO) << this->print_current_device() << " Prefetch allocated.";
// if (Caffe::mode() == Caffe::GPU) {
// for (int i = 0; i < prefetch_.size(); ++i) {
// Ftype* tdata = prefetch_[i]->data_->template mutable_gpu_data_c<Ftype>(false);
// (void) tdata;
// if (this->output_labels_) {
// tdata = prefetch_[i]->label_->template mutable_gpu_data_c<Ftype>(false);
// (void) tdata;
// }
// }
// }
// LOG(INFO) << this->print_current_device() << " Prefetch allocated.";
#else
if (Caffe::mode() == Caffe::CPU) {
for (int i = 0; i < prefetch_.size(); ++i) {
Expand Down
13 changes: 2 additions & 11 deletions src/caffe/layers/base_data_layer.cu
Expand Up @@ -10,23 +10,14 @@ void BasePrefetchingDataLayer<Ftype, Btype>::Forward_gpu(const vector<Blob*>& bo
// Note: this function runs in one thread per object and one object per one Solver thread
shared_ptr<Batch> batch =
prefetches_full_[next_batch_queue_]->pop("Data layer prefetch queue empty");
if (top[0]->data_type() == batch->data_->data_type()
&& top[0]->diff_type() == batch->data_->diff_type()
&& top[0]->shape() == batch->data_->shape()
&& batch->data_packing() == this->transform_param_.forward_packing()) {
if (batch->data_packing() == this->transform_param_.forward_packing()) {
top[0]->Swap(*batch->data_);
} else {
top[0]->CopyDataFrom(*batch->data_, true, batch->data_packing(),
this->transform_param_.forward_packing());
}
if (this->output_labels_) {
if (top[1]->data_type() == batch->label_->data_type()
&& top[1]->diff_type() == batch->label_->diff_type()
&& top[1]->shape() == batch->label_->shape()) {
top[1]->Swap(*batch->label_);
} else {
top[1]->CopyDataFrom(*batch->label_, true);
}
top[1]->Swap(*batch->label_);
}
batch->set_id((size_t) -1);
prefetches_free_[next_batch_queue_]->push(batch);
Expand Down

0 comments on commit 759e155

Please sign in to comment.