Skip to content

Commit

Permalink
Merge pull request #147 from sergeyk/hdf5_data
Browse files Browse the repository at this point in the history
HDF5DataLayer: read matrix of features and labels from HDF5 file as input
  • Loading branch information
sergeyk authored and shelhamer committed Feb 26, 2014
2 parents b6af3c9 + 3873048 commit 723c168
Show file tree
Hide file tree
Showing 10 changed files with 355 additions and 2 deletions.
9 changes: 7 additions & 2 deletions Makefile
Expand Up @@ -69,8 +69,13 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64

INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
LIBRARIES := cudart cublas curand mkl_rt pthread \
glog protobuf leveldb snappy boost_system \
LIBRARIES := cudart cublas curand \
mkl_rt \
pthread \
glog protobuf leveldb \
snappy \
boost_system \
hdf5 hdf5_hl \
opencv_core opencv_highgui opencv_imgproc
PYTHON_LIBRARIES := boost_python python2.7
WARNINGS := -Wall
Expand Down
9 changes: 9 additions & 0 deletions include/caffe/util/io.hpp
Expand Up @@ -5,6 +5,10 @@

#include <google/protobuf/message.h>

#include <boost/scoped_ptr.hpp>
#include "hdf5.h"
#include "hdf5_hl.h"

#include <string>

#include "caffe/blob.hpp"
Expand Down Expand Up @@ -48,6 +52,11 @@ inline bool ReadImageToDatum(const string& filename, const int label,
return ReadImageToDatum(filename, label, 0, 0, datum);
}

template <typename Dtype>
void load_2d_dataset(
hid_t file_id, const char* dataset_name_,
boost::scoped_ptr<Dtype>* array, hsize_t* dims);

} // namespace caffe

#endif // CAFFE_UTIL_IO_H_
30 changes: 30 additions & 0 deletions include/caffe/vision_layers.hpp
Expand Up @@ -5,6 +5,9 @@

#include <leveldb/db.h>
#include <pthread.h>
#include <boost/scoped_ptr.hpp>

#include "hdf5.h"

#include <vector>

Expand Down Expand Up @@ -351,6 +354,33 @@ class DataLayer : public Layer<Dtype> {
};


template <typename Dtype>
class HDF5DataLayer : public Layer<Dtype> {
public:
explicit HDF5DataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual ~HDF5DataLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);

boost::scoped_ptr<Dtype> data;
boost::scoped_ptr<Dtype> label;
hsize_t data_dims[2];
hsize_t label_dims[2];
hsize_t current_row;
};


template <typename Dtype>
class SoftmaxLayer : public Layer<Dtype> {
public:
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/layer_factory.cpp
Expand Up @@ -27,6 +27,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new ConvolutionLayer<Dtype>(param);
} else if (type == "data") {
return new DataLayer<Dtype>(param);
} else if (type == "hdf5_data") {
return new HDF5DataLayer<Dtype>(param);
} else if (type == "dropout") {
return new DropoutLayer<Dtype>(param);
} else if (type == "euclidean_loss") {
Expand Down
106 changes: 106 additions & 0 deletions src/caffe/layers/hdf5_data_layer.cpp
@@ -0,0 +1,106 @@
/*
TODO:
- only load parts of the file, in accordance with a prototxt param "max_mem"
*/

#include <iostream>
#include <stdint.h>
#include <string>
#include <vector>

#include "hdf5.h"
#include "hdf5_hl.h"

#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"

using std::string;

namespace caffe {

template <typename Dtype>
HDF5DataLayer<Dtype>::~HDF5DataLayer<Dtype>() { }

template <typename Dtype>
void HDF5DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 0) << "HDF5DataLayer takes no input blobs.";
CHECK_EQ(top->size(), 2) << "HDF5DataLayer takes two blobs as output.";

// Load the HDF5 file and initialize the counter.
const char* hdf_filename = this->layer_param_.source().c_str();
LOG(INFO) << "Loading HDF5 file" << hdf_filename;
hid_t file_id = H5Fopen(hdf_filename, H5F_ACC_RDONLY, H5P_DEFAULT);
load_2d_dataset(file_id, "data", &data, data_dims);
load_2d_dataset(file_id, "label", &label, label_dims);
herr_t status = H5Fclose(file_id);
assert(data_dims[0] == label_dims[0]);
current_row = 0;

// Reshape blobs.
(*top)[0]->Reshape(this->layer_param_.batchsize(), data_dims[1], 1, 1);
(*top)[1]->Reshape(this->layer_param_.batchsize(), label_dims[1], 1, 1);
LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
<< (*top)[0]->channels() << "," << (*top)[0]->height() << ","
<< (*top)[0]->width();
}

template <typename Dtype>
void HDF5DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const int batchsize = this->layer_param_.batchsize();
for (int i = 0; i < batchsize; ++i, ++current_row) {
if (current_row == data_dims[0]) {
current_row = 0;
}

memcpy( &(*top)[0]->mutable_cpu_data()[i * data_dims[1]],
&(data.get()[current_row * data_dims[1]]),
sizeof(Dtype) * data_dims[1]);

memcpy( &(*top)[1]->mutable_cpu_data()[i * label_dims[1]],
&(label.get()[current_row * label_dims[1]]),
sizeof(Dtype) * label_dims[1]);
}
}

template <typename Dtype>
void HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const int batchsize = this->layer_param_.batchsize();
for (int i = 0; i < batchsize; ++i, ++current_row) {
if (current_row == data_dims[0]) {
current_row = 0;
}

CUDA_CHECK(cudaMemcpy(
&(*top)[0]->mutable_gpu_data()[i * data_dims[1]],
&(data.get()[current_row * data_dims[1]]),
sizeof(Dtype) * data_dims[1],
cudaMemcpyHostToDevice));

CUDA_CHECK(cudaMemcpy(
&(*top)[1]->mutable_gpu_data()[i * label_dims[1]],
&(label.get()[current_row * label_dims[1]]),
sizeof(Dtype) * label_dims[1],
cudaMemcpyHostToDevice));
}
}

// The backward operations are dummy - they do not carry any computation.
template <typename Dtype>
Dtype HDF5DataLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
return Dtype(0.);
}

template <typename Dtype>
Dtype HDF5DataLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
return Dtype(0.);
}

INSTANTIATE_CLASS(HDF5DataLayer);

} // namespace caffe
17 changes: 17 additions & 0 deletions src/caffe/test/test_data/generate_sample_data.py
@@ -0,0 +1,17 @@
"""
Generate data used in the HDF5DataLayer test.
"""

import numpy as np
import h5py

num_cols = 8
num_rows = 10
data = np.arange(num_cols * num_rows).reshape(num_rows, num_cols)
label = np.arange(num_rows)[:, np.newaxis]
print data
print label

with h5py.File('./sample_data.h5', 'w') as f:
f['data'] = data.astype('float32')
f['label'] = label.astype('float32')
Binary file added src/caffe/test/test_data/sample_data.h5
Binary file not shown.
16 changes: 16 additions & 0 deletions src/caffe/test/test_data_layer.cpp
Expand Up @@ -81,6 +81,7 @@ TYPED_TEST(DataLayerTest, TestRead) {
EXPECT_EQ(this->blob_top_label_->channels(), 1);
EXPECT_EQ(this->blob_top_label_->height(), 1);
EXPECT_EQ(this->blob_top_label_->width(), 1);

// Go through the data 100 times
for (int iter = 0; iter < 100; ++iter) {
layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
Expand All @@ -94,6 +95,21 @@ TYPED_TEST(DataLayerTest, TestRead) {
}
}
}

// Same test, in GPU mode.
Caffe::set_mode(Caffe::GPU);
for (int iter = 0; iter < 100; ++iter) {
layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(i, this->blob_top_label_->cpu_data()[i]);
}
for (int i = 0; i < 5; ++i) {
for (int j = 0; j < 24; ++j) {
EXPECT_EQ(i, this->blob_top_data_->cpu_data()[i * 24 + j])
<< "debug: i " << i << " j " << j;
}
}
}
}

}
130 changes: 130 additions & 0 deletions src/caffe/test/test_hdf5data_layer.cpp
@@ -0,0 +1,130 @@
// Copyright 2013 Yangqing Jia

#include <cuda_runtime.h>
#include <leveldb/db.h>

#include <string>

#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/test/test_caffe_main.hpp"

using std::string;

namespace caffe {

extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;

template <typename Dtype>
class HDF5DataLayerTest : public ::testing::Test {
protected:
HDF5DataLayerTest()
: blob_top_data_(new Blob<Dtype>()),
blob_top_label_(new Blob<Dtype>()),
filename(NULL) {};
virtual void SetUp() {
blob_top_vec_.push_back(blob_top_data_);
blob_top_vec_.push_back(blob_top_label_);

// TODO: generate sample HDF5 file on the fly.
// For now, use example HDF5 file.
// TODO: how to best deal with the relativeness of the path?
filename = "src/caffe/test/test_data/sample_data.h5";
LOG(INFO) << "Using sample HDF5 data file " << filename;
};

virtual ~HDF5DataLayerTest() {
delete blob_top_data_;
delete blob_top_label_;
}

char* filename;
Blob<Dtype>* const blob_top_data_;
Blob<Dtype>* const blob_top_label_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};

typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(HDF5DataLayerTest, Dtypes);

TYPED_TEST(HDF5DataLayerTest, TestRead) {
// Create LayerParameter with the known parameters.
// The data file we are reading has 10 rows and 8 columns,
// with values from 0 to 10*8 reshaped in row-major order.
LayerParameter param;
int batchsize = 5;
param.set_batchsize(batchsize);
param.set_source(this->filename);
int num_rows = 10;
int num_cols = 8;
HDF5DataLayer<TypeParam> layer(param);

// Test that the layer setup got the correct parameters.
layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
EXPECT_EQ(this->blob_top_data_->num(), batchsize);
EXPECT_EQ(this->blob_top_data_->channels(), num_cols);
EXPECT_EQ(this->blob_top_data_->height(), 1);
EXPECT_EQ(this->blob_top_data_->width(), 1);

EXPECT_EQ(this->blob_top_label_->num(), batchsize);
EXPECT_EQ(this->blob_top_label_->channels(), 1);
EXPECT_EQ(this->blob_top_label_->height(), 1);
EXPECT_EQ(this->blob_top_label_->width(), 1);

// Go through the data 100 times.
for (int iter = 0; iter < 100; ++iter) {
layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);

// On even iterations, we're reading the first half of the data.
// On odd iterations, we're reading the second half of the data.
int label_offset = (iter % 2 == 0) ? 0 : batchsize;
int data_offset = (iter % 2 == 0) ? 0 : batchsize * num_cols;

for (int i = 0; i < batchsize; ++i) {
EXPECT_EQ(
label_offset + i,
this->blob_top_label_->cpu_data()[i]);
}
for (int i = 0; i < batchsize; ++i) {
for (int j = 0; j < num_cols; ++j) {
EXPECT_EQ(
data_offset + i * num_cols + j,
this->blob_top_data_->cpu_data()[i * num_cols + j])
<< "debug: i " << i << " j " << j;
}
}
}

// Exact same test in GPU mode.
Caffe::set_mode(Caffe::GPU);
// Go through the data 100 times.
for (int iter = 0; iter < 100; ++iter) {
layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);

// On even iterations, we're reading the first half of the data.
// On odd iterations, we're reading the second half of the data.
int label_offset = (iter % 2 == 0) ? 0 : batchsize;
int data_offset = (iter % 2 == 0) ? 0 : batchsize * num_cols;

for (int i = 0; i < batchsize; ++i) {
EXPECT_EQ(
label_offset + i,
this->blob_top_label_->cpu_data()[i]);
}
for (int i = 0; i < batchsize; ++i) {
for (int j = 0; j < num_cols; ++j) {
EXPECT_EQ(
data_offset + i * num_cols + j,
this->blob_top_data_->cpu_data()[i * num_cols + j])
<< "debug: i " << i << " j " << j;
}
}
}
}

} // namespace caffe

0 comments on commit 723c168

Please sign in to comment.