diff --git a/.gitignore b/.gitignore index a80b9184e039..f07156a70df4 100644 --- a/.gitignore +++ b/.gitignore @@ -115,4 +115,9 @@ scala-package/*/*/target/ .project .cproject .pydevproject +CMakeFiles +cmake_install.cmake +dmlc-core +ps-lite nnvm +lib diff --git a/example/caffe/caffe_net.py b/example/caffe/caffe_net.py index dc65131e8547..c91d37bcbecb 100644 --- a/example/caffe/caffe_net.py +++ b/example/caffe/caffe_net.py @@ -1,4 +1,3 @@ -import os, sys import mxnet as mx from data import get_iterator import argparse @@ -51,13 +50,18 @@ def get_lenet(): lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax') return lenet +def get_network_from_json_file(file_name): + network = mx.sym.load(file_name) + return network + def parse_args(): - parser = argparse.ArgumentParser(description='train an image classifer on mnist') + parser = argparse.ArgumentParser(description='train an image classifier on mnist') parser.add_argument('--network', type=str, default='lenet', - choices = ['mlp', 'lenet'], - help='the cnn to use') + help='the cnn to use (mlp | lenet | ') parser.add_argument('--caffe-loss', type=int, default=0, help='Use CaffeLoss symbol') + parser.add_argument('--caffe-data', type=bool, default=False, + help='Use Caffe input-data layer (True | False)') parser.add_argument('--data-dir', type=str, default='mnist/', help='the input data directory') parser.add_argument('--gpus', type=str, @@ -88,16 +92,21 @@ def parse_args(): if __name__ == '__main__': args = parse_args() use_caffe_loss = args.caffe_loss + use_caffe_data = args.caffe_data + data_shape = () if args.network == 'mlp': data_shape = (784, ) net = get_mlp() - else: - data_shape = (1, 28, 28) + elif args.network == 'lenet': + if not use_caffe_data: + data_shape = (1, 28, 28) net = get_lenet() + else: + net = get_network_from_json_file(args.network) # train if use_caffe_loss: - train_model.fit(args, net, get_iterator(data_shape), mx.metric.Caffe()) + train_model.fit(args, net, get_iterator(data_shape, use_caffe_data), mx.metric.Caffe()) else: - train_model.fit(args, net, get_iterator(data_shape)) + train_model.fit(args, net, get_iterator(data_shape, use_caffe_data)) diff --git a/example/caffe/data.py b/example/caffe/data.py index 3512b1b7799f..0ed6ed9d0d79 100644 --- a/example/caffe/data.py +++ b/example/caffe/data.py @@ -6,12 +6,12 @@ import get_data import mxnet as mx -def get_iterator(data_shape): - def get_iterator_impl(args, kv): +def get_iterator(data_shape, use_caffe_data): + def get_iterator_impl_mnist(args, kv): """return train and val iterators for mnist""" # download data get_data.GetMNIST_ubyte() - flat = False if len(data_shape) == 3 else True + flat = False if len(data_shape) != 1 else True train = mx.io.MNISTIter( image = "data/train-images-idx3-ubyte", @@ -33,5 +33,60 @@ def get_iterator_impl(args, kv): part_index = kv.rank) return (train, val) - return get_iterator_impl + def get_iterator_impl_caffe(args, kv): + flat = False if len(data_shape) != 1 else True + train = mx.io.CaffeDataIter( + prototxt = + 'layer { \ + name: "mnist" \ + type: "Data" \ + top: "data" \ + top: "label" \ + include { \ + phase: TRAIN \ + } \ + transform_param { \ + scale: 0.00390625 \ + } \ + data_param { \ + source: "caffe/examples/mnist/mnist_train_lmdb" \ + batch_size: 64 \ + backend: LMDB \ + } \ + }', + flat = flat, + num_examples = 60000 + # float32 is the default, so left out here in order to illustrate + ) + + val = mx.io.CaffeDataIter( + prototxt = + 'layer { \ + name: "mnist" \ + type: "Data" \ + top: "data" \ + top: "label" \ + include { \ + phase: TEST \ + } \ + transform_param { \ + scale: 0.00390625 \ + } \ + data_param { \ + source: "caffe/examples/mnist/mnist_test_lmdb" \ + batch_size: 100 \ + backend: LMDB \ + } \ + }', + flat = flat, + num_examples = 10000, + dtype = "float32" # float32 is the default + ) + + return train, val + + if use_caffe_data: + return get_iterator_impl_caffe + else: + return get_iterator_impl_mnist diff --git a/plugin/caffe/caffe_data_iter.cc b/plugin/caffe/caffe_data_iter.cc new file mode 100644 index 000000000000..c94c315ee0da --- /dev/null +++ b/plugin/caffe/caffe_data_iter.cc @@ -0,0 +1,251 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file caffe_data_iter.cc + * \brief register mnist iterator +*/ +#include +#include +#include +#include + +#include "caffe_common.h" +#include "caffe_stream.h" +#include "caffe_fieldentry.h" +#include "caffe_blob.h" +#include "../../src/io/inst_vector.h" +#include "../../src/io/iter_prefetcher.h" +#include "../../src/operator/cast-inl.h" + +#define CHECK_NEXT_TIMING + +#ifdef CHECK_NEXT_TIMING +#define IF_CHECK_TIMING(__t$) __t$ +#else +#define IF_CHECK_TIMING(__t$) +#endif + +namespace mxnet { +namespace io { + +struct CaffeDataParam : public dmlc::Parameter { + /*! \brief protobuf text */ + ::caffe::LayerParameter prototxt; + /*! \brief number of iterations per epoch */ + int num_examples; + /*! \brief data mode */ + bool flat; + + DMLC_DECLARE_PARAMETER(CaffeDataParam) { + DMLC_DECLARE_FIELD(prototxt).set_default("layer{}") + .describe("Caffe's layer parameter"); + DMLC_DECLARE_FIELD(flat).set_default(false) + .describe("Augmentation Param: Whether to flat the data into 1D."); + DMLC_DECLARE_FIELD(num_examples).set_lower_bound(1).set_default(10000) + .describe("Number of examples in the epoch."); + } +}; + +template +class CaffeDataIter : public IIterator { + public: + explicit CaffeDataIter(int type_flag) : batch_size_(0), channels_(0), width_(1), height_(1) + , type_flag_(type_flag), loc_(0) + {} + virtual ~CaffeDataIter(void) {} + + // intialize iterator loads data in + virtual void Init(const std::vector >& kwargs) { + std::map kmap(kwargs.begin(), kwargs.end()); + param_.InitAllowUnknown(kmap); + + // Caffe seems to understand phase inside an "include {}" block + if (!param_.prototxt.has_phase()) { + if (param_.prototxt.include().size()) { + if (param_.prototxt.include(0).has_phase()) { + param_.prototxt.set_phase(param_.prototxt.include(0).phase()); + } + } + } + + std::string type = param_.prototxt.type(); + caffe_data_layer_ = caffe::LayerRegistry::CreateLayer(param_.prototxt); + CHECK(caffe_data_layer_ != nullptr) << "Failed creating caffe data layer"; + const size_t top_size = param_.prototxt.top_size(); + if (top_size > 0) { + if (top_size > NR_SUPPORTED_TOP_ITEMS) { + LOG(WARNING) + << "Too may \"top\" items, only two (one data, one label) are currently supported"; + } + top_.reserve(top_size); + for (size_t x = 0; x < top_size; ++x) { + ::caffe::Blob *blob = new ::caffe::Blob(); + cleanup_blobs_.push_back(std::unique_ptr<::caffe::Blob>(blob)); + top_.push_back(blob); + } + caffe_data_layer_->SetUp(bottom_, top_); + const std::vector &shape = top_[DATA]->shape(); + const size_t shapeDimCount = shape.size(); + if (shapeDimCount > 0) { + batch_size_ = shape[0]; + if (shapeDimCount > 1) { + channels_ = shape[1]; + if (shapeDimCount > 2) { + width_ = shape[2]; + if (shapeDimCount > 3) { + height_ = shape[3]; + } + } + } + } + + if (top_size > DATA) { + if (param_.flat) { + batch_data_ = TBlob(nullptr, mshadow::Shape2(batch_size_, width_ * height_), + cpu::kDevCPU, type_flag_); + } else { + batch_data_ = TBlob(nullptr, mxnet::TShape(top_[DATA]->shape().begin(), + top_[DATA]->shape().end()), + cpu::kDevCPU, type_flag_); + } + } + out_.data.clear(); + if (top_size > LABEL) { + batch_label_ = TBlob(nullptr, mxnet::TShape(top_[LABEL]->shape().begin(), + top_[LABEL]->shape().end()), + cpu::kDevCPU, type_flag_); + } + out_.batch_size = batch_size_; + } + } + + virtual void BeforeFirst(void) { + loc_ = 0; + } + + virtual bool Next(void) { + // MxNet iterator is expected to return CPU-accessible memory + if (::caffe::Caffe::mode() != ::caffe::Caffe::CPU) { + ::caffe::Caffe::set_mode(::caffe::Caffe::CPU); + CHECK_EQ(::caffe::Caffe::mode(), ::caffe::Caffe::CPU); + } + caffe_data_layer_->Forward(bottom_, top_); + CHECK_GT(batch_size_, 0) << "batch size must be greater than zero"; + CHECK_EQ(out_.batch_size, batch_size_) << "Internal Error: batch size mismatch"; + + if (loc_ + batch_size_ <= param_.num_examples) { + batch_data_.dptr_ = top_[DATA]->mutable_cpu_data(); + batch_label_.dptr_ = top_[LABEL]->mutable_cpu_data(); + + out_.data.clear(); + out_.data.push_back(batch_data_); + out_.data.push_back(batch_label_); + loc_ += batch_size_; + return true; + } + + return false; + } + + virtual const TBlobBatch &Value(void) const { + return out_; + } + + private: + /*! \brief indexes into top_ */ + enum { DATA = 0, LABEL, NR_SUPPORTED_TOP_ITEMS }; + + /*! \brief MNISTCass iter params */ + CaffeDataParam param_; + /*! \brief Shape scalar values */ + index_t batch_size_, channels_, width_, height_; + /*! \brief Caffe data layer */ + boost::shared_ptr > caffe_data_layer_; + /*! \brief batch data blob */ + mxnet::TBlob batch_data_; + /*! \brief batch label blob */ + mxnet::TBlob batch_label_; + /*! \brief Output blob data for this iteration */ + TBlobBatch out_; + /*! \brief Bottom and top connection-point blob data */ + std::vector<::caffe::Blob*> bottom_, top_; + /*! \brief Cleanup these blobs on exit */ + std::list>> cleanup_blobs_; + /*! \brief type flag of the tensor blob */ + const int type_flag_; + /*! \brief Blobs done so far */ + std::atomic loc_; +}; // class CaffeDataIter + +class CaffeDataIterWrapper : public PrefetcherIter { + public: + CaffeDataIterWrapper() : PrefetcherIter(NULL), next_time_(0) {} + virtual ~CaffeDataIterWrapper() { + IF_CHECK_TIMING( + if (next_time_.load() > 0) { + LOG(WARNING) << "Caffe data loader was blocked for " + << next_time_.load() + << " ms waiting for incoming data"; + } + ) + } + virtual void Init(const std::vector >& kwargs) { + // We need to init prefetcher args in order to get dtype + this->param_.InitAllowUnknown(kwargs); + switch (this->param_.dtype) { + case mshadow::kFloat32: + this->loader_.reset(new CaffeDataIter(this->param_.dtype)); + break; + case mshadow::kFloat64: + this->loader_.reset(new CaffeDataIter(this->param_.dtype)); + break; + case mshadow::kFloat16: + LOG(FATAL) << "float16 layer is not supported by caffe"; + return; + default: + LOG(FATAL) << "Unsupported type " << this->param_.dtype; + return; + } + PrefetcherIter::Init(kwargs); + this->param_.prefetch_buffer = 1; + } + virtual void BeforeFirst(void) { + return PrefetcherIter::BeforeFirst(); + } + virtual bool Next(void) { + IF_CHECK_TIMING( + const uint64_t start_time = GetTickCountMS(); + ) + const bool rc = PrefetcherIter::Next(); + IF_CHECK_TIMING( + const uint64_t diff_time = GetTickCountMS() - start_time; + next_time_.fetch_add(diff_time); + ) + return rc; + } + + protected: + IF_CHECK_TIMING( + static uint64_t GetTickCountMS() { + struct timeval tv; + gettimeofday(&tv, 0); + return uint64_t( tv.tv_sec ) * 1000 + tv.tv_usec / 1000; + } + ) + + /*! \brief milliseconds spent in Next() */ + std::atomic next_time_; +}; // class CaffeDataIterWrapper + +DMLC_REGISTER_PARAMETER(CaffeDataParam); + +MXNET_REGISTER_IO_ITER(CaffeDataIter) +.describe("Create MxNet iterator for a Caffe data layer.") +.add_arguments(CaffeDataParam::__FIELDS__()) +.add_arguments(PrefetcherParam::__FIELDS__()) +.set_body([]() { + return new CaffeDataIterWrapper(); +}); + +} // namespace io +} // namespace mxnet + diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 0765827df13a..665a5504ae73 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -25,10 +25,19 @@ namespace io { struct PrefetcherParam : public dmlc::Parameter { /*! \brief number of prefetched batches */ size_t prefetch_buffer; + /*! \brief data type */ + int dtype; + // declare parameters DMLC_DECLARE_PARAMETER(PrefetcherParam) { DMLC_DECLARE_FIELD(prefetch_buffer).set_default(4) .describe("Backend Param: Number of prefetched parameters"); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .set_default(mshadow::default_type_flag) + .describe("Data type."); } }; @@ -36,7 +45,7 @@ struct PrefetcherParam : public dmlc::Parameter { class PrefetcherIter : public IIterator { public: explicit PrefetcherIter(IIterator* base) - : out_(nullptr), loader_(base) { + : loader_(base), out_(nullptr) { } ~PrefetcherIter() { @@ -70,7 +79,9 @@ class PrefetcherIter : public IIterator { (*dptr)->data.resize(batch.data.size()); (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { - (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); + (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, + Context::CPU(), false, + param_.dtype); } } CHECK(batch.data.size() == (*dptr)->data.size()); @@ -102,7 +113,7 @@ class PrefetcherIter : public IIterator { // do recycle if (recycle_queue_.size() == param_.prefetch_buffer) { DataBatch *old_batch = recycle_queue_.front(); - // can be more efficienct on engine + // can be more efficient on engine for (NDArray& arr : old_batch->data) { arr.WaitToWrite(); } @@ -115,17 +126,19 @@ class PrefetcherIter : public IIterator { return *out_; } - private: + protected: /*! \brief prefetcher parameters */ PrefetcherParam param_; - // output data + /*! \brief internal batch loader */ + std::unique_ptr > loader_; + + private: + /*! \brief output data */ DataBatch *out_; - // queue to be recycled + /*! \brief queue to be recycled */ std::queue recycle_queue_; - // backend thread + /*! \brief backend thread */ dmlc::ThreadedIter iter_; - // internal batch loader - std::unique_ptr > loader_; }; } // namespace io } // namespace mxnet diff --git a/tools/caffe_converter/convert_model.py b/tools/caffe_converter/convert_model.py index 891681fb347a..a139db111b64 100644 --- a/tools/caffe_converter/convert_model.py +++ b/tools/caffe_converter/convert_model.py @@ -72,9 +72,11 @@ def main(): wmat_dim = list(layer_blobs[0].shape) wmat = np.array(layer_blobs[0].data).reshape(wmat_dim) bias = np.array(layer_blobs[1].data) - if first_conv: - print 'Swapping BGR of caffe into RGB in mxnet' - wmat[:, [0, 2], :, :] = wmat[:, [2, 0], :, :] + channels = layer_blobs[0].channels; + if channels == 3 or channels == 4: # RGB or RGBA + if first_conv: + print 'Swapping BGR of caffe into RGB in mxnet' + wmat[:, [0, 2], :, :] = wmat[:, [2, 0], :, :] assert(wmat.flags['C_CONTIGUOUS'] is True) assert(bias.flags['C_CONTIGUOUS'] is True)