diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..9ba433b --- /dev/null +++ b/.clang-format @@ -0,0 +1,28 @@ +# This file is used by clang-format to autoformat paddle source code +# +# The clang-format is part of llvm toolchain. +# It need to install llvm and clang to format source code style. +# +# The basic usage is, +# clang-format -i -style=file PATH/TO/SOURCE/CODE +# +# The -style=file implicit use ".clang-format" file located in one of +# parent directory. +# The -i means inplace change. +# +# The document of clang-format is +# http://clang.llvm.org/docs/ClangFormat.html +# http://clang.llvm.org/docs/ClangFormatStyleOptions.html +--- +Language: Cpp +BasedOnStyle: Google +IndentWidth: 2 +TabWidth: 2 +ContinuationIndentWidth: 4 +AccessModifierOffset: -2 # The private/protected/public has no indent in class +Standard: Cpp11 +AllowAllParametersOfDeclarationOnNextLine: true +BinPackParameters: false +BinPackArguments: false +... + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..74c9e2a --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +build/ +third_party/thread diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..41be4e3 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/thread"] + path = third_party/thread + url = https://github.com/progschj/ThreadPool.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..fcf55b7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +- repo: https://github.com/Lucas-C/pre-commit-hooks.git + sha: c25201a00e6b0514370501050cf2a8538ac12270 + hooks: + - id: remove-crlf + files: (?!.*thread)^.*$ +- repo: https://github.com/reyoung/mirrors-yapf.git + sha: v0.13.2 + hooks: + - id: yapf + files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ # Bazel BUILD files follow Python syntax. +- repo: https://github.com/pre-commit/pre-commit-hooks + sha: 7539d8bd1a00a3c1bfd34cdb606d3a6372e83469 + hooks: + - id: check-added-large-files + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + - id: end-of-file-fixer + files: (?!.*thread)^.*$ +- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git + sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29 + hooks: + - id: clang-formater diff --git a/.set_python_path.sh b/.set_python_path.sh new file mode 100755 index 0000000..8e40a24 --- /dev/null +++ b/.set_python_path.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# A simple test driver for cmake. +# set PYTHONPATH before run command. +# Usage: +# ./.set_python_pash.sh -p YOUR_PYTHON_PATH {exec...} +# +# It same as PYTHONPATH=${YOUR_PYTHON_PATH}:$PYTHONPATH {exec...} +# + +PYPATH="" +set -x +while getopts "d:" opt; do + case $opt in + d) + PYPATH=$OPTARG + ;; + esac +done +shift $(($OPTIND - 1)) +echo $PYPATH +export PYTHONPATH=$PYPATH:${PYTHONPATH} +$@ diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..5cbe6d2 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,20 @@ +cmake_minimum_required(VERSION 3.2) + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake") +set(PROJ_ROOT ${CMAKE_SOURCE_DIR}) + +include(flags) + +find_package(PythonLibs 2.7 REQUIRED) +find_package(PythonInterp 2.7 REQUIRED) +find_package(Glog REQUIRED) +find_package(NumPy REQUIRED) + +include_directories(${PROJ_ROOT}) +include_directories(${PYTHON_INCLUDE_DIR}) +include_directories(${LIBGLOG_INCLUDE_DIR}) +include_directories(${PYTHON_NUMPY_INCLUDE_DIR}) + +enable_testing() + +add_subdirectory(transformer) diff --git a/cmake/FindGlog.cmake b/cmake/FindGlog.cmake new file mode 100644 index 0000000..142e2ca --- /dev/null +++ b/cmake/FindGlog.cmake @@ -0,0 +1,24 @@ +# +# Find libglog +# +# LIBGLOG_INCLUDE_DIR - where to find glog/logging.h, etc. +# LIBGLOG_LIBRARY - List of libraries when using libglog. +# LIBGLOG_FOUND - True if libglog found. +# +# from https://github.com/facebook/hhvm/blob/master/CMake/FindGlog.cmake + +IF (LIBGLOG_INCLUDE_DIR) + # Already in cache, be silent + SET(LIBGLOG_FIND_QUIETLY TRUE) +ENDIF () + +FIND_PATH(LIBGLOG_INCLUDE_DIR glog/logging.h) + +FIND_LIBRARY(LIBGLOG_LIBRARY glog) + +# handle the QUIETLY and REQUIRED arguments and set LIBGLOG_FOUND to TRUE if +# all listed variables are TRUE +INCLUDE(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS(LIBGLOG DEFAULT_MSG LIBGLOG_LIBRARY LIBGLOG_INCLUDE_DIR) + +MARK_AS_ADVANCED(LIBGLOG_LIBRARY LIBGLOG_INCLUDE_DIR) \ No newline at end of file diff --git a/cmake/FindNumPy.cmake b/cmake/FindNumPy.cmake new file mode 100644 index 0000000..8cdd642 --- /dev/null +++ b/cmake/FindNumPy.cmake @@ -0,0 +1,38 @@ +# Find the Python NumPy package +# PYTHON_NUMPY_INCLUDE_DIR +# NUMPY_FOUND +# will be set by this script + +cmake_minimum_required(VERSION 2.6) + +if(NOT PYTHON_EXECUTABLE) + if(NumPy_FIND_QUIETLY) + find_package(PythonInterp QUIET) + else() + find_package(PythonInterp) + set(_numpy_out 1) + endif() +endif() + +if (PYTHON_EXECUTABLE) + # write a python script that finds the numpy path + file(WRITE ${PROJECT_BINARY_DIR}/FindNumpyPath.py + "try: import numpy; print(numpy.get_include())\nexcept:pass\n") + + # execute the find script + exec_program("${PYTHON_EXECUTABLE}" ${PROJECT_BINARY_DIR} + ARGS "FindNumpyPath.py" + OUTPUT_VARIABLE NUMPY_PATH) +elseif(_numpy_out) + message(STATUS "Python executable not found.") +endif(PYTHON_EXECUTABLE) + +find_path(PYTHON_NUMPY_INCLUDE_DIR numpy/arrayobject.h + HINTS "${NUMPY_PATH}" "${PYTHON_INCLUDE_PATH}") + +if(PYTHON_NUMPY_INCLUDE_DIR) + set(PYTHON_NUMPY_FOUND 1 CACHE INTERNAL "Python numpy found") +endif(PYTHON_NUMPY_INCLUDE_DIR) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NumPy DEFAULT_MSG PYTHON_NUMPY_INCLUDE_DIR) diff --git a/cmake/flags.cmake b/cmake/flags.cmake new file mode 100644 index 0000000..d77a910 --- /dev/null +++ b/cmake/flags.cmake @@ -0,0 +1,77 @@ +include(CheckCXXCompilerFlag) +include(CheckCCompilerFlag) +include(CheckCXXSymbolExists) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING + "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" + FORCE) +endif() + +function(CheckCompilerCXX11Flag) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) + message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") + endif() + endif() +endfunction() + +CheckCompilerCXX11Flag() +LIST(APPEND CMAKE_CXX_FLAGS -std=c++11) + +# safe_set_flag +# +# Set a compile flag only if compiler is support +# is_c: is C flag or C++ flag, bool type. +# src_list: The list name which the flag name will be append to. +# flag_name: the flag name for compiler, such as '-Werror' '-Wall' etc +# rest arguments: not used. +function(safe_set_flag is_c src_list flag_name) + string(REPLACE "-" "_" safe_name ${flag_name}) + string(REPLACE "=" "_" safe_name ${safe_name}) + if(is_c) + CHECK_C_COMPILER_FLAG(${flag_name} C_COMPILER_SUPPORT_FLAG_${safe_name}) + set(safe_name C_COMPILER_SUPPORT_FLAG_${safe_name}) + else() + CHECK_CXX_COMPILER_FLAG(${flag_name} CXX_COMPILER_SUPPORT_FLAG_${safe_name}) + set(safe_name CXX_COMPILER_SUPPORT_FLAG_${safe_name}) + endif() + if(${safe_name}) + set(${src_list} "${${src_list}} ${flag_name}" PARENT_SCOPE) + endif() +endfunction() + +# helper macro to set cflag +macro(safe_set_cflag src_list flag_name) + safe_set_flag(ON ${src_list} ${flag_name}) +endmacro() + +# helper macro to set cxxflag +macro(safe_set_cxxflag src_list flag_name) + safe_set_flag(OFF ${src_list} ${flag_name}) +endmacro() + +# helper macro to set nvcc flag +macro(safe_set_nvflag flag_name) + string(REPLACE "-" "_" safe_name ${flag_name}) + string(REPLACE "=" "_" safe_name ${safe_name}) + CHECK_C_COMPILER_FLAG(${flag_name} C_COMPILER_SUPPORT_FLAG_${safe_name}) + set(safe_name C_COMPILER_SUPPORT_FLAG_${safe_name}) + if(${safe_name}) + LIST(APPEND CUDA_NVCC_FLAGS -Xcompiler ${flag_name}) + endif() +endmacro() + +set(COMMON_FLAGS + -fPIC + -fno-omit-frame-pointer + -Wextra + -Wno-unused-parameter + -Wno-unused-function + -Wno-error=literal-suffix + -Wno-error=unused-local-typedefs) + +foreach(flag ${COMMON_FLAGS}) + safe_set_cflag(CMAKE_C_FLAGS ${flag}) + safe_set_cxxflag(CMAKE_CXX_FLAGS ${flag}) +endforeach() diff --git a/third_party/.gitignore b/third_party/.gitignore new file mode 100644 index 0000000..8502cd4 --- /dev/null +++ b/third_party/.gitignore @@ -0,0 +1 @@ +thread diff --git a/transformer/CMakeLists.txt b/transformer/CMakeLists.txt new file mode 100644 index 0000000..37826c4 --- /dev/null +++ b/transformer/CMakeLists.txt @@ -0,0 +1,32 @@ +project(DeJpeg CXX C) +set(DEJPEG_LINKER_LIBS "") + +# OpenCV +find_package(OpenCV REQUIRED COMPONENTS core highgui imgproc) +include_directories(${OpenCV_INCLUDE_DIRS}) + +# Boost +set(Boost_NO_SYSTEM_PATHS ON) +if (Boost_NO_SYSTEM_PATHS) + set(BOOST_ROOT $ENV{BOOST_ROOT}) + set(Boost_DIR ${BOOST_ROOT}) + set(Boost_INCLUDE_DIR "${BOOST_ROOT}/include") + set(Boost_LIBRARIES "${BOOST_ROOT}/lib/") +endif (Boost_NO_SYSTEM_PATHS) +find_package(Boost 1.63 COMPONENTS python numpy) +include_directories(${Boost_INCLUDE_DIR}) + +list(APPEND DEJPEG_LINKER_LIBS ${OpenCV_LIBS}) +list(APPEND DEJPEG_LINKER_LIBS ${Boost_LIBRARIES}) +list(APPEND DEJPEG_LINKER_LIBS ${LIBGLOG_LIBRARY}) + +set(DEJPEG_SOURCES + DataTransformer.cpp + PyDecodejpeg.cpp) + +add_library(DeJpeg SHARED ${DEJPEG_SOURCES}) +target_compile_options(DeJpeg BEFORE PRIVATE ${BUILD_FLAGS}) +target_link_libraries(DeJpeg ${DEJPEG_LINKER_LIBS}) +set_target_properties(DeJpeg PROPERTIES PREFIX "") + +add_subdirectory(tests) diff --git a/transformer/DataTransformer.cpp b/transformer/DataTransformer.cpp new file mode 100644 index 0000000..69cdcbc --- /dev/null +++ b/transformer/DataTransformer.cpp @@ -0,0 +1,129 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "DataTransformer.h" + +DataTransformer::DataTransformer( + std::unique_ptr&& config) + : config_(std::move(config)), eng_(time(NULL)) {} + +void DataTransformer::transfromFile(const char* imgFile, float* trg) { + int cvFlag = + config_->isColor_ ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE; + cv::Mat im = cv::imread(imgFile, cvFlag); + if (!im.data) { + LOG(FATAL) << "Could not read image, image shape"; + } + this->transform(im, trg); +} + +void DataTransformer::transfromString(const char* src, + const int size, + float* trg) { + cv::_InputArray imbuf(src, size); + int cvFlag = + config_->isColor_ ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE; + cv::Mat im = cv::imdecode(imbuf, cvFlag); + if (!im.data) { + LOG(FATAL) << "Could not decode image"; + } + this->transform(im, trg); +} + +int DataTransformer::rand(const int min, const int max) { + std::uniform_int_distribution dist(min, max); + return dist(eng_); +} + +// TODO(qingqing): add more data argumentation operation +// and split this function. +void DataTransformer::transform(cv::Mat& cvImgOri, float* target) { + const int imgChannels = cvImgOri.channels(); + const int imgHeight = cvImgOri.rows; + const int imgWidth = cvImgOri.cols; + const bool doMirror = (!config_->isTest_) && rand(0, 1); + int hoff = 0; + int woff = 0; + int th = imgHeight; + int tw = imgWidth; + cv::Mat img; + int imsz = config_->imgSize_; + if (imsz > 0) { + double ratio = imgHeight < imgWidth ? double(imsz) / double(imgHeight) + : double(imsz) / double(imgWidth); + th = int(double(imgHeight) * ratio); + tw = int(double(imgWidth) * ratio); + cv::resize(cvImgOri, img, cv::Size(tw, th)); + } else { + img = cvImgOri; + } + + cv::Mat cv_cropped_img = img; + int cropH = config_->cropHeight_; + int cropW = config_->cropWidth_; + if (cropH && cropW) { + if (!config_->isTest_) { + hoff = rand(0, th - cropH); + woff = rand(0, tw - cropW); + } else { + hoff = (th - cropH) / 2; + woff = (tw - cropW) / 2; + } + cv::Rect roi(woff, hoff, cropW, cropH); + cv_cropped_img = img(roi); + } else { + CHECK_EQ(cropH, imgHeight); + CHECK_EQ(cropW, imgWidth); + } + int height = cropH; + int width = cropW; + int top_index; + float scale = config_->scale_; + float* meanVal = config_->meanValues_; + for (int h = 0; h < height; ++h) { + const uint8_t* ptr = cv_cropped_img.ptr(h); + int img_index = 0; + for (int w = 0; w < width; ++w) { + for (int c = 0; c < imgChannels; ++c) { + if (doMirror) { + top_index = (c * height + h) * width + width - 1 - w; + } else { + top_index = (c * height + h) * width + w; + } + float pixel = static_cast(ptr[img_index++]); + switch (config_->meanType_) { + case CHANNEL_MEAN: { + target[top_index] = (pixel - meanVal[c]) * scale; + break; + } + case ELEMENT_MEAN: { + int mean_index = (c * height + h) * width + w; + target[top_index] = (pixel - meanVal[mean_index]) * scale; + break; + } + case NULL_MEAN: { + target[top_index] = pixel * scale; + break; + } + default: + LOG(FATAL) << "Unsupport type"; + } + } + } + } // target: BGR +} diff --git a/transformer/DataTransformer.h b/transformer/DataTransformer.h new file mode 100644 index 0000000..898abc8 --- /dev/null +++ b/transformer/DataTransformer.h @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef DATATRANSFORMER_H_ +#define DATATRANSFORMER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#define DISABLE_COPY(T) \ + T(T&&) = delete; \ + T(T const&) = delete; \ + void operator=(T const& t) = delete + +enum MeanType { CHANNEL_MEAN = 0, ELEMENT_MEAN = 1, NULL_MEAN = 2 }; + +struct DataTransformerConfig { + bool isTest_; + bool isColor_; + int cropHeight_; + int cropWidth_; + int imgSize_; // short side + MeanType meanType_; + float scale_; + int imgPixels_; // the total pixels of transformed image + float* meanValues_; +}; + +/** + * This is an image processing module with OpenCV, such as + * resizing, scaling, mirroring, substracting the image mean... + */ +class DataTransformer { +public: + DISABLE_COPY(DataTransformer); + + DataTransformer(std::unique_ptr&& config); + virtual ~DataTransformer() {} + + /** + * @brief Applies the transformation on one image Mat. + * + * @param img The input image Mat to be transformed. + * @param target target is used to save the transformed data. + */ + void transform(cv::Mat& img, float* target); + + /** + * @brief Save image Mat as file. + * + * @param filename The file name. + * @param im The image to be saved. + */ + void imsave(std::string filename, cv::Mat& im) const { + cv::imwrite(filename, im); + } + + /** + * @brief Decode the image buffer, then calls transform() function. + * + * @param src The input image buffer. + * @param size The length of string buffer. + * @param trg trg is used to save the transformed data. + */ + void transfromString(const char* src, const int size, float* trg); + + /** + * @brief Load image form image file, then calls transform() function. + * + * @param src The input image file. + * @param trg trg is used to save the transformed data. + */ + void transfromFile(const char* imgFile, float* trg); + +private: + std::unique_ptr config_; + + /** + * @brief Generates a random integer from Uniform({min, min + 1, ..., max}). + * @param min The lower bound (inclusive) value of the random number. + * @param max The upper bound (inclusive) value of the random number. + * + * @return + * A uniformly random integer value from ({min, min + 1, ..., max}). + */ + int rand(const int min, const int max); + std::default_random_engine eng_; + +}; // class DataTransformer + +#endif // DATATRANSFORMER_H_ diff --git a/transformer/Parallel.h b/transformer/Parallel.h new file mode 100644 index 0000000..9383297 --- /dev/null +++ b/transformer/Parallel.h @@ -0,0 +1,143 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include + +#include "DataTransformer.h" +#include "third_party/thread/ThreadPool.h" + +namespace bn = boost::python::numpy; + +class Parallel { +public: + Parallel(int threadNum, + bool isTest, + bool isColor, + int resizeMinSize, + int cropSizeH, + int cropSizeW, + PyObject* meanValues) + : threadPool_(threadNum) { + int channel = isColor ? 3 : 1; + MeanType meanType = NULL_MEAN; + float* mean = NULL; + if (meanValues || meanValues != Py_None) { + if (!PyArray_Check(meanValues)) { + LOG(FATAL) << "Object is not a numpy array"; + } + pyTypeCheck(meanValues); + int size = PyArray_SIZE(reinterpret_cast(meanValues)); + mean = (float*)PyArray_DATA(reinterpret_cast(meanValues)); + meanType = (size == channel) ? CHANNEL_MEAN : meanType; + meanType = + (size == channel * cropSizeH * cropSizeW) ? ELEMENT_MEAN : meanType; + } + + imgPixels_ = channel * cropSizeH * cropSizeW; + + DataTransformerConfig* conf = new DataTransformerConfig; + conf->isTest_ = isTest; + conf->isColor_ = isColor; + conf->cropHeight_ = cropSizeH; + conf->cropWidth_ = cropSizeW; + conf->imgSize_ = resizeMinSize; + conf->meanType_ = meanType; + conf->scale_ = 1.0; + conf->imgPixels_ = imgPixels_; + conf->meanValues_ = mean; + + transformerPtr_ = std::unique_ptr( + new DataTransformer(std::unique_ptr(conf))); + } + + ~Parallel() {} + + int start(boost::python::list& pysrc, PyObject* pylabel, int mode) { + int num = len(pysrc); + int* labels = (int*)PyArray_DATA(reinterpret_cast(pylabel)); + for (int i = 0; i < num; ++i) { + const char* buf = boost::python::extract(pysrc[i]); + int buflen = len(pysrc[i]); + int label = labels[i]; + Py_intptr_t shape[1] = {this->imgPixels_}; + DataTypePtr trg = std::make_shared( + boost::python::numpy::zeros( + 1, shape, boost::python::numpy::dtype::get_builtin()), + 0); + results_.emplace_back( + threadPool_.enqueue([this, buf, buflen, label, trg, mode]() { + trg->second = label; + float* data = (float*)((trg->first).get_data()); + if (mode == 0) { + this->transformerPtr_->transfromString(buf, buflen, data); + } else if (mode == 1) { + this->transformerPtr_->transfromFile(buf, data); + } else { + LOG(FATAL) << "Unsupport mode " << mode; + } + return trg; + })); + } + return 0; + } + + boost::python::tuple get() { + DataTypePtr ret = results_.front().get(); + results_.pop_front(); + return boost::python::make_tuple(ret->first, ret->second); + } + +private: + /** + * @brief Check whether the type of PyObject is valid or not. + */ + void pyTypeCheck(PyObject* o) { + int typenum = PyArray_TYPE(reinterpret_cast(o)); + + // clang-format off + int type = + typenum == NPY_UBYTE ? CV_8U : + typenum == NPY_BYTE ? CV_8S : + typenum == NPY_USHORT ? CV_16U : + typenum == NPY_SHORT ? CV_16S : + typenum == NPY_INT || typenum == NPY_LONG ? CV_32S : + typenum == NPY_FLOAT ? CV_32F : + typenum == NPY_DOUBLE ? CV_64F : -1; + // clang-format on + + if (type < 0) { + LOG(FATAL) << "toMat: Data type = " << type << " is not supported"; + } + } + + int imgPixels_; + + /** + * @brief An object of DataTransformer, which is used to call + * the image processing funtions. + */ + std::unique_ptr transformerPtr_; + + ThreadPool threadPool_; + + typedef std::pair DataType; + typedef std::shared_ptr DataTypePtr; + std::deque> results_; + +}; // Parallel diff --git a/transformer/PyDecodejpeg.cpp b/transformer/PyDecodejpeg.cpp new file mode 100644 index 0000000..ea88ff9 --- /dev/null +++ b/transformer/PyDecodejpeg.cpp @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Parallel.h" + +/** + * DecodeJpeg is an image processing API for interfacing Python and + * C++ code. The Boost Python Library is used to wrap C++ interfaces. + * This class is only an interface and there is no specific calculation. + */ +class DecodeJpeg { +public: + DecodeJpeg(int threadNum, + bool isTest, + bool isColor, + int resize_min_size, + int cropSizeH, + int cropSizeW, + PyObject* meanValues) { + tfhandlerPtr_ = std::make_shared(threadNum, + isTest, + isColor, + resize_min_size, + cropSizeH, + cropSizeW, + meanValues); + } + + ~DecodeJpeg() {} + + /** + * @brief This function calls the start function of Parallel + * to process image with multi-threads. + * @param pysrc The input image list with string type. + * @param pylabel The input label of image. + * Its type is numpy.array with int32. + * @param mode Two mode: + * 0: the input is image buffer + * 1: the input is image file path. + */ + int start(boost::python::list& pysrc, PyObject* pylabel, int mode) { + int ret = tfhandlerPtr_->start(pysrc, pylabel, mode); + return 0; + } + + /** + * @brief Return a tuple: (image, label). + * The image is transformed image. + */ + boost::python::tuple get() { return tfhandlerPtr_->get(); } + +private: + std::shared_ptr tfhandlerPtr_; +}; // DecodeJpeg + +/** + * @brief Initialize the Python interpreter and numpy. + */ +static void initPython() { + Py_Initialize(); + PyOS_sighandler_t sighandler = PyOS_getsig(SIGINT); + import_array(); + PyOS_setsig(SIGINT, sighandler); +} + +/** + * Use Boost.Python to expose C++ interface to Python. + */ +BOOST_PYTHON_MODULE(DeJpeg) { + initPython(); + boost::python::numpy::initialize(); + boost::python::class_( + "DecodeJpeg", + boost::python::init()) + .def("start", &DecodeJpeg::start) + .def("get", &DecodeJpeg::get); +}; diff --git a/transformer/tests/CMakeLists.txt b/transformer/tests/CMakeLists.txt new file mode 100644 index 0000000..0216b91 --- /dev/null +++ b/transformer/tests/CMakeLists.txt @@ -0,0 +1,4 @@ +add_test(NAME test_dejpeg + COMMAND ${PROJ_ROOT}/.set_python_path.sh -d ${CMAKE_CURRENT_BINARY_DIR}/.. + python ${PROJ_ROOT}/transformer/tests/test_dejpeg.py + WORKING_DIRECTORY ${PROJ_ROOT}) diff --git a/transformer/tests/cat.jpg b/transformer/tests/cat.jpg new file mode 100644 index 0000000..47b01db Binary files /dev/null and b/transformer/tests/cat.jpg differ diff --git a/transformer/tests/test_dejpeg.py b/transformer/tests/test_dejpeg.py new file mode 100644 index 0000000..afab3d1 --- /dev/null +++ b/transformer/tests/test_dejpeg.py @@ -0,0 +1,54 @@ +import os +import unittest +import numpy as np +from PIL import Image +import StringIO + +import DeJpeg + + +class TestDataTransformer(unittest.TestCase): + def test_image_buf(self): + im_size = 0 # not resize + crop_size = 128 + tmp_name = './transformer/tests/cat.jpg' + + data = [] + with open(tmp_name) as f: + data.append(f.read()) + + data.append(data[0]) + mean = np.array([103.939, 116.779, 123.68], dtype=np.float32) + + # transform by DeJpeg + op = DeJpeg.DecodeJpeg(2, True, True, im_size, crop_size, crop_size, + mean) + labels = np.array([3, 3], dtype=np.int32) + ret = op.start(data, labels, 0) + self.assertEqual(ret, 0) + lab = np.zeros(1, dtype=np.int32) + im, lab = op.get() + im, lab = op.get() + self.assertEqual(lab, 3) + + # transform by PIL + img = Image.open(tmp_name) + im_array = np.array(img) + + h, w = im_array.shape[:2] + hoff = (h - crop_size) / 2 + woff = (w - crop_size) / 2 + pyim = im_array[hoff:hoff + crop_size, woff:woff + crop_size, :] + + pyim = pyim.astype(np.float32) + pyim = pyim.transpose((2, 0, 1)) + mean = mean[:, np.newaxis, np.newaxis] + pyim = pyim[(2, 1, 0), :, :] + pyim -= mean + pyim = pyim.flatten() + + self.assertEqual(im.all(), pyim.all()) + + +if __name__ == '__main__': + unittest.main()