Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weaver jet flavour inference #188

Merged
merged 54 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
33193b4
First skeletton of the ONNXRuntime wrapper object
forthommel Oct 15, 2021
21380d5
Added a few operators to ONNXRuntime, derivation from TObject
forthommel Feb 16, 2022
fcb6b76
Replaced the operator() by a run() method
forthommel Mar 14, 2022
21be3e6
Registered run method
forthommel Mar 14, 2022
ea94e6a
Added ONNX to runtime libraries search path
forthommel Mar 14, 2022
d36a785
Removed TObject hierarchy
forthommel Mar 14, 2022
9344e43
API slightly adapted ; using a singleton (!)
forthommel Mar 17, 2022
771fc61
Working version
forthommel Mar 23, 2022
8b41130
Added test
forthommel Mar 23, 2022
06f2948
Added an intermediate weaver interface object
forthommel Mar 23, 2022
ef90158
Work on python API
forthommel Mar 23, 2022
40c65b0
Cleanup
forthommel Mar 23, 2022
aa9ffd0
Handling the parsing of Weaver's preprocess json file
forthommel Mar 24, 2022
c50b333
Clang-format
forthommel Mar 24, 2022
49d99b7
Using new jet constituents tool to feed the NN inference
forthommel Mar 25, 2022
7398e6f
Renamed onnx test
forthommel Mar 25, 2022
cb25b04
Improved printout for weaver module output
forthommel Mar 25, 2022
144d3bd
JC definitions simplification through aliases
forthommel Mar 28, 2022
2f3fbf6
Updated ONNX interface to JC
forthommel Apr 4, 2022
acef7a4
Updated ONNX version
forthommel Apr 4, 2022
c304e48
Picking up missing rebase merges
forthommel May 4, 2022
c4c26ad
More fiddling around
forthommel May 20, 2022
8df1dd0
Adapted python tests to new scheme
forthommel Jul 4, 2022
e772757
Inference FW skipping simplified through CMake regexes
forthommel Jul 4, 2022
7390451
Added unit test skeletton for weaver interface
forthommel Jul 4, 2022
e4bb51a
Cleanup of includes
forthommel Jul 4, 2022
6a94797
Updated jet constituents FW
forthommel Jul 14, 2022
2e99636
Intermediate jet flavour utilities
forthommel Jul 15, 2022
cdf5d19
Fixed input format for constituents
forthommel Jul 15, 2022
f9e15dd
Registration of variables names
forthommel Jul 15, 2022
2ceb203
Non-crashing version
forthommel Jul 20, 2022
825c3c4
Fixed test
forthommel Jul 20, 2022
b9a11b9
Cleanup
forthommel Jul 20, 2022
92f3f01
First use of experimental API
forthommel Jul 20, 2022
2bff937
Further simplification through the use of 'experimental' Ort API
forthommel Jul 21, 2022
1c2d3a6
Added a method to retrieve a specific weight for all jets in collection
forthommel Jul 21, 2022
60a62f1
Skip JetFlavourUtils compilation if ONNXRuntime is not found
forthommel Jul 21, 2022
10fb853
Added some tests in flavtagging unit tests
forthommel Jul 21, 2022
809cddc
Added Weaver inference test into CTest collection
forthommel Jul 25, 2022
b19478a
Improved documentation of jet flavour utils
forthommel Jul 25, 2022
d83ae7a
Improved documentation of Weaver interface
forthommel Jul 25, 2022
159e3ea
Only run Weaver inference test if library was linked against ONNXRuntime
forthommel Jul 25, 2022
641590f
Fixed global functions order for jet flavour utils
forthommel Jul 27, 2022
e17553b
Import ONNXRuntime from py-onnx-runtime spack repo
forthommel Aug 1, 2022
d96170f
Retrieve Weaver input test files
forthommel Aug 1, 2022
19e544b
Making JSON parsing more informative when input file is not found
forthommel Aug 1, 2022
8393b09
Added utility to find input test files in unit test
forthommel Aug 1, 2022
a40bf24
Specify input data dir as environment variable for python tests
forthommel Aug 1, 2022
02e231d
Made ONNX linking a configuration parameter
forthommel Aug 1, 2022
b0a00d6
Added WITH_ONNX=ON CMake option, cleaned analysers/dataframe CMake di…
forthommel Aug 1, 2022
346e97d
Fixed CMake logic
forthommel Aug 1, 2022
fa7db53
Propagate WITH_ONNX state to the parent scope
forthommel Aug 1, 2022
8c15282
Implemented tristate (on/auto/off) for the WITH_ONNX variable to acco…
forthommel Aug 1, 2022
5ccbaec
Reverted back WITH_ONNX=ON behaviour for CI
forthommel Aug 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 16 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ option(USE_EXTERNAL_CATCH2 "Link against an external Catch2 v3 static library, o


option(WITH_DD4HEP "Build analyzers that need DD4hep" OFF)
option(WITH_ONNX "Build analyzers that need ONNXRuntime" OFF)
forthommel marked this conversation as resolved.
Show resolved Hide resolved
forthommel marked this conversation as resolved.
Show resolved Hide resolved

#--- Set a better default for installation directory---------------------------
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
Expand Down Expand Up @@ -58,7 +59,21 @@ set(INSTALL_INCLUDE_DIR include CACHE PATH
"Installation directory for header files")



# Grab the test files into a cached directory
if(NOT DEFINED CACHE{TEST_INPUT_DATA_DIR})
message(STATUS "Getting test input files")
execute_process(COMMAND bash ${CMAKE_CURRENT_LIST_DIR}/tests/get_test_inputs.sh
OUTPUT_VARIABLE test_input_data_dir
RESULT_VARIABLE test_inputs_available)
if(NOT "${test_inputs_available}" STREQUAL "0")
message(WARNING "Failed to retrieve input test files. Some tests will need to be skipped.")
unset(TEST_INPUT_DATA_DIR CACHE)
else()
message(STATUS "Test input files stored in ${test_input_data_dir}")
set(TEST_INPUT_DATA_DIR ${test_input_data_dir} CACHE INTERNAL "directory for input test files")
mark_as_advanced(TEST_INPUT_DATA_DIR)
endif()
endif()

#--- add CMake infrastructure --------------------------------------------------
include(cmake/FCCAnalysesCreateConfig.cmake)
Expand Down
35 changes: 30 additions & 5 deletions analyzers/dataframe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,25 @@ include_directories(${EDM4HEP_INCLUDE_DIRS}
${VDT_INCLUDE_DIR}
)

message(STATUS "includes-------------------------- dataframe awkward: ${AWKWARD_INCLUDE}")
message(STATUS "includes-------------------------- dataframe edm4hep: ${EDM4HEP_INCLUDE_DIRS}")
message(STATUS "includes-------------------------- dataframe podio : ${podio_INCLUDE_DIR}")
message(STATUS "includes-------------------------- dataframe fastjet: ${FASTJET_INCLUDE_DIRS}")

message(STATUS "includes-------------------------- dataframe awkward: ${AWKWARD_INCLUDE}")
message(STATUS "includes-------------------------- dataframe edm4hep: ${EDM4HEP_INCLUDE_DIRS}")
message(STATUS "includes-------------------------- dataframe podio : ${podio_INCLUDE_DIR}")
message(STATUS "includes-------------------------- dataframe fastjet: ${FASTJET_INCLUDE_DIRS}")

set(EXTRA_INCLUDE_DIRS)
set(EXTRA_LIBRARIES)
set(EXTRA_ROOT_INCLUDES)

if(${WITH_ONNX})
find_package(ONNXRuntime REQUIRED)
find_package(nlohmann_json QUIET REQUIRED)
message(STATUS "includes-------------------------- dataframe onnxruntime: ${ONNXRUNTIME_INCLUDE_DIRS}")
message(STATUS "includes-------------------------- dataframe nlohmann_json")
list(APPEND EXTRA_INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS})
list(APPEND EXTRA_LIBRARIES ${ONNXRUNTIME_LIBRARIES} nlohmann_json::nlohmann_json)
list(APPEND EXTRA_ROOT_INCLUDES FCCAnalyses/ONNXRuntime.h FCCAnalyses/WeaverInterface.h)
include_directories(${ONNXRUNTIME_INCLUDE_DIRS})
endif()


file(GLOB sources src/*.cc)
Expand All @@ -35,6 +49,14 @@ if(NOT ${WITH_DD4HEP})
list(FILTER headers EXCLUDE REGEX "CaloNtupleizer.h")
list(FILTER sources EXCLUDE REGEX "CaloNtupleizer.cc")
endif()
if(NOT ${WITH_ONNX})
list(FILTER headers EXCLUDE REGEX "ONNXRuntime.h")
list(FILTER sources EXCLUDE REGEX "ONNXRuntime.cc")
list(FILTER headers EXCLUDE REGEX "WeaverInterface.h")
list(FILTER sources EXCLUDE REGEX "WeaverInterface.cc")
list(FILTER headers EXCLUDE REGEX "JetFlavourUtils.h")
list(FILTER sources EXCLUDE REGEX "JetFlavourUtils.cc")
endif()


message(STATUS "CMAKE_CURRENT_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}")
Expand All @@ -47,6 +69,7 @@ target_include_directories(FCCAnalyses PUBLIC
${FASTJET_INCLUDE_DIR}
${acts_INCLUDE_DIR}
${AWKWARD_INCLUDE}
${EXTRA_INCLUDE_DIRS}
)

target_link_libraries(FCCAnalyses
Expand All @@ -65,6 +88,7 @@ target_link_libraries(FCCAnalyses
${LIBAWKWARD}
${CPU-KERNELS}
${LIBDL}
${EXTRA_LIBRARIES}
gfortran # todo: why necessary?
)

Expand All @@ -78,6 +102,7 @@ set_target_properties(FCCAnalyses PROPERTIES

ROOT_GENERATE_DICTIONARY(G__FCCAnalyses
${headers}
${EXTRA_ROOT_INCLUDES}
MODULE FCCAnalyses
LINKDEF FCCAnalyses/LinkDef.h
)
Expand Down
28 changes: 28 additions & 0 deletions analyzers/dataframe/FCCAnalyses/JetFlavourUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef FCCAnalyses_JetFlavourUtils_h
#define FCCAnalyses_JetFlavourUtils_h

#include <ROOT/RVec.hxx>

namespace FCCAnalyses {
namespace JetFlavourUtils {
namespace rv = ROOT::VecOps;
using FCCAnalysesJetConstituentsData = rv::RVec<float>;
using Variables = rv::RVec<FCCAnalysesJetConstituentsData>;

/// Compute all weights given a collection of input variables
/// \note This helper should not be used directly in RDataFrame examples
rv::RVec<rv::RVec<float> > compute_weights(const rv::RVec<Variables>&);

/// Setup the ONNXRuntime instance using Weaver-provided parameters
void setup_weaver(const std::string&, const std::string&, const rv::RVec<std::string>&);
/// Compute all weights given an unspecified collection of input variables
template <typename... Args>
ROOT::VecOps::RVec<ROOT::VecOps::RVec<float> > get_weights(Args&&... args) {
return compute_weights(std::vector<Variables>{std::forward<Args>(args)...});
}
/// Get one specific weight previously computed
rv::RVec<float> get_weight(const rv::RVec<rv::RVec<float> >&, int);
} // namespace JetFlavourUtils
} // namespace FCCAnalyses

#endif
43 changes: 43 additions & 0 deletions analyzers/dataframe/FCCAnalyses/ONNXRuntime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef FCCAnalyses_ONNXRuntime_h
#define FCCAnalyses_ONNXRuntime_h

#include <string>
#include <vector>
#include <map>
#include <memory>

namespace Ort {
class Env;
namespace Experimental {
class Session;
}
} // namespace Ort

class ONNXRuntime {
public:
explicit ONNXRuntime(const std::string& = "", const std::vector<std::string>& = {});
virtual ~ONNXRuntime();

template <typename T>
using Tensor = std::vector<std::vector<T>>;

ONNXRuntime(const ONNXRuntime&) = delete;
ONNXRuntime& operator=(const ONNXRuntime&) = delete;

const std::vector<std::string>& inputNames() const { return input_names_; }

template <typename T>
Tensor<T> run(Tensor<T>&, const Tensor<long>& = {}, unsigned long long = 1ull) const;

private:
size_t variablePos(const std::string&) const;

std::unique_ptr<Ort::Env> env_;
std::unique_ptr<Ort::Experimental::Session> session_;

std::vector<std::string> input_node_strings_, output_node_strings_;
std::vector<std::string> input_names_;
std::map<std::string, std::vector<int64_t>> input_node_dims_, output_node_dims_;
};

#endif
74 changes: 74 additions & 0 deletions analyzers/dataframe/FCCAnalyses/WeaverInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#ifndef FCCAnalyses_WeaverInterface_h
#define FCCAnalyses_WeaverInterface_h

#include "FCCAnalyses/ONNXRuntime.h"
#include "ROOT/RVec.hxx"

namespace FCCAnalyses {
namespace rv = ROOT::VecOps;

class WeaverInterface {
public:
using ConstituentVars = rv::RVec<float>;

/// Initialise an inference model from Weaver output ONNX/JSON files and
/// a list of variables to be provided for each event/jet
explicit WeaverInterface(const std::string& onnx_filename = "",
const std::string& json_filename = "",
const rv::RVec<std::string>& vars = {});

/// Run inference given a list of jet constituents variables
rv::RVec<float> run(const rv::RVec<ConstituentVars>&);

private:
struct PreprocessParams {
struct VarInfo {
VarInfo() {}
VarInfo(float imedian,
float inorm_factor,
float ireplace_inf_value,
float ilower_bound,
float iupper_bound,
float ipad)
: center(imedian),
norm_factor(inorm_factor),
replace_inf_value(ireplace_inf_value),
lower_bound(ilower_bound),
upper_bound(iupper_bound),
pad(ipad) {}

float center{0.};
float norm_factor{1.};
float replace_inf_value{0.};
float lower_bound{-5.};
float upper_bound{5.};
float pad{0.};
};
std::string name;
size_t min_length{0}, max_length{0};
std::vector<std::string> var_names;
std::unordered_map<std::string, VarInfo> var_info_map;
VarInfo info(const std::string& name) const { return var_info_map.at(name); }
void dumpVars() const;
};
std::vector<float> center_norm_pad(const rv::RVec<float>& input,
float center,
float scale,
size_t min_length,
size_t max_length,
float pad_value = 0,
float replace_inf_value = 0,
float min = 0,
float max = -1);
size_t variablePos(const std::string&) const;

std::unique_ptr<ONNXRuntime> onnx_;
std::vector<std::string> variables_names_;
ONNXRuntime::Tensor<long> input_shapes_;
std::vector<unsigned int> input_sizes_;
std::unordered_map<std::string, PreprocessParams> prep_info_map_;
ONNXRuntime::Tensor<float> data_;
};
} // namespace FCCAnalyses

#endif
53 changes: 53 additions & 0 deletions analyzers/dataframe/src/JetFlavourUtils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "FCCAnalyses/JetFlavourUtils.h"
#include "FCCAnalyses/WeaverInterface.h"

#include <memory>

namespace FCCAnalyses {
std::unique_ptr<WeaverInterface> gWeaver;

namespace JetFlavourUtils {
void setup_weaver(const std::string& onnx_filename,
const std::string& json_filename,
const rv::RVec<std::string>& vars) {
gWeaver = std::make_unique<WeaverInterface>(onnx_filename, json_filename, vars);
}

rv::RVec<rv::RVec<float> > compute_weights(const rv::RVec<Variables>& vars) {
if (!gWeaver)
throw std::runtime_error("Weaver interface is not initialised!");
rv::RVec<rv::RVec<float> > out;
if (vars.empty()) // no variables registered
return out;
size_t num_jets = vars.at(0).size();
if (num_jets == 0) // no jets to categorise
return out;
// transform a collection of {var1 -> {jet1 -> {constit1, constit2, ...}, jet2 -> {...}, ...}, var2 -> {...}}
// into a collection of {jet -> {var1 -> {constit1, constit2, ...}, var2 -> {...}, ...}}
for (size_t i = 0; i < num_jets; ++i) {
Variables jet_sc_vars;
size_t num_constits = vars.at(0).at(i).size();
for (size_t k = 0; k < vars.size(); ++k) {
FCCAnalysesJetConstituentsData constit_vars;
for (size_t j = 0; j < num_constits; ++j)
constit_vars.push_back((float)vars.at(k).at(i).at(j));
jet_sc_vars.push_back(constit_vars);
}
out.emplace_back(gWeaver->run(jet_sc_vars));
}
return out;
}

rv::RVec<float> get_weight(const rv::RVec<rv::RVec<float> >& jets_weights, int weight) {
if (weight < 0)
throw std::runtime_error("Invalid index requested for jet flavour weight.");
rv::RVec<float> out;
for (const auto& jet_weights : jets_weights) {
if (weight >= jet_weights.size())
throw std::runtime_error("Flavour weight index exceeds the number of weights registered.");
out.emplace_back(jet_weights.at(weight));
}
return out;
}
} // namespace JetFlavourUtils
} // namespace FCCAnalyses