diff --git a/Framework/Core/include/Framework/ASoAHelpers.h b/Framework/Core/include/Framework/ASoAHelpers.h index dbdabcffb74f5..c9ca3ad3a8031 100644 --- a/Framework/Core/include/Framework/ASoAHelpers.h +++ b/Framework/Core/include/Framework/ASoAHelpers.h @@ -13,6 +13,8 @@ #define O2_FRAMEWORK_ASOAHELPERS_H_ #include "Framework/ASoA.h" +#include "Framework/BinningPolicy.h" +#include "Framework/Logger.h" #include "Framework/RuntimeError.h" #include @@ -63,7 +65,7 @@ inline bool diffCategory(std::pair const& a, std::pair -std::vector> doGroupTable(const T& table, const std::string& categoryColumnName, int minCatSize, const T2& outsider) +std::vector> oldGroupTable(const T& table, const std::string& categoryColumnName, int minCatSize, const T2& outsider) { auto columnIndex = table.asArrowTable()->schema()->GetFieldIndex(categoryColumnName); auto chunkedArray = table.asArrowTable()->column(columnIndex); @@ -139,30 +141,129 @@ std::vector> doGroupTable(const T& table, const st return groupedIndices; } -// TODO: With arrow table we lose the filtered information! -// Can we group directly on T? Otherwise we need to extract the selection vector from T -template -auto groupTable(const T& table, const std::string& categoryColumnName, int minCatSize, const T2& outsider) +template +std::vector> doGroupTable(const T& table, const BinningPolicy& binningPolicy, int minCatSize, int outsider) { - auto columnIndex = table.asArrowTable()->schema()->GetFieldIndex(categoryColumnName); - auto dataType = table.asArrowTable()->column(columnIndex)->type(); - if (dataType->id() == arrow::Type::UINT64) { - return doGroupTable(table, categoryColumnName, minCatSize, outsider); + arrow::Table* arrowTable = table.asArrowTable().get(); + + uint64_t ind = 0; + uint64_t selInd = 0; + gsl::span selectedRows; + std::vector> groupedIndices; + + // Separate check to account for Filtered size different from arrow table + if (table.size() == 0) { + return groupedIndices; } - if (dataType->id() == arrow::Type::INT64) { - return doGroupTable(table, categoryColumnName, minCatSize, outsider); + + if constexpr (soa::is_soa_filtered_t::value) { + selectedRows = table.getSelectedRows(); // vector } - if (dataType->id() == arrow::Type::UINT32) { - return doGroupTable(table, categoryColumnName, minCatSize, outsider); + + auto binningColumns = binningPolicy.getColumns(); + auto arrowColumns = o2::framework::binning_helpers::getArrowColumns(arrowTable, binningColumns); + auto chunksCount = arrowColumns[0]->num_chunks(); + // TODO: Are such checks needed or can we safely assume chunks are always the same? + for (int i = 1; i < binningPolicy.mColumnsCount; i++) { + if (arrowColumns[i]->num_chunks() != chunksCount) { + throw o2::framework::runtime_error("Combinations: data size varies between selected columns"); + } } - if (dataType->id() == arrow::Type::INT32) { - return doGroupTable(table, categoryColumnName, minCatSize, outsider); + + for (uint64_t ci = 0; ci < chunksCount; ++ci) { + auto chunks = o2::framework::binning_helpers::getChunks(arrowTable, binningColumns, ci); + auto chunkLength = std::get<0>(chunks)->length(); + // TODO: Are such checks needed or can we safely assume chunks are always the same? + //constexpr auto cn = binningPolicy.mColumnsCount - 1; + //for_([&chunks, &chunkLength](auto i) { + // if (std::get(chunks)->length() != chunkLength) { + // throw o2::framework::runtime_error("Combinations: data size varies between selected columns"); + // } + //}); + + if constexpr (soa::is_soa_filtered_t::value) { + if (selectedRows[ind] >= selInd + chunkLength) { + selInd += chunkLength; + continue; // Go to the next chunk, no value selected in this chunk + } + } + + uint64_t ai = 0; + while (ai < chunkLength) { + if constexpr (soa::is_soa_filtered_t::value) { + ai += selectedRows[ind] - selInd; + selInd = selectedRows[ind]; + } + + auto rowData = o2::framework::binning_helpers::getRowData(arrowTable, binningColumns, ci, ai); + int val = binningPolicy.getBin(rowData); + if (val != outsider) { + groupedIndices.emplace_back(val, ind); + } + ind++; + + if constexpr (soa::is_soa_filtered_t::value) { + if (ind >= selectedRows.size()) { + break; + } + } else { + ai++; + } + } + + if constexpr (soa::is_soa_filtered_t::value) { + if (ind == selectedRows.size()) { + break; + } + } } - if (dataType->id() == arrow::Type::FLOAT) { - return doGroupTable(table, categoryColumnName, minCatSize, outsider); + + // Do a stable sort so that same categories entries are + // grouped together. + std::stable_sort(groupedIndices.begin(), groupedIndices.end()); + + // Remove categories of too small size + if (minCatSize > 1) { + auto catBegin = groupedIndices.begin(); + while (catBegin != groupedIndices.end()) { + auto catEnd = std::upper_bound(catBegin, groupedIndices.end(), *catBegin, sameCategory); + if (std::distance(catBegin, catEnd) < minCatSize) { + catEnd = groupedIndices.erase(catBegin, catEnd); + } + catBegin = catEnd; + } + } + + return groupedIndices; +} + +namespace old_interface +{ +template +struct is_string { + static const bool value = false; +}; + +template +struct is_string> { + static const bool value = true; +}; + +template +struct is_string { + static const bool value = true; +}; +} // namespace old_interface + +template +std::vector> groupTable(const T& table, const BinningPolicy& binningPolicy, int minCatSize, int outsider) +{ + if constexpr (old_interface::is_string::value) { + LOG(warn) << "You are using obsolete interface for block combinations / event mixing, please update"; + return oldGroupTable(table, binningPolicy, minCatSize, outsider); + } else { + return doGroupTable(table, binningPolicy, minCatSize, outsider); } - // FIXME: Should we support other types as well? - throw o2::framework::runtime_error("Combinations: category column must be of integral type"); } // Synchronize categories so as groupedIndices contain elements only of categories common to all tables @@ -411,19 +512,19 @@ struct CombinationsFullIndexPolicy : public CombinationsIndexPolicyBase { }; // For upper and full only -template +template struct CombinationsBlockIndexPolicyBase : public CombinationsIndexPolicyBase { using CombinationType = typename CombinationsIndexPolicyBase::CombinationType; using IndicesType = typename NTupleType::type; - CombinationsBlockIndexPolicyBase(const std::string& categoryColumnName, int categoryNeighbours, const T& outsider) : CombinationsIndexPolicyBase(), mSlidingWindowSize(categoryNeighbours + 1), mCategoryColumnName(categoryColumnName), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider) {} - CombinationsBlockIndexPolicyBase(const std::string& categoryColumnName, int categoryNeighbours, const T& outsider, const Ts&... tables) : CombinationsIndexPolicyBase(tables...), mSlidingWindowSize(categoryNeighbours + 1), mCategoryColumnName(categoryColumnName), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider) + CombinationsBlockIndexPolicyBase(const BinningPolicy& binningPolicy, int categoryNeighbours, const T& outsider) : CombinationsIndexPolicyBase(), mSlidingWindowSize(categoryNeighbours + 1), mBinningPolicy(binningPolicy), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider) {} + CombinationsBlockIndexPolicyBase(const BinningPolicy& binningPolicy, int categoryNeighbours, const T& outsider, const Ts&... tables) : CombinationsIndexPolicyBase(tables...), mSlidingWindowSize(categoryNeighbours + 1), mBinningPolicy(binningPolicy), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider) { if (!this->mIsEnd) { setRanges(tables...); } } - CombinationsBlockIndexPolicyBase(const std::string& categoryColumnName, int categoryNeighbours, const T& outsider, Ts&&... tables) : CombinationsIndexPolicyBase(std::forward(tables)...), mSlidingWindowSize(categoryNeighbours + 1) + CombinationsBlockIndexPolicyBase(const BinningPolicy& binningPolicy, int categoryNeighbours, const T& outsider, Ts&&... tables) : CombinationsIndexPolicyBase(std::forward(tables)...), mSlidingWindowSize(categoryNeighbours + 1), mBinningPolicy(binningPolicy), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider) { if (!this->mIsEnd) { setRanges(); @@ -450,7 +551,7 @@ struct CombinationsBlockIndexPolicyBase : public CombinationsIndexPolicyBasemGroupedIndices[tableIndex++] = groupTable(tables, this->mCategoryColumnName, 1, this->mOutsider)), ...); + ((this->mGroupedIndices[tableIndex++] = groupTable(tables, this->mBinningPolicy, 1, this->mOutsider)), ...); // Synchronize categories across tables syncCategories(this->mGroupedIndices); @@ -479,7 +580,7 @@ struct CombinationsBlockIndexPolicyBase : public CombinationsIndexPolicyBasemGroupedIndices[tableIndex++] = groupTable(x, mCategoryColumnName, 1, mOutsider)), ...); + ((this->mGroupedIndices[tableIndex++] = groupTable(x, this->mBinningPolicy, 1, this->mOutsider)), ...); }, *this->mTables); @@ -502,23 +603,23 @@ struct CombinationsBlockIndexPolicyBase : public CombinationsIndexPolicyBase -struct CombinationsBlockUpperIndexPolicy : public CombinationsBlockIndexPolicyBase { - using CombinationType = typename CombinationsBlockIndexPolicyBase::CombinationType; +template +struct CombinationsBlockUpperIndexPolicy : public CombinationsBlockIndexPolicyBase { + using CombinationType = typename CombinationsBlockIndexPolicyBase::CombinationType; - CombinationsBlockUpperIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T& outsider) : CombinationsBlockIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider) {} - CombinationsBlockUpperIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T& outsider, const Ts&... tables) : CombinationsBlockIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, tables...) + CombinationsBlockUpperIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T& outsider) : CombinationsBlockIndexPolicyBase(binningPolicy, categoryNeighbours, outsider) {} + CombinationsBlockUpperIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T& outsider, const Ts&... tables) : CombinationsBlockIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, tables...) { if (!this->mIsEnd) { setRanges(); } } - CombinationsBlockUpperIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T& outsider, Ts&&... tables) : CombinationsBlockIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, std::forward(tables)...) + CombinationsBlockUpperIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T& outsider, Ts&&... tables) : CombinationsBlockIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, std::forward(tables)...) { if (!this->mIsEnd) { setRanges(); @@ -527,12 +628,12 @@ struct CombinationsBlockUpperIndexPolicy : public CombinationsBlockIndexPolicyBa void setTables(const Ts&... tables) { - CombinationsBlockIndexPolicyBase::setTables(tables...); + CombinationsBlockIndexPolicyBase::setTables(tables...); setRanges(); } void setTables(Ts&&... tables) { - CombinationsBlockIndexPolicyBase::setTables(std::forward(tables)...); + CombinationsBlockIndexPolicyBase::setTables(std::forward(tables)...); setRanges(); } @@ -619,19 +720,19 @@ struct CombinationsBlockUpperIndexPolicy : public CombinationsBlockIndexPolicyBa } }; -template -struct CombinationsBlockFullIndexPolicy : public CombinationsBlockIndexPolicyBase { - using CombinationType = typename CombinationsBlockIndexPolicyBase::CombinationType; +template +struct CombinationsBlockFullIndexPolicy : public CombinationsBlockIndexPolicyBase { + using CombinationType = typename CombinationsBlockIndexPolicyBase::CombinationType; using IndicesType = typename NTupleType::type; - CombinationsBlockFullIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T& outsider) : CombinationsBlockIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider), mCurrentlyFixed(0) {} - CombinationsBlockFullIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T& outsider, const Ts&... tables) : CombinationsBlockIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, tables...), mCurrentlyFixed(0) + CombinationsBlockFullIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T& outsider) : CombinationsBlockIndexPolicyBase(binningPolicy, categoryNeighbours, outsider), mCurrentlyFixed(0) {} + CombinationsBlockFullIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T& outsider, const Ts&... tables) : CombinationsBlockIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, tables...), mCurrentlyFixed(0) { if (!this->mIsEnd) { setRanges(); } } - CombinationsBlockFullIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T& outsider, Ts&&... tables) : CombinationsBlockIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, std::forward(tables)...), mCurrentlyFixed(0) + CombinationsBlockFullIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T& outsider, Ts&&... tables) : CombinationsBlockIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, std::forward(tables)...), mCurrentlyFixed(0) { if (!this->mIsEnd) { setRanges(); @@ -640,12 +741,12 @@ struct CombinationsBlockFullIndexPolicy : public CombinationsBlockIndexPolicyBas void setTables(const Ts&... tables) { - CombinationsBlockIndexPolicyBase::setTables(tables...); + CombinationsBlockIndexPolicyBase::setTables(tables...); setRanges(); } void setTables(Ts&&... tables) { - CombinationsBlockIndexPolicyBase::setTables(std::forward(tables)...); + CombinationsBlockIndexPolicyBase::setTables(std::forward(tables)...); setRanges(); } @@ -751,19 +852,19 @@ struct CombinationsBlockFullIndexPolicy : public CombinationsBlockIndexPolicyBas uint64_t mCurrentlyFixed; }; -template +template struct CombinationsBlockSameIndexPolicyBase : public CombinationsIndexPolicyBase { using CombinationType = typename CombinationsIndexPolicyBase::CombinationType; using IndicesType = typename NTupleType::type; - CombinationsBlockSameIndexPolicyBase(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider, int minWindowSize) : CombinationsIndexPolicyBase(), mSlidingWindowSize(categoryNeighbours + 1), mCategoryColumnName(categoryColumnName), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider), mMinWindowSize(minWindowSize) {} - CombinationsBlockSameIndexPolicyBase(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider, int minWindowSize, const T& table, const Ts&... tables) : CombinationsIndexPolicyBase(table, tables...), mSlidingWindowSize(categoryNeighbours + 1), mCategoryColumnName(categoryColumnName), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider), mMinWindowSize(minWindowSize) + CombinationsBlockSameIndexPolicyBase(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, int minWindowSize) : CombinationsIndexPolicyBase(), mSlidingWindowSize(categoryNeighbours + 1), mBinningPolicy(binningPolicy), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider), mMinWindowSize(minWindowSize) {} + CombinationsBlockSameIndexPolicyBase(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, int minWindowSize, const T& table, const Ts&... tables) : CombinationsIndexPolicyBase(table, tables...), mSlidingWindowSize(categoryNeighbours + 1), mBinningPolicy(binningPolicy), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider), mMinWindowSize(minWindowSize) { if (!this->mIsEnd) { setRanges(table); } } - CombinationsBlockSameIndexPolicyBase(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider, int minWindowSize, T&& table, Ts&&... tables) : CombinationsIndexPolicyBase(std::forward(table), std::forward(tables)...), mSlidingWindowSize(categoryNeighbours + 1) + CombinationsBlockSameIndexPolicyBase(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, int minWindowSize, T&& table, Ts&&... tables) : CombinationsIndexPolicyBase(std::forward(table), std::forward(tables)...), mSlidingWindowSize(categoryNeighbours + 1), mBinningPolicy(binningPolicy), mCategoryNeighbours(categoryNeighbours), mOutsider(outsider), mMinWindowSize(minWindowSize) { if (!this->mIsEnd) { setRanges(); @@ -794,7 +895,7 @@ struct CombinationsBlockSameIndexPolicyBase : public CombinationsIndexPolicyBase return; } - this->mGroupedIndices = groupTable(table, mCategoryColumnName, mMinWindowSize, mOutsider); + this->mGroupedIndices = groupTable(table, mBinningPolicy, mMinWindowSize, mOutsider); if (this->mGroupedIndices.size() == 0) { this->mIsEnd = true; @@ -812,7 +913,7 @@ struct CombinationsBlockSameIndexPolicyBase : public CombinationsIndexPolicyBase return; } - this->mGroupedIndices = groupTable(std::get<0>(*this->mTables), mCategoryColumnName, mMinWindowSize, mOutsider); + this->mGroupedIndices = groupTable(std::get<0>(*this->mTables), mBinningPolicy, mMinWindowSize, mOutsider); if (this->mGroupedIndices.size() == 0) { this->mIsEnd = true; @@ -826,23 +927,23 @@ struct CombinationsBlockSameIndexPolicyBase : public CombinationsIndexPolicyBase IndicesType mCurrentIndices; const uint64_t mSlidingWindowSize; const int mMinWindowSize; - const std::string mCategoryColumnName; + const BinningPolicy mBinningPolicy; const int mCategoryNeighbours; const T1 mOutsider; }; -template -struct CombinationsBlockUpperSameIndexPolicy : public CombinationsBlockSameIndexPolicyBase { - using CombinationType = typename CombinationsBlockSameIndexPolicyBase::CombinationType; +template +struct CombinationsBlockUpperSameIndexPolicy : public CombinationsBlockSameIndexPolicyBase { + using CombinationType = typename CombinationsBlockSameIndexPolicyBase::CombinationType; - CombinationsBlockUpperSameIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider) : CombinationsBlockSameIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, 1) {} - CombinationsBlockUpperSameIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider, const Ts&... tables) : CombinationsBlockSameIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, 1, tables...) + CombinationsBlockUpperSameIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider) : CombinationsBlockSameIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, 1) {} + CombinationsBlockUpperSameIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, const Ts&... tables) : CombinationsBlockSameIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, 1, tables...) { if (!this->mIsEnd) { setRanges(); } } - CombinationsBlockUpperSameIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider, Ts&&... tables) : CombinationsBlockSameIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, 1, std::forward(tables)...) + CombinationsBlockUpperSameIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, Ts&&... tables) : CombinationsBlockSameIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, 1, std::forward(tables)...) { if (!this->mIsEnd) { setRanges(); @@ -851,12 +952,12 @@ struct CombinationsBlockUpperSameIndexPolicy : public CombinationsBlockSameIndex void setTables(const Ts&... tables) { - CombinationsBlockSameIndexPolicyBase::setTables(tables...); + CombinationsBlockSameIndexPolicyBase::setTables(tables...); setRanges(); } void setTables(Ts&&... tables) { - CombinationsBlockSameIndexPolicyBase::setTables(std::forward(tables)...); + CombinationsBlockSameIndexPolicyBase::setTables(std::forward(tables)...); setRanges(); } @@ -927,18 +1028,19 @@ struct CombinationsBlockUpperSameIndexPolicy : public CombinationsBlockSameIndex } }; -template -struct CombinationsBlockStrictlyUpperSameIndexPolicy : public CombinationsBlockSameIndexPolicyBase { - using CombinationType = typename CombinationsBlockSameIndexPolicyBase::CombinationType; +template +struct CombinationsBlockStrictlyUpperSameIndexPolicy : public CombinationsBlockSameIndexPolicyBase { + using CombinationType = typename CombinationsBlockSameIndexPolicyBase::CombinationType; - CombinationsBlockStrictlyUpperSameIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider) : CombinationsBlockSameIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, sizeof...(Ts)) {} - CombinationsBlockStrictlyUpperSameIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider, const Ts&... tables) : CombinationsBlockSameIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, sizeof...(Ts), tables...) + CombinationsBlockStrictlyUpperSameIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider) : CombinationsBlockSameIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, sizeof...(Ts)) {} + CombinationsBlockStrictlyUpperSameIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, const Ts&... tables) : CombinationsBlockSameIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, sizeof...(Ts), tables...) { if (!this->mIsEnd) { setRanges(); } } - CombinationsBlockStrictlyUpperSameIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider, Ts&&... tables) : CombinationsBlockSameIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, sizeof...(Ts) + 1, std::forward(tables)...) + + CombinationsBlockStrictlyUpperSameIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, Ts&&... tables) : CombinationsBlockSameIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, sizeof...(Ts), std::forward(tables)...) { if (!this->mIsEnd) { setRanges(); @@ -947,14 +1049,14 @@ struct CombinationsBlockStrictlyUpperSameIndexPolicy : public CombinationsBlockS void setTables(const Ts&... tables) { - CombinationsBlockSameIndexPolicyBase::setTables(tables...); + CombinationsBlockSameIndexPolicyBase::setTables(tables...); if (!this->mIsEnd) { setRanges(); } } void setTables(Ts&&... tables) { - CombinationsBlockSameIndexPolicyBase::setTables(std::forward(tables)...); + CombinationsBlockSameIndexPolicyBase::setTables(std::forward(tables)...); if (!this->mIsEnd) { setRanges(); } @@ -1030,18 +1132,18 @@ struct CombinationsBlockStrictlyUpperSameIndexPolicy : public CombinationsBlockS } }; -template -struct CombinationsBlockFullSameIndexPolicy : public CombinationsBlockSameIndexPolicyBase { - using CombinationType = typename CombinationsBlockSameIndexPolicyBase::CombinationType; +template +struct CombinationsBlockFullSameIndexPolicy : public CombinationsBlockSameIndexPolicyBase { + using CombinationType = typename CombinationsBlockSameIndexPolicyBase::CombinationType; - CombinationsBlockFullSameIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider) : CombinationsBlockSameIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, 1), mCurrentlyFixed(0) {} - CombinationsBlockFullSameIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider, const Ts&... tables) : CombinationsBlockSameIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, 1, tables...), mCurrentlyFixed(0) + CombinationsBlockFullSameIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider) : CombinationsBlockSameIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, 1), mCurrentlyFixed(0) {} + CombinationsBlockFullSameIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, const Ts&... tables) : CombinationsBlockSameIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, 1, tables...), mCurrentlyFixed(0) { if (!this->mIsEnd) { setRanges(); } } - CombinationsBlockFullSameIndexPolicy(const std::string& categoryColumnName, int categoryNeighbours, const T1& outsider, Ts&&... tables) : CombinationsBlockSameIndexPolicyBase(categoryColumnName, categoryNeighbours, outsider, 1, std::forward(tables)...), mCurrentlyFixed(0) + CombinationsBlockFullSameIndexPolicy(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, Ts&&... tables) : CombinationsBlockSameIndexPolicyBase(binningPolicy, categoryNeighbours, outsider, 1, std::forward(tables)...), mCurrentlyFixed(0) { if (!this->mIsEnd) { setRanges(); @@ -1050,12 +1152,12 @@ struct CombinationsBlockFullSameIndexPolicy : public CombinationsBlockSameIndexP void setTables(const Ts&... tables) { - CombinationsBlockSameIndexPolicyBase::setTables(tables...); + CombinationsBlockSameIndexPolicyBase::setTables(tables...); setRanges(); } void setTables(Ts&&... tables) { - CombinationsBlockSameIndexPolicyBase::setTables(std::forward(tables)...); + CombinationsBlockSameIndexPolicyBase::setTables(std::forward(tables)...); setRanges(); } @@ -1246,54 +1348,97 @@ constexpr bool isSameType() return std::conjunction_v...>; } -template -auto selfCombinations(const char* categoryColumnName, int categoryNeighbours, const T1& outsider, const T2s&... tables) +template +auto selfCombinations(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, const T2s&... tables) { static_assert(isSameType(), "Tables must have the same type for self combinations"); - return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(categoryColumnName, categoryNeighbours, outsider, tables...)); + if constexpr (old_interface::is_string::value) { + LOG(warn) << "You are using obsolete interface for block combinations / event mixing, please update"; + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(std::string(binningPolicy), categoryNeighbours, outsider, tables...)); + } else { + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(binningPolicy, categoryNeighbours, outsider, tables...)); + } } -template -auto selfPairCombinations(const char* categoryColumnName, int categoryNeighbours, const T1& outsider) +template +auto selfPairCombinations(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider) { - return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(categoryColumnName, categoryNeighbours, outsider)); + if constexpr (old_interface::is_string::value) { + LOG(warn) << "You are using obsolete interface for block combinations / event mixing, please update"; + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(std::string(binningPolicy), categoryNeighbours, outsider)); + } else { + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(binningPolicy, categoryNeighbours, outsider)); + } } -template -auto selfPairCombinations(const char* categoryColumnName, int categoryNeighbours, const T1& outsider, const T2& table) +template +auto selfPairCombinations(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, const T2& table) { - return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(categoryColumnName, categoryNeighbours, outsider, table, table)); + if constexpr (old_interface::is_string::value) { + LOG(warn) << "You are using obsolete interface for block combinations / event mixing, please update"; + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(std::string(binningPolicy), categoryNeighbours, outsider, table, table)); + } else { + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(binningPolicy, categoryNeighbours, outsider, table, table)); + } } -template -auto selfTripleCombinations(const char* categoryColumnName, int categoryNeighbours, const T1& outsider) +template +auto selfTripleCombinations(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider) { - return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(categoryColumnName, categoryNeighbours, outsider)); + if constexpr (old_interface::is_string::value) { + LOG(warn) << "You are using obsolete interface for block combinations / event mixing, please update"; + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(std::string(binningPolicy), categoryNeighbours, outsider)); + } else { + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(binningPolicy, categoryNeighbours, outsider)); + } } -template -auto selfTripleCombinations(const char* categoryColumnName, int categoryNeighbours, const T1& outsider, const T2& table) +template +auto selfTripleCombinations(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, const T2& table) { - return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(categoryColumnName, categoryNeighbours, outsider, table, table, table)); + if constexpr (old_interface::is_string::value) { + LOG(warn) << "You are using obsolete interface for block combinations / event mixing, please update"; + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(std::string(binningPolicy), categoryNeighbours, outsider, table, table, table)); + } else { + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(binningPolicy, categoryNeighbours, outsider, table, table, table)); + } } -template -auto combinations(const char* categoryColumnName, int categoryNeighbours, const T1& outsider, const T2s&... tables) +template +auto combinations(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, const T2s&... tables) { - if constexpr (isSameType()) { - return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(categoryColumnName, categoryNeighbours, outsider, tables...)); + if constexpr (old_interface::is_string::value) { + LOG(warn) << "You are using obsolete interface for block combinations / event mixing, please update"; + if constexpr (isSameType()) { + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(std::string(binningPolicy), categoryNeighbours, outsider, tables...)); + } else { + return CombinationsGenerator>(CombinationsBlockUpperIndexPolicy(std::string(binningPolicy), categoryNeighbours, outsider, tables...)); + } } else { - return CombinationsGenerator>(CombinationsBlockUpperIndexPolicy(categoryColumnName, categoryNeighbours, outsider, tables...)); + if constexpr (isSameType()) { + return CombinationsGenerator>(CombinationsBlockStrictlyUpperSameIndexPolicy(binningPolicy, categoryNeighbours, outsider, tables...)); + } else { + return CombinationsGenerator>(CombinationsBlockUpperIndexPolicy(binningPolicy, categoryNeighbours, outsider, tables...)); + } } } -template -auto combinations(const char* categoryColumnName, int categoryNeighbours, const T1& outsider, const o2::framework::expressions::Filter& filter, const T2s&... tables) +template +auto combinations(const BinningPolicy& binningPolicy, int categoryNeighbours, const T1& outsider, const o2::framework::expressions::Filter& filter, const T2s&... tables) { - if constexpr (isSameType()) { - return CombinationsGenerator...>>(CombinationsBlockStrictlyUpperSameIndexPolicy(categoryColumnName, categoryNeighbours, outsider, tables.select(filter)...)); + if constexpr (old_interface::is_string::value) { + LOG(warn) << "You are using obsolete interface for block combinations / event mixing, please update"; + if constexpr (isSameType()) { + return CombinationsGenerator...>>(CombinationsBlockStrictlyUpperSameIndexPolicy(std::string(binningPolicy), categoryNeighbours, outsider, tables.select(filter)...)); + } else { + return CombinationsGenerator...>>(CombinationsBlockUpperIndexPolicy(std::string(binningPolicy), categoryNeighbours, outsider, tables.select(filter)...)); + } } else { - return CombinationsGenerator...>>(CombinationsBlockUpperIndexPolicy(categoryColumnName, categoryNeighbours, outsider, tables.select(filter)...)); + if constexpr (isSameType()) { + return CombinationsGenerator...>>(CombinationsBlockStrictlyUpperSameIndexPolicy(binningPolicy, categoryNeighbours, outsider, tables.select(filter)...)); + } else { + return CombinationsGenerator...>>(CombinationsBlockUpperIndexPolicy(binningPolicy, categoryNeighbours, outsider, tables.select(filter)...)); + } } } diff --git a/Framework/Core/include/Framework/BinningPolicy.h b/Framework/Core/include/Framework/BinningPolicy.h new file mode 100644 index 0000000000000..420812bf8373e --- /dev/null +++ b/Framework/Core/include/Framework/BinningPolicy.h @@ -0,0 +1,204 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +#ifndef FRAMEWORK_BINNINGPOLICY_H +#define FRAMEWORK_BINNINGPOLICY_H + +#include "Framework/HistogramSpec.h" // only for VARIABLE_WIDTH +#include "Framework/ASoAHelpers.h" +#include "Framework/Pack.h" +#include "Framework/ArrowTypes.h" +#include + +namespace o2::framework +{ + +namespace binning_helpers +{ +template +std::array getArrowColumns(arrow::Table* table, pack) +{ + static_assert(std::conjunction_v, "BinningPolicy: only persistent columns accepted (not dynamic and not index ones"); + return std::array{o2::soa::getIndexFromLabel(table, Cs::columnLabel())...}; +} + +template +std::array, sizeof...(Cs)> getChunks(arrow::Table* table, pack, uint64_t ci) +{ + static_assert(std::conjunction_v, "BinningPolicy: only persistent columns accepted (not dynamic and not index ones"); + return std::array, sizeof...(Cs)>{o2::soa::getIndexFromLabel(table, Cs::columnLabel())->chunk(ci)...}; +} + +template +std::tuple getRowData(arrow::Table* table, pack, uint64_t ci, uint64_t ai) +{ + static_assert(std::conjunction_v, "BinningPolicy: only persistent columns accepted (not dynamic and not index ones"); + return std::make_tuple(std::static_pointer_cast>(o2::soa::getIndexFromLabel(table, Cs::columnLabel())->chunk(ci))->raw_values()[ai]...); +} +} // namespace binning_helpers + +template +struct BinningPolicy { + BinningPolicy(std::array, sizeof...(Cs) + 1> bins, bool ignoreOverflows = true) : mBins(bins), mIgnoreOverflows(ignoreOverflows) + { + static_assert(sizeof...(Cs) < 3, "No default binning for more than 3 columns, you need to implement a binning class yourself"); + for (int i = 0; i < sizeof...(Cs) + 1; i++) { + expandConstantBinning(bins[i], i); + } + } + + int getBin(std::tuple const& data) const + { + if (this->mIgnoreOverflows) { + // underflow + if (std::get<0>(data) < this->mBins[0][1]) { // xBins[0] is a dummy VARIABLE_WIDTH + return -1; + } + if constexpr (sizeof...(Cs) > 0) { + if (std::get<1>(data) < this->mBins[1][1]) { // this->mBins[1][0] is a dummy VARIABLE_WIDTH + return -1; + } + } + if constexpr (sizeof...(Cs) > 1) { + if (std::get<2>(data) < this->mBins[2][1]) { // this->mBins[2][0] is a dummy VARIABLE_WIDTH + return -1; + } + } + } + + unsigned int i = 2, j = 2, k = 2; + for (; i < this->mBins[0].size(); i++) { + if (std::get<0>(data) < this->mBins[0][i]) { + + if constexpr (sizeof...(Cs) > 0) { + for (; j < this->mBins[1].size(); j++) { + if (std::get<1>(data) < this->mBins[1][j]) { + + if constexpr (sizeof...(Cs) > 1) { + for (; k < this->mBins[2].size(); k++) { + if (std::get<2>(data) < this->mBins[2][k]) { + return getBinAt(i, j, k); + } + } + if (this->mIgnoreOverflows) { + return -1; + } + } + + // overflow for this->mBins[2] only + return getBinAt(i, j, k); + } + } + + if (this->mIgnoreOverflows) { + return -1; + } + + // overflow for this->mBins[1] only + if constexpr (sizeof...(Cs) > 1) { + for (k = 2; k < this->mBins[2].size(); k++) { + if (std::get<2>(data) < this->mBins[2][k]) { + return getBinAt(i, j, k); + } + } + } + } + + // overflow for this->mBins[2] and this->mBins[1] + return getBinAt(i, j, k); + } + } + + if (this->mIgnoreOverflows) { + // overflow + return -1; + } + + // overflow for this->mBins[0] only + if constexpr (sizeof...(Cs) > 0) { + for (j = 2; j < this->mBins[1].size(); j++) { + if (std::get<1>(data) < this->mBins[1][j]) { + + if constexpr (sizeof...(Cs) > 1) { + for (k = 2; k < this->mBins[2].size(); k++) { + if (std::get<2>(data) < this->mBins[2][k]) { + return getBinAt(i, j, k); + } + } + } + + // overflow for this->mBins[0] and this->mBins[2] + return getBinAt(i, j, k); + } + } + } + + // overflow for this->mBins[0] and this->mBins[1] + if constexpr (sizeof...(Cs) > 1) { + for (k = 2; k < this->mBins[2].size(); k++) { + if (std::get<2>(data) < this->mBins[2][k]) { + return getBinAt(i, j, k); + } + } + } + + // overflow for all bins + return getBinAt(i, j, k); + } + + pack getColumns() const { return pack{}; } + static constexpr int mColumnsCount = sizeof...(Cs) + 1; + + private: + int getBinAt(unsigned int i, unsigned int j, unsigned int k) const + { + if constexpr (sizeof...(Cs) == 0) { + return i - 1; + } else if constexpr (sizeof...(Cs) == 1) { + return (i - 1) + (j - 1) * this->mBins[0].size(); + } else if constexpr (sizeof...(Cs) == 2) { + return (i - 1) + (j - 1) * this->mBins[0].size() + (k - 1) * (this->mBins[0].size() + this->mBins[1].size()); + } else { + return -1; + } + } + + void expandConstantBinning(std::vector const& bins, int ind) + { + if (bins[0] != VARIABLE_WIDTH) { + int nBins = static_cast(bins[0]); + this->mBins[ind].clear(); + this->mBins[ind].resize(nBins + 2); + this->mBins[ind][0] = VARIABLE_WIDTH; + std::iota(std::begin(this->mBins[ind]) + 1, std::end(this->mBins[ind]), bins[2] - bins[1] / nBins); + } + } + + std::array, sizeof...(Cs) + 1> mBins; + bool mIgnoreOverflows; +}; + +template +struct NoBinningPolicy { + // Just take the bin number from the column data + NoBinningPolicy() = default; + + int getBin(std::tuple const& data) const + { + return std::get<0>(data); + } + + pack getColumns() const { return pack{}; } + static constexpr int mColumnsCount = 1; +}; + +} // namespace o2::framework +#endif // FRAMEWORK_BINNINGPOLICY_H_ diff --git a/Framework/Core/include/Framework/GroupedCombinations.h b/Framework/Core/include/Framework/GroupedCombinations.h index a4f64e1a4d94e..6d1a918c44214 100644 --- a/Framework/Core/include/Framework/GroupedCombinations.h +++ b/Framework/Core/include/Framework/GroupedCombinations.h @@ -258,15 +258,15 @@ struct GroupedCombinationsGenerator, As... // 'Pair' and 'Triple' can be used for same kind pair/triple, too, just specify the same type twice template using joinedCollisions = typename soa::Join::table_t; -template , joinedCollisions>> +template , joinedCollisions>> using Pair = GroupedCombinationsGenerator>, A1, A2>; -template , joinedCollisions>> +template , joinedCollisions>> using SameKindPair = GroupedCombinationsGenerator, A, A>; // Aliases for 3-particle correlations -template , joinedCollisions, joinedCollisions>> +template , joinedCollisions, joinedCollisions>> using Triple = GroupedCombinationsGenerator>, A1, A2, A3>; -template , joinedCollisions, joinedCollisions>> +template , joinedCollisions, joinedCollisions>> using SameKindTriple = GroupedCombinationsGenerator, A, A, A>; } // namespace o2::framework diff --git a/Framework/Core/test/benchmark_ASoAHelpers.cxx b/Framework/Core/test/benchmark_ASoAHelpers.cxx index 59775bf02ff01..3ee7c4a062293 100644 --- a/Framework/Core/test/benchmark_ASoAHelpers.cxx +++ b/Framework/Core/test/benchmark_ASoAHelpers.cxx @@ -454,12 +454,13 @@ static void BM_ASoAHelpersCombGenSimplePairsSameCategories(benchmark::State& sta using Test = o2::soa::Table; Test tests{table}; + NoBinningPolicy noBinning; int64_t count = 0; for (auto _ : state) { count = 0; - for (auto& comb : combinations(CombinationsBlockUpperSameIndexPolicy("x", 2, -1, tests, tests))) { + for (auto& comb : combinations(CombinationsBlockUpperSameIndexPolicy(noBinning, 2, -1, tests, tests))) { count++; } benchmark::DoNotOptimize(count); @@ -486,12 +487,13 @@ static void BM_ASoAHelpersCombGenSimpleFivesSameCategories(benchmark::State& sta using Test = o2::soa::Table; Test tests{table}; + NoBinningPolicy noBinning; int64_t count = 0; for (auto _ : state) { count = 0; - for (auto& comb : combinations(CombinationsBlockUpperSameIndexPolicy("x", 5, -1, tests, tests, tests, tests, tests))) { + for (auto& comb : combinations(CombinationsBlockUpperSameIndexPolicy(noBinning, 5, -1, tests, tests, tests, tests, tests))) { count++; } benchmark::DoNotOptimize(count); @@ -518,12 +520,13 @@ static void BM_ASoAHelpersCombGenSimplePairsCategories(benchmark::State& state) using Test = o2::soa::Table; Test tests{table}; + NoBinningPolicy noBinning; int64_t count = 0; for (auto _ : state) { count = 0; - for (auto& comb : combinations(CombinationsBlockUpperIndexPolicy("x", 2, -1, tests, tests))) { + for (auto& comb : combinations(CombinationsBlockUpperIndexPolicy(noBinning, 2, -1, tests, tests))) { count++; } benchmark::DoNotOptimize(count); @@ -550,12 +553,13 @@ static void BM_ASoAHelpersCombGenSimpleFivesCategories(benchmark::State& state) using Test = o2::soa::Table; Test tests{table}; + NoBinningPolicy noBinning; int64_t count = 0; for (auto _ : state) { count = 0; - for (auto& comb : combinations(CombinationsBlockUpperIndexPolicy("x", 2, -1, tests, tests, tests, tests, tests))) { + for (auto& comb : combinations(CombinationsBlockUpperIndexPolicy(noBinning, 2, -1, tests, tests, tests, tests, tests))) { count++; } benchmark::DoNotOptimize(count); @@ -587,12 +591,13 @@ static void BM_ASoAHelpersCombGenCollisionsPairsSameCategories(benchmark::State& auto table = builder.finalize(); o2::aod::Collisions collisions{table}; + NoBinningPolicy noBinning; int64_t count = 0; for (auto _ : state) { count = 0; - for (auto& comb : combinations(CombinationsBlockUpperSameIndexPolicy("fNumContrib", 2, -1, collisions, collisions))) { + for (auto& comb : combinations(CombinationsBlockUpperSameIndexPolicy(noBinning, 2, -1, collisions, collisions))) { count++; } benchmark::DoNotOptimize(count); @@ -624,12 +629,13 @@ static void BM_ASoAHelpersCombGenCollisionsFivesSameCategories(benchmark::State& auto table = builder.finalize(); o2::aod::Collisions collisions{table}; + NoBinningPolicy noBinning; int64_t count = 0; for (auto _ : state) { count = 0; - for (auto& comb : combinations(CombinationsBlockUpperSameIndexPolicy("fNumContrib", 5, -1, collisions, collisions, collisions, collisions, collisions))) { + for (auto& comb : combinations(CombinationsBlockUpperSameIndexPolicy(noBinning, 5, -1, collisions, collisions, collisions, collisions, collisions))) { count++; } benchmark::DoNotOptimize(count); @@ -661,12 +667,13 @@ static void BM_ASoAHelpersCombGenCollisionsPairsCategories(benchmark::State& sta auto table = builder.finalize(); o2::aod::Collisions collisions{table}; + NoBinningPolicy noBinning; int64_t count = 0; for (auto _ : state) { count = 0; - for (auto& comb : combinations(CombinationsBlockUpperIndexPolicy("fNumContrib", 2, -1, collisions, collisions))) { + for (auto& comb : combinations(CombinationsBlockUpperIndexPolicy(noBinning, 2, -1, collisions, collisions))) { count++; } benchmark::DoNotOptimize(count); @@ -698,12 +705,13 @@ static void BM_ASoAHelpersCombGenCollisionsFivesCategories(benchmark::State& sta auto table = builder.finalize(); o2::aod::Collisions collisions{table}; + NoBinningPolicy noBinning; int64_t count = 0; for (auto _ : state) { count = 0; - for (auto& comb : combinations(CombinationsBlockUpperIndexPolicy("fNumContrib", 5, -1, collisions, collisions, collisions, collisions, collisions))) { + for (auto& comb : combinations(CombinationsBlockUpperIndexPolicy(noBinning, 5, -1, collisions, collisions, collisions, collisions, collisions))) { count++; } benchmark::DoNotOptimize(count); diff --git a/Framework/Core/test/test_ASoAHelpers.cxx b/Framework/Core/test/test_ASoAHelpers.cxx index 6dbbf7f3fa416..1e4c963c44011 100644 --- a/Framework/Core/test/test_ASoAHelpers.cxx +++ b/Framework/Core/test/test_ASoAHelpers.cxx @@ -31,42 +31,6 @@ DECLARE_SOA_COLUMN_FULL(FloatZ, floatZ, float, "floatZ"); DECLARE_SOA_DYNAMIC_COLUMN(Sum, sum, [](int32_t x, int32_t y) { return x + y; }); } // namespace test -// Calculate hash for an element based on 2 properties and their bins. -int32_t getHash(const std::vector& yBins, const std::vector& zBins, uint32_t colY, float colZ, bool ignoreOverflows = false) -{ - if (ignoreOverflows) { - if (colY < yBins[0] || colZ < zBins[0]) { - return -1; - } - } - - for (int i = 0; i < yBins.size(); i++) { - if (colY < yBins[i]) { - for (int j = 0; j < zBins.size(); j++) { - if (colZ < zBins[j]) { - return i + j * (yBins.size() + 1); - } - } - // overflow for zBins only - return ignoreOverflows ? -1 : i + zBins.size() * (yBins.size() + 1); - } - } - - if (ignoreOverflows) { - return -1; - } - - // overflow for yBins only - for (int j = 0; j < zBins.size(); j++) { - if (colZ < zBins[j]) { - return yBins.size() + j * (yBins.size() + 1); - } - } - - // overflow for both bins - return (zBins.size() + 1) * (yBins.size() + 1) - 1; -} - BOOST_AUTO_TEST_CASE(IteratorTuple) { TableBuilder builderA; @@ -166,20 +130,10 @@ BOOST_AUTO_TEST_CASE(CombinationsGeneratorConstruction) // Grouped data: // [3, 5] [0, 4, 7], [1, 6], [2] // Assuming bins intervals: [ , ) - std::vector yBins{0, 5, 10, 20, 30, 40, 50, 101}; - std::vector zBins{-7.0f, -5.0f, -3.0f, -1.0f, 1.0f, 3.0f, 5.0f, 7.0f}; - - TableBuilder builderAux; - auto rowWriterAux = builderAux.persist({"x", "y"}); - for (auto it = testsA.begin(); it != testsA.end(); it++) { - auto& elem = *it; - rowWriterAux(0, elem.x(), getHash(yBins, zBins, elem.y(), elem.floatZ())); - } - auto tableAux = builderAux.finalize(); - BOOST_REQUIRE_EQUAL(tableAux->num_rows(), 8); - using TestsAux = o2::soa::Table, test::X, test::Y>; - TestsAux testAux{tableAux}; - BOOST_REQUIRE_EQUAL(8, testAux.size()); + std::vector yBins{VARIABLE_WIDTH, 0, 5, 10, 20, 30, 40, 50, 101}; + std::vector zBins{VARIABLE_WIDTH, -7.0, -5.0, -3.0, -1.0, 1.0, 3.0, 5.0, 7.0}; + + BinningPolicy pairBinning{{yBins, zBins}, false}; CombinationsGenerator>::CombinationsIterator combIt(CombinationsStrictlyUpperIndexPolicy(testsA, testsA)); BOOST_REQUIRE_NE(static_cast(std::get<0>(*(combIt))).getIterator().mCurrentPos, nullptr); @@ -337,10 +291,10 @@ BOOST_AUTO_TEST_CASE(CombinationsGeneratorConstruction) BOOST_REQUIRE_EQUAL(*(static_cast(std::get<4>(endBadCombination)).getIterator().mCurrentPos), 4); BOOST_REQUIRE_EQUAL(static_cast(std::get<4>(endBadCombination)).getIterator().mCurrentChunk, 0); - auto combBlock = combinations(CombinationsBlockStrictlyUpperSameIndexPolicy("y", 2, -1, testAux, testAux)); + auto combBlock = combinations(CombinationsBlockStrictlyUpperSameIndexPolicy(pairBinning, 2, -1, testsA, testsA)); - static_assert(std::is_same_v>::CombinationsIterator>, "Wrong iterator type"); - static_assert(std::is_same_v::CombinationType&>, "Wrong combination type"); + static_assert(std::is_same_v, int32_t, TestA, TestA>>::CombinationsIterator>, "Wrong iterator type"); + static_assert(std::is_same_v, int32_t, TestA, TestA>::CombinationType&>, "Wrong combination type"); auto beginBlockCombination = *(combBlock.begin()); BOOST_REQUIRE_NE(static_cast(std::get<0>(beginBlockCombination)).getIterator().mCurrentPos, nullptr); @@ -975,55 +929,33 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) TestA testA{tableA}; BOOST_REQUIRE_EQUAL(10, testA.size()); + TableBuilder builderAHalf; + auto rowWriterAHalf = builderAHalf.persist({"x", "y", "floatZ"}); + rowWriterAHalf(0, 0, 25, -6.0f); + rowWriterAHalf(0, 1, 18, 0.0f); + rowWriterAHalf(0, 2, 48, 8.0f); + rowWriterAHalf(0, 3, 103, 2.0f); + rowWriterAHalf(0, 4, 28, -6.0f); + auto tableAHalf = builderAHalf.finalize(); + BOOST_REQUIRE_EQUAL(tableAHalf->num_rows(), 5); + + TestA testAHalf{tableAHalf}; + BOOST_REQUIRE_EQUAL(5, testAHalf.size()); + // Grouped data: // [3, 5] [0, 4, 7], [1, 6], [2, 8, 9] // Assuming bins intervals: [ , ) - std::vector yBins{0, 5, 10, 20, 30, 40, 50, 101}; - std::vector zBins{-7.0f, -5.0f, -3.0f, -1.0f, 1.0f, 3.0f, 5.0f, 7.0f}; - - TableBuilder builderAux; - TableBuilder builderAuxHalf; - auto rowWriterAux = builderAux.persist({"x", "y"}); - auto rowWriterAuxHalf = builderAuxHalf.persist({"x", "y"}); - int size = 0; - for (auto it = testA.begin(); it != testA.end(); it++) { - auto& elem = *it; - rowWriterAux(0, elem.x(), getHash(yBins, zBins, elem.y(), elem.floatZ())); - if (size < 5) { - rowWriterAuxHalf(0, elem.x(), getHash(yBins, zBins, elem.y(), elem.floatZ())); - } - size++; - } - auto tableAux = builderAux.finalize(); - auto tableAuxHalf = builderAuxHalf.finalize(); - BOOST_REQUIRE_EQUAL(tableAux->num_rows(), 10); - BOOST_REQUIRE_EQUAL(tableAuxHalf->num_rows(), 5); - - // Auxiliary table: testsAux with id and hash, hash is the category for grouping - using TestsAux = o2::soa::Table, test::X, test::Y>; - TestsAux testAux{tableAux}; - TestsAux testAuxHalf{tableAuxHalf}; - BOOST_REQUIRE_EQUAL(10, testAux.size()); - BOOST_REQUIRE_EQUAL(5, testAuxHalf.size()); - - // Omitting values outside bins - TableBuilder builderAuxNoOverflows; - auto rowWriterAux2 = builderAuxNoOverflows.persist({"x", "y"}); - for (auto it = testA.begin(); it != testA.end(); it++) { - auto& elem = *it; - rowWriterAux2(0, elem.x(), getHash(yBins, zBins, elem.y(), elem.floatZ(), true)); - } - auto tableAuxNoOverflows = builderAuxNoOverflows.finalize(); - BOOST_REQUIRE_EQUAL(tableAuxNoOverflows->num_rows(), 10); + std::vector yBins{VARIABLE_WIDTH, 0, 5, 10, 20, 30, 40, 50, 101}; + std::vector zBins{VARIABLE_WIDTH, -7.0, -5.0, -3.0, -1.0, 1.0, 3.0, 5.0, 7.0}; - TestsAux testAuxNoOverflows{tableAuxNoOverflows}; - BOOST_REQUIRE_EQUAL(10, testAuxNoOverflows.size()); + BinningPolicy pairBinning{{yBins, zBins}, false}; + BinningPolicy pairBinningNoOverflows{{yBins, zBins}, true}; // 2, 3, 5, 8, 9 have overflows in testA std::vector> expectedFullPairsNoOverflows{ {0, 0}, {0, 4}, {4, 0}, {4, 4}, {4, 7}, {7, 4}, {7, 7}, {1, 1}, {1, 6}, {6, 1}, {6, 6}}; int count = 0; - for (auto& [c0, c1] : combinations(CombinationsBlockFullIndexPolicy("y", 1, -1, testAuxNoOverflows, testAuxNoOverflows))) { + for (auto& [c0, c1] : combinations(CombinationsBlockFullIndexPolicy(pairBinningNoOverflows, 1, -1, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedFullPairsNoOverflows[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedFullPairsNoOverflows[count])); count++; @@ -1033,7 +965,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedFullPairs{ {0, 0}, {0, 4}, {0, 7}, {4, 0}, {7, 0}, {4, 4}, {4, 7}, {7, 4}, {7, 7}, {1, 1}, {1, 6}, {6, 1}, {6, 6}, {3, 3}, {3, 5}, {5, 3}, {5, 5}, {2, 2}, {2, 8}, {2, 9}, {8, 2}, {9, 2}, {8, 8}, {8, 9}, {9, 8}, {9, 9}}; count = 0; - for (auto& [c0, c1] : combinations(CombinationsBlockFullIndexPolicy("y", 2, -1, testAux, testAux))) { + for (auto& [c0, c1] : combinations(CombinationsBlockFullIndexPolicy(pairBinning, 2, -1, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedFullPairs[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedFullPairs[count])); count++; @@ -1043,7 +975,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedFullTriples{ {0, 0, 0}, {0, 0, 4}, {0, 0, 7}, {0, 4, 0}, {0, 4, 4}, {0, 4, 7}, {0, 7, 0}, {0, 7, 4}, {0, 7, 7}, {4, 0, 0}, {4, 0, 4}, {4, 0, 7}, {7, 0, 0}, {7, 0, 4}, {7, 0, 7}, {4, 4, 0}, {4, 7, 0}, {7, 4, 0}, {7, 7, 0}, {4, 4, 4}, {4, 4, 7}, {4, 7, 4}, {4, 7, 7}, {7, 4, 4}, {7, 4, 7}, {7, 7, 4}, {7, 7, 7}, {1, 1, 1}, {1, 1, 6}, {1, 6, 1}, {1, 6, 6}, {6, 1, 1}, {6, 1, 6}, {6, 6, 1}, {6, 6, 6}, {3, 3, 3}, {3, 3, 5}, {3, 5, 3}, {3, 5, 5}, {5, 3, 3}, {5, 3, 5}, {5, 5, 3}, {5, 5, 5}, {2, 2, 2}, {2, 2, 8}, {2, 2, 9}, {2, 8, 2}, {2, 8, 8}, {2, 8, 9}, {2, 9, 2}, {2, 9, 8}, {2, 9, 9}, {8, 2, 2}, {8, 2, 8}, {8, 2, 9}, {9, 2, 2}, {9, 2, 8}, {9, 2, 9}, {8, 8, 2}, {8, 9, 2}, {9, 8, 2}, {9, 9, 2}, {8, 8, 8}, {8, 8, 9}, {8, 9, 8}, {8, 9, 9}, {9, 8, 8}, {9, 8, 9}, {9, 9, 8}, {9, 9, 9}}; count = 0; - for (auto& [c0, c1, c2] : combinations(CombinationsBlockFullIndexPolicy("y", 2, -1, testAux, testAux, testAux))) { + for (auto& [c0, c1, c2] : combinations(CombinationsBlockFullIndexPolicy(pairBinning, 2, -1, testA, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedFullTriples[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedFullTriples[count])); BOOST_CHECK_EQUAL(c2.x(), std::get<2>(expectedFullTriples[count])); @@ -1054,7 +986,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedUpperPairs{ {0, 0}, {0, 4}, {0, 7}, {4, 4}, {4, 7}, {7, 7}, {1, 1}, {1, 6}, {6, 6}, {3, 3}, {3, 5}, {5, 5}, {2, 2}, {2, 8}, {2, 9}, {8, 8}, {8, 9}, {9, 9}}; count = 0; - for (auto& [c0, c1] : combinations(CombinationsBlockUpperIndexPolicy("y", 2, -1, testAux, testAux))) { + for (auto& [c0, c1] : combinations(CombinationsBlockUpperIndexPolicy(pairBinning, 2, -1, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedUpperPairs[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedUpperPairs[count])); count++; @@ -1064,7 +996,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedUpperTriples{ {0, 0, 0}, {0, 0, 4}, {0, 4, 4}, {4, 4, 4}, {4, 4, 7}, {4, 7, 7}, {7, 7, 7}, {1, 1, 1}, {1, 1, 6}, {1, 6, 6}, {6, 6, 6}, {3, 3, 3}, {3, 3, 5}, {3, 5, 5}, {5, 5, 5}, {2, 2, 2}, {2, 2, 8}, {2, 8, 8}, {8, 8, 8}, {8, 8, 9}, {8, 9, 9}, {9, 9, 9}}; count = 0; - for (auto& [c0, c1, c2] : combinations(CombinationsBlockUpperIndexPolicy("y", 1, -1, testAux, testAux, testAux))) { + for (auto& [c0, c1, c2] : combinations(CombinationsBlockUpperIndexPolicy(pairBinning, 1, -1, testA, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedUpperTriples[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedUpperTriples[count])); BOOST_CHECK_EQUAL(c2.x(), std::get<2>(expectedUpperTriples[count])); @@ -1074,7 +1006,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedUpperFives{{0, 0, 0, 0, 0}, {0, 0, 0, 0, 4}, {0, 0, 0, 0, 7}, {0, 0, 0, 4, 4}, {0, 0, 0, 4, 7}, {0, 0, 0, 7, 7}, {0, 0, 4, 4, 4}, {0, 0, 4, 4, 7}, {0, 0, 4, 7, 7}, {0, 0, 7, 7, 7}, {0, 4, 4, 4, 4}, {0, 4, 4, 4, 7}, {0, 4, 4, 7, 7}, {0, 4, 7, 7, 7}, {0, 7, 7, 7, 7}, {4, 4, 4, 4, 4}, {4, 4, 4, 4, 7}, {4, 4, 4, 7, 7}, {4, 4, 7, 7, 7}, {4, 7, 7, 7, 7}, {7, 7, 7, 7, 7}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 6}, {1, 1, 1, 6, 6}, {1, 1, 6, 6, 6}, {1, 6, 6, 6, 6}, {6, 6, 6, 6, 6}, {3, 3, 3, 3, 3}, {3, 3, 3, 3, 5}, {3, 3, 3, 5, 5}, {3, 3, 5, 5, 5}, {3, 5, 5, 5, 5}, {5, 5, 5, 5, 5}, {2, 2, 2, 2, 2}, {2, 2, 2, 2, 8}, {2, 2, 2, 2, 9}, {2, 2, 2, 8, 8}, {2, 2, 2, 8, 9}, {2, 2, 2, 9, 9}, {2, 2, 8, 8, 8}, {2, 2, 8, 8, 9}, {2, 2, 8, 9, 9}, {2, 2, 9, 9, 9}, {2, 8, 8, 8, 8}, {2, 8, 8, 8, 9}, {2, 8, 8, 9, 9}, {2, 8, 9, 9, 9}, {2, 9, 9, 9, 9}, {8, 8, 8, 8, 8}, {8, 8, 8, 8, 9}, {8, 8, 8, 9, 9}, {8, 8, 9, 9, 9}, {8, 9, 9, 9, 9}, {9, 9, 9, 9, 9}}; count = 0; - for (auto& [c0, c1, c2, c3, c4] : combinations(CombinationsBlockUpperIndexPolicy("y", 2, -1, testAux, testAux, testAux, testAux, testAux))) { + for (auto& [c0, c1, c2, c3, c4] : combinations(CombinationsBlockUpperIndexPolicy(pairBinning, 2, -1, testA, testA, testA, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedUpperFives[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedUpperFives[count])); BOOST_CHECK_EQUAL(c2.x(), std::get<2>(expectedUpperFives[count])); @@ -1087,7 +1019,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedStrictlyUpperPairsSmaller{ {0, 4}, {4, 7}, {1, 6}, {3, 5}, {2, 8}, {8, 9}}; count = 0; - for (auto& [c0, c1] : combinations(CombinationsBlockStrictlyUpperSameIndexPolicy("y", 1, -1, testAux, testAux))) { + for (auto& [c0, c1] : combinations(CombinationsBlockStrictlyUpperSameIndexPolicy(pairBinning, 1, -1, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedStrictlyUpperPairsSmaller[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedStrictlyUpperPairsSmaller[count])); count++; @@ -1097,7 +1029,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedStrictlyUpperPairs{ {0, 4}, {0, 7}, {4, 7}, {1, 6}, {3, 5}, {2, 8}, {2, 9}, {8, 9}}; count = 0; - for (auto& [c0, c1] : combinations(CombinationsBlockStrictlyUpperSameIndexPolicy("y", 2, -1, testAux, testAux))) { + for (auto& [c0, c1] : combinations(CombinationsBlockStrictlyUpperSameIndexPolicy(pairBinning, 2, -1, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedStrictlyUpperPairs[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedStrictlyUpperPairs[count])); count++; @@ -1107,7 +1039,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedStrictlyUpperTriples{ {0, 4, 7}, {2, 8, 9}}; count = 0; - for (auto& [c0, c1, c2] : combinations(CombinationsBlockStrictlyUpperSameIndexPolicy("y", 2, -1, testAux, testAux, testAux))) { + for (auto& [c0, c1, c2] : combinations(CombinationsBlockStrictlyUpperSameIndexPolicy(pairBinning, 2, -1, testA, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedStrictlyUpperTriples[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedStrictlyUpperTriples[count])); BOOST_CHECK_EQUAL(c2.x(), std::get<2>(expectedStrictlyUpperTriples[count])); @@ -1116,7 +1048,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) BOOST_CHECK_EQUAL(count, expectedStrictlyUpperTriples.size()); count = 0; - for (auto& [c0, c1, c2, c3, c4] : combinations(CombinationsBlockStrictlyUpperSameIndexPolicy("y", 1, -1, testAux, testAux, testAux, testAux, testAux))) { + for (auto& [c0, c1, c2, c3, c4] : combinations(CombinationsBlockStrictlyUpperSameIndexPolicy(pairBinning, 1, -1, testA, testA, testA, testA, testA))) { count++; } BOOST_CHECK_EQUAL(count, 0); @@ -1125,7 +1057,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedFullPairsFirstSmaller{ {0, 0}, {0, 4}, {4, 0}, {4, 4}, {4, 7}, {1, 1}, {1, 6}, {3, 3}, {3, 5}, {2, 2}, {2, 8}}; count = 0; - for (auto& [x0, x1] : combinations(CombinationsBlockFullIndexPolicy("y", 1, -1, testAuxHalf, testAux))) { + for (auto& [x0, x1] : combinations(CombinationsBlockFullIndexPolicy(pairBinning, 1, -1, testAHalf, testA))) { BOOST_CHECK_EQUAL(x0.x(), std::get<0>(expectedFullPairsFirstSmaller[count])); BOOST_CHECK_EQUAL(x1.x(), std::get<1>(expectedFullPairsFirstSmaller[count])); count++; @@ -1135,7 +1067,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) count = 0; std::vector> expectedFullPairsSecondSmaller{ {0, 0}, {0, 4}, {4, 0}, {4, 4}, {7, 4}, {1, 1}, {6, 1}, {3, 3}, {5, 3}, {2, 2}, {8, 2}}; - for (auto& [x0, x1] : combinations(CombinationsBlockFullIndexPolicy("y", 1, -1, testAux, testAuxHalf))) { + for (auto& [x0, x1] : combinations(CombinationsBlockFullIndexPolicy(pairBinning, 1, -1, testA, testAHalf))) { BOOST_CHECK_EQUAL(x0.x(), std::get<0>(expectedFullPairsSecondSmaller[count])); BOOST_CHECK_EQUAL(x1.x(), std::get<1>(expectedFullPairsSecondSmaller[count])); count++; @@ -1145,7 +1077,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) std::vector> expectedUpperPairsFirstSmaller{ {0, 0}, {0, 4}, {4, 4}, {4, 7}, {1, 1}, {1, 6}, {3, 3}, {3, 5}, {2, 2}, {2, 8}}; count = 0; - for (auto& [x0, x1] : combinations(CombinationsBlockUpperIndexPolicy("y", 1, -1, testAuxHalf, testAux))) { + for (auto& [x0, x1] : combinations(CombinationsBlockUpperIndexPolicy(pairBinning, 1, -1, testAHalf, testA))) { BOOST_CHECK_EQUAL(x0.x(), std::get<0>(expectedUpperPairsFirstSmaller[count])); BOOST_CHECK_EQUAL(x1.x(), std::get<1>(expectedUpperPairsFirstSmaller[count])); count++; @@ -1155,7 +1087,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) count = 0; std::vector> expectedUpperPairsSecondSmaller{ {0, 0}, {0, 4}, {4, 4}, {1, 1}, {3, 3}, {2, 2}}; - for (auto& [x0, x1] : combinations(CombinationsBlockUpperIndexPolicy("y", 1, -1, testAux, testAuxHalf))) { + for (auto& [x0, x1] : combinations(CombinationsBlockUpperIndexPolicy(pairBinning, 1, -1, testA, testAHalf))) { BOOST_CHECK_EQUAL(x0.x(), std::get<0>(expectedUpperPairsSecondSmaller[count])); BOOST_CHECK_EQUAL(x1.x(), std::get<1>(expectedUpperPairsSecondSmaller[count])); count++; @@ -1164,7 +1096,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) // Using same index combinations for better performance count = 0; - for (auto& [c0, c1] : combinations(CombinationsBlockFullSameIndexPolicy("y", 2, -1, testAux, testAux))) { + for (auto& [c0, c1] : combinations(CombinationsBlockFullSameIndexPolicy(pairBinning, 2, -1, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedFullPairs[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedFullPairs[count])); count++; @@ -1172,7 +1104,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) BOOST_CHECK_EQUAL(count, expectedFullPairs.size()); count = 0; - for (auto& [c0, c1, c2] : combinations(CombinationsBlockFullSameIndexPolicy("y", 2, -1, testAux, testAux, testAux))) { + for (auto& [c0, c1, c2] : combinations(CombinationsBlockFullSameIndexPolicy(pairBinning, 2, -1, testA, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedFullTriples[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedFullTriples[count])); BOOST_CHECK_EQUAL(c2.x(), std::get<2>(expectedFullTriples[count])); @@ -1181,7 +1113,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) BOOST_CHECK_EQUAL(count, expectedFullTriples.size()); count = 0; - for (auto& [c0, c1] : combinations(CombinationsBlockUpperSameIndexPolicy("y", 2, -1, testAux, testAux))) { + for (auto& [c0, c1] : combinations(CombinationsBlockUpperSameIndexPolicy(pairBinning, 2, -1, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedUpperPairs[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedUpperPairs[count])); count++; @@ -1189,7 +1121,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) BOOST_CHECK_EQUAL(count, expectedUpperPairs.size()); count = 0; - for (auto& [c0, c1, c2] : combinations(CombinationsBlockUpperSameIndexPolicy("y", 1, -1, testAux, testAux, testAux))) { + for (auto& [c0, c1, c2] : combinations(CombinationsBlockUpperSameIndexPolicy(pairBinning, 1, -1, testA, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedUpperTriples[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedUpperTriples[count])); BOOST_CHECK_EQUAL(c2.x(), std::get<2>(expectedUpperTriples[count])); @@ -1198,7 +1130,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) BOOST_CHECK_EQUAL(count, expectedUpperTriples.size()); count = 0; - for (auto& [c0, c1, c2, c3, c4] : combinations(CombinationsBlockUpperSameIndexPolicy("y", 2, -1, testAux, testAux, testAux, testAux, testAux))) { + for (auto& [c0, c1, c2, c3, c4] : combinations(CombinationsBlockUpperSameIndexPolicy(pairBinning, 2, -1, testA, testA, testA, testA, testA))) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedUpperFives[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedUpperFives[count])); BOOST_CHECK_EQUAL(c2.x(), std::get<2>(expectedUpperFives[count])); @@ -1209,7 +1141,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) BOOST_CHECK_EQUAL(count, expectedUpperFives.size()); count = 0; - for (auto& [c0, c1] : selfCombinations("y", 2, -1, testAux, testAux)) { + for (auto& [c0, c1] : selfCombinations(pairBinning, 2, -1, testA, testA)) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedStrictlyUpperPairs[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedStrictlyUpperPairs[count])); count++; @@ -1217,7 +1149,7 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) BOOST_CHECK_EQUAL(count, expectedStrictlyUpperPairs.size()); count = 0; - for (auto& [c0, c1, c2] : selfCombinations("y", 2, -1, testAux, testAux, testAux)) { + for (auto& [c0, c1, c2] : selfCombinations(pairBinning, 2, -1, testA, testA, testA)) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedStrictlyUpperTriples[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedStrictlyUpperTriples[count])); BOOST_CHECK_EQUAL(c2.x(), std::get<2>(expectedStrictlyUpperTriples[count])); @@ -1226,10 +1158,39 @@ BOOST_AUTO_TEST_CASE(BlockCombinations) BOOST_CHECK_EQUAL(count, expectedStrictlyUpperTriples.size()); count = 0; - for (auto& [c0, c1, c2, c3, c4] : selfCombinations("y", 2, -1, testAux, testAux, testAux, testAux, testAux)) { + for (auto& [c0, c1, c2, c3, c4] : selfCombinations(pairBinning, 2, -1, testA, testA, testA, testA, testA)) { count++; } BOOST_CHECK_EQUAL(count, 0); + + // Testing bin calculations for triple binning + // Grouped data: + // [3, 5] [0, 4], [7], [1, 6], [2], [8, 9] + // Assuming bins intervals: [ , ) + std::vector xBins{VARIABLE_WIDTH, 0, 7, 10}; + BinningPolicy tripleBinning{{xBins, yBins, zBins}, false}; + BinningPolicy tripleBinningNoOverflows{{xBins, yBins, zBins}, true}; + + // 2, 3, 5, 8, 9 have overflows in testA + std::vector> expectedFullPairsTripleBinningNoOverflows{ + {0, 0}, {0, 4}, {4, 0}, {4, 4}, {7, 7}, {1, 1}, {1, 6}, {6, 1}, {6, 6}}; + count = 0; + for (auto& [c0, c1] : combinations(CombinationsBlockFullIndexPolicy(tripleBinningNoOverflows, 1, -1, testA, testA))) { + BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedFullPairsTripleBinningNoOverflows[count])); + BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedFullPairsTripleBinningNoOverflows[count])); + count++; + } + BOOST_CHECK_EQUAL(count, expectedFullPairsTripleBinningNoOverflows.size()); + + std::vector> expectedFullPairsTripleBinning{ + {0, 0}, {0, 4}, {4, 0}, {4, 4}, {7, 7}, {1, 1}, {1, 6}, {6, 1}, {6, 6}, {3, 3}, {3, 5}, {5, 3}, {5, 5}, {2, 2}, {8, 8}, {8, 9}, {9, 8}, {9, 9}}; + count = 0; + for (auto& [c0, c1] : combinations(CombinationsBlockFullIndexPolicy(tripleBinning, 2, -1, testA, testA))) { + BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedFullPairsTripleBinning[count])); + BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedFullPairsTripleBinning[count])); + count++; + } + BOOST_CHECK_EQUAL(count, expectedFullPairsTripleBinning.size()); } BOOST_AUTO_TEST_CASE(CombinationsHelpers) @@ -1312,27 +1273,15 @@ BOOST_AUTO_TEST_CASE(CombinationsHelpers) // Grouped data: // [3, 5] [0, 4, 7], [1, 6], [2, 8, 9] // Assuming bins intervals: [ , ) - std::vector yBins{0, 5, 10, 20, 30, 40, 50, 101}; - std::vector zBins{-7.0f, -5.0f, -3.0f, -1.0f, 1.0f, 3.0f, 5.0f, 7.0f}; - - TableBuilder builderAux; - auto rowWriterAux = builderAux.persist({"x", "y"}); - for (auto it = testB.begin(); it != testB.end(); it++) { - auto& elem = *it; - rowWriterAux(0, elem.x(), getHash(yBins, zBins, elem.y(), elem.floatZ())); - } - auto tableAux = builderAux.finalize(); - BOOST_REQUIRE_EQUAL(tableAux->num_rows(), 10); + std::vector yBins{VARIABLE_WIDTH, 0, 5, 10, 20, 30, 40, 50, 101}; + std::vector zBins{VARIABLE_WIDTH, -7.0, -5.0, -3.0, -1.0, 1.0, 3.0, 5.0, 7.0}; - // Auxiliary table: testsAux with id and hash, hash is the category for grouping - using TestsAux = o2::soa::Table, test::X, test::Y>; - TestsAux testAux{tableAux}; - BOOST_REQUIRE_EQUAL(10, testAux.size()); + BinningPolicy pairBinning{{yBins, zBins}, false}; std::vector> expectedStrictlyUpperPairs{ {0, 4}, {0, 7}, {4, 7}, {1, 6}, {3, 5}, {2, 8}, {2, 9}, {8, 9}}; count = 0; - for (auto& [c0, c1] : selfPairCombinations("y", 2, -1, testAux)) { + for (auto& [c0, c1] : selfPairCombinations(pairBinning, 2, -1, testB)) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedStrictlyUpperPairs[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedStrictlyUpperPairs[count])); count++; @@ -1342,7 +1291,7 @@ BOOST_AUTO_TEST_CASE(CombinationsHelpers) std::vector> expectedStrictlyUpperTriples{ {0, 4, 7}, {2, 8, 9}}; count = 0; - for (auto& [c0, c1, c2] : selfTripleCombinations("y", 2, -1, testAux)) { + for (auto& [c0, c1, c2] : selfTripleCombinations(pairBinning, 2, -1, testB)) { BOOST_CHECK_EQUAL(c0.x(), std::get<0>(expectedStrictlyUpperTriples[count])); BOOST_CHECK_EQUAL(c1.x(), std::get<1>(expectedStrictlyUpperTriples[count])); BOOST_CHECK_EQUAL(c2.x(), std::get<2>(expectedStrictlyUpperTriples[count])); @@ -1354,6 +1303,7 @@ BOOST_AUTO_TEST_CASE(CombinationsHelpers) BOOST_AUTO_TEST_CASE(ConstructorsWithoutTables) { using TestA = o2::soa::Table, test::X, test::Y>; + NoBinningPolicy noBinning; int count = 0; for (auto& [t0, t1] : pairCombinations()) { @@ -1368,13 +1318,13 @@ BOOST_AUTO_TEST_CASE(ConstructorsWithoutTables) BOOST_CHECK_EQUAL(count, 0); count = 0; - for (auto& [c0, c1] : selfPairCombinations("y", 2, -1)) { + for (auto& [c0, c1] : selfPairCombinations, int, TestA>(noBinning, 2, -1)) { count++; } BOOST_CHECK_EQUAL(count, 0); count = 0; - for (auto& [c0, c1, c2] : selfTripleCombinations("y", 2, -1)) { + for (auto& [c0, c1, c2] : selfTripleCombinations, int, TestA>(noBinning, 2, -1)) { count++; } BOOST_CHECK_EQUAL(count, 0);