Skip to content

Commit

Permalink
pycaffe: add Net.set_input_arrays for input from numpy
Browse files Browse the repository at this point in the history
This requires a net whose first layer is a MemoryDataLayer.
  • Loading branch information
longjon committed May 2, 2014
1 parent e1072a6 commit 76c2554
Showing 1 changed file with 65 additions and 1 deletion.
66 changes: 65 additions & 1 deletion python/caffe/_caffe.cpp
Expand Up @@ -158,6 +158,8 @@ struct CaffeNet {

virtual ~CaffeNet() {}

// this function is mostly redundant with the one below, but should go away
// with new pycaffe
inline void check_array_against_blob(
PyArrayObject* arr, Blob<float>* blob) {
CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS);
Expand All @@ -170,6 +172,29 @@ struct CaffeNet {
CHECK_EQ(dims[3], blob->width());
}

// generate Python exceptions for badly shaped or discontiguous arrays
inline void check_contiguous_array(PyArrayObject* arr, string name,
int channels, int height, int width) {
if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
throw std::runtime_error(name + " must be C contiguous");
}
if (PyArray_NDIM(arr) != 4) {
throw std::runtime_error(name + " must be 4-d");
}
if (PyArray_TYPE(arr) != NPY_FLOAT32) {
throw std::runtime_error(name + " must be float32");
}
if (PyArray_DIMS(arr)[1] != channels) {
throw std::runtime_error(name + " has wrong number of channels");
}
if (PyArray_DIMS(arr)[2] != height) {
throw std::runtime_error(name + " has wrong height");
}
if (PyArray_DIMS(arr)[3] != width) {
throw std::runtime_error(name + " has wrong width");
}
}

// The actual forward function. It takes in a python list of numpy arrays as
// input and a python list of numpy arrays as output. The input and output
// should all have correct shapes, are single-precisionabcdnt- and
Expand Down Expand Up @@ -267,6 +292,41 @@ struct CaffeNet {
net_->ForwardPrefilled();
}

void set_input_arrays(object data_obj, object labels_obj) {
// check that this network has an input MemoryDataLayer
shared_ptr<MemoryDataLayer<float> > md_layer =
boost::dynamic_pointer_cast<MemoryDataLayer<float> >(net_->layers()[0]);
if (!md_layer) {
throw std::runtime_error("set_input_arrays may only be called if the"
" first layer is a MemoryDataLayer");
}

// check that we were passed appropriately-sized contiguous memory
PyArrayObject* data_arr =
reinterpret_cast<PyArrayObject*>(data_obj.ptr());
PyArrayObject* labels_arr =
reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
check_contiguous_array(data_arr, "data array", md_layer->datum_channels(),
md_layer->datum_height(), md_layer->datum_width());
check_contiguous_array(labels_arr, "labels array", 1, 1, 1);
if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
throw std::runtime_error("data and labels must have the same first"
" dimension");
}
if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
throw std::runtime_error("first dimensions of input arrays must be a"
" multiple of batch size");
}

// hold references
input_data_ = data_obj;
input_labels_ = labels_obj;

md_layer->Reset(static_cast<float*>(PyArray_DATA(data_arr)),
static_cast<float*>(PyArray_DATA(labels_arr)),
PyArray_DIMS(data_arr)[0]);
}

// The caffe::Caffe utility functions.
void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
Expand All @@ -292,6 +352,9 @@ struct CaffeNet {

// The pointer to the internal caffe::Net instant.
shared_ptr<Net<float> > net_;
// if taking input from an ndarray, we need to hold references
object input_data_;
object input_labels_;
};

class CaffeSGDSolver {
Expand Down Expand Up @@ -334,7 +397,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("set_device", &CaffeNet::set_device)
// rename blobs here since the pycaffe.py wrapper will replace it
.add_property("_blobs", &CaffeNet::blobs)
.add_property("layers", &CaffeNet::layers);
.add_property("layers", &CaffeNet::layers)
.def("set_input_arrays", &CaffeNet::set_input_arrays);

boost::python::class_<CaffeBlob, CaffeBlobWrap>(
"Blob", boost::python::no_init)
Expand Down

0 comments on commit 76c2554

Please sign in to comment.