Skip to content

Commit

Permalink
Multi-threaded construction
Browse files Browse the repository at this point in the history
This is still quite slow because BuRR is not parallelized.
We can only parallelize over the 3 BuRR data structures, which
does not properly utilize the threads. Constructing BuRR takes
about 50% of the construction time, so this is quite relevant.
  • Loading branch information
ByteHamster committed Mar 19, 2024
1 parent f686ff6 commit 4c65c51
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 21 deletions.
13 changes: 12 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,20 @@ target_compile_features(SicHash INTERFACE cxx_std_20)
add_subdirectory(extlib/util EXCLUDE_FROM_ALL)
target_link_libraries(SicHash INTERFACE ByteHamsterUtil)

add_subdirectory(extlib/simpleRibbon)
if(NOT TARGET simpleRibbon)
set(IPS2RA_DISABLE_PARALLEL ON CACHE PATH "ips2ra's FindTBB greps a file that does not exist in recent TBB versions")
add_subdirectory(extlib/simpleRibbon)
find_package(TBB)
target_compile_options(ips2ra INTERFACE -D_REENTRANT)
target_link_libraries(ips2ra INTERFACE pthread atomic TBB::tbb)
endif()
target_link_libraries(SicHash INTERFACE SimpleRibbon ips2ra)

find_package(OpenMP)
if(OpenMP_CXX_FOUND)
target_link_libraries(SicHash INTERFACE OpenMP::OpenMP_CXX)
endif()

if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
################### Benchmark build targets ###################
add_executable(Solvers solvers.cpp)
Expand Down
106 changes: 87 additions & 19 deletions include/SicHash.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ struct SicHashConfig {
* loadFactor < ~0.94 ==> use minimalFanoLowerBits=4
* loadFactor < ~0.97 ==> use minimalFanoLowerBits=5
*/
template<bool minimal=false, size_t ribbonWidth=64, int minimalFanoLowerBits = 3>
template<bool minimal=true, size_t ribbonWidth=64, int minimalFanoLowerBits = 5>
class SicHash {
public:
static constexpr size_t HASH_FUNCTION_BUCKET_ASSIGNMENT = 42;
Expand All @@ -115,7 +115,7 @@ class SicHash {
size_t unnecessaryConstructions = 0;

// Keys parameter must be an std::vector<std::string> or an std::vector<HashedKey>.
SicHash(const auto &keys, SicHashConfig _config)
SicHash(const auto &keys, SicHashConfig _config, size_t threads = 1)
: config(_config),
N(keys.size()),
numSmallTables(N / config.smallTableSize + 1),
Expand All @@ -125,8 +125,18 @@ class SicHash {
std::cout<<"Creating MHCs"<<std::endl;
}
std::vector<std::pair<size_t, HashedKey>> hashedKeys(N);
initialHash(0, N, keys, hashedKeys);
ips2ra::sort(hashedKeys.begin(), hashedKeys.end(),
if (threads == 1) {
initialHash(0, N, keys, hashedKeys);
} else {
size_t keysPerThread = (N + threads) / threads;
#pragma omp parallel for num_threads(threads)
for (size_t i = 0; i < threads; i++) {
size_t from = i * keysPerThread;
size_t to = std::min(N, (i + 1) * keysPerThread);
initialHash(from, to, keys, hashedKeys);
}
}
ips2ra::parallel::sort(hashedKeys.begin(), hashedKeys.end(),
[](const std::pair<size_t, HashedKey> &pair) { return pair.first; });

if (!config.silent) {
Expand All @@ -137,19 +147,68 @@ class SicHash {
}
std::vector<std::vector<std::pair<uint64_t, uint8_t>>> maps; // Avoids conditional jumps later
maps.resize(0b111 + 1);
maps[0b001].reserve(keys.size() * config.class1Percentage());
maps[0b011].reserve(keys.size() * config.class2Percentage());
maps[0b111].reserve(keys.size() * config.class3Percentage());
unnecessaryConstructions = 0;
hashedKeys.emplace_back(numSmallTables + 1, 0); // Sentinel

std::vector<size_t> emptySlots;
if constexpr (minimal) {
emptySlots.reserve(N / config.loadFactor - N);
if (threads == 1) {
if constexpr (minimal) {
emptySlots.reserve(N / config.loadFactor - N);
}
maps[0b001].reserve(keys.size() * config.class1Percentage());
maps[0b011].reserve(keys.size() * config.class2Percentage());
maps[0b111].reserve(keys.size() * config.class3Percentage());
constructSmallTables(0, numSmallTables, hashedKeys, emptySlots, maps);
bucketInfo[numSmallTables].offset = bucketInfo[0].offset;
bucketInfo[0].offset = 0;
} else {
size_t bucketsPerThread = (numSmallTables + threads) / threads;
std::vector<std::vector<size_t>> emptySlotsThreads(threads);
std::vector<std::vector<std::vector<std::pair<uint64_t, uint8_t>>>> mapsThreads;
mapsThreads.resize(threads);
for (size_t i = 0; i < threads; i++) {
if constexpr (minimal) {
emptySlotsThreads[i].reserve((N / threads) / config.loadFactor - N / threads);
}
mapsThreads[i].resize(0b111 + 1);
maps[0b001].reserve(bucketsPerThread * config.smallTableSize * config.class1Percentage());
maps[0b011].reserve(bucketsPerThread * config.smallTableSize * config.class2Percentage());
maps[0b111].reserve(bucketsPerThread * config.smallTableSize * config.class3Percentage());
}

#pragma omp parallel for num_threads(threads)
for (size_t i = 0; i < threads; i++) {
size_t from = i * bucketsPerThread;
size_t to = std::min(numSmallTables, (i + 1) * bucketsPerThread);
constructSmallTables(from, to, hashedKeys, emptySlotsThreads[i], mapsThreads[i]);
}
size_t sizePrefix = 0;
for (size_t i = 0; i < threads; i++) {
size_t size = bucketInfo[i * bucketsPerThread].offset;
bucketInfo[i * bucketsPerThread].offset = sizePrefix;
sizePrefix += size;
}
bucketInfo[0].offset = 0;
bucketInfo[numSmallTables].offset = sizePrefix;
#pragma omp parallel for num_threads(threads)
for (size_t i = 1; i < threads; i++) { // Offsets of first thread do not have to be changed
size_t from = i * bucketsPerThread;
size_t to = std::min(numSmallTables, (i + 1) * bucketsPerThread);
for (size_t k = from + 1; k < to; k++) {
bucketInfo[k].offset += bucketInfo[from].offset;
}
for (size_t k = 0; k < emptySlotsThreads[i].size(); k++) {
emptySlotsThreads[i][k] += bucketInfo[from].offset;
}
}
// Append thread-local arrays to global ones
for (size_t i = 0; i < threads; i++) {
emptySlots.insert(emptySlots.end(), emptySlotsThreads[i].begin(), emptySlotsThreads[i].end());
maps[0b001].insert(maps[0b001].end(), mapsThreads[i][0b001].begin(), mapsThreads[i][0b001].end());
maps[0b011].insert(maps[0b011].end(), mapsThreads[i][0b011].begin(), mapsThreads[i][0b011].end());
maps[0b111].insert(maps[0b111].end(), mapsThreads[i][0b111].begin(), mapsThreads[i][0b111].end());
}
}
constructSmallTables(0, numSmallTables, hashedKeys, emptySlots, maps);
bucketInfo[numSmallTables].offset = bucketInfo[0].offset;
bucketInfo[0].offset = 0;

if (!config.silent) {
std::cout << "Buckets took " << std::chrono::duration_cast<std::chrono::milliseconds>(
Expand All @@ -160,9 +219,17 @@ class SicHash {
std::cout<<"Constructing Ribbon"<<std::endl;
}

ribbon1 = new SimpleRibbon<1, ribbonWidth>(maps[0b001]);
ribbon2 = new SimpleRibbon<2, ribbonWidth>(maps[0b011]);
ribbon3 = new SimpleRibbon<3, ribbonWidth>(maps[0b111]);
#pragma omp parallel for num_threads(threads)
for (size_t i = 0; i < 3; i++) {
if (i == 0) {
ribbon1 = new SimpleRibbon<1, ribbonWidth>(maps[0b001]);
} else if (i == 1) {
ribbon2 = new SimpleRibbon<2, ribbonWidth>(maps[0b011]);
} else if (i == 2) {
ribbon3 = new SimpleRibbon<3, ribbonWidth>(maps[0b111]);
}
}

if (!config.silent) {
std::cout << "Ribbon took " << std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - begin).count() << std::endl;
Expand Down Expand Up @@ -200,13 +267,14 @@ class SicHash {

// Find key to start with
size_t keyIdx = (double(from) / double(numSmallTables)) * N; // Rough estimate
while (keyIdx > 0 && hashedKeys[keyIdx].first >= from) {
keyIdx--;
}
while (hashedKeys[keyIdx].first < from) {
keyIdx++;
}
while (hashedKeys[keyIdx].first > from) {
keyIdx--;
}

assert(hashedKeys[keyIdx].first == from);
assert(keyIdx == 0 || hashedKeys[keyIdx - 1].first < from);
for (size_t bucketIdx = from; bucketIdx < to; bucketIdx++) {
irregularCuckooHashTable.clear();
while (hashedKeys[keyIdx].first == bucketIdx) {
Expand Down
5 changes: 4 additions & 1 deletion sicHashBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ int main(int argc, char** argv) {
std::vector<bool> taken(keys.size() / sicHashTable.config.loadFactor + 100, false); // +100 for rounding
for (std::string &key : keys) {
size_t retrieved = sicHashTable(key);
if (taken[retrieved]) {
if (retrieved > taken.size()) {
std::cerr << "Error: out of range" << std::endl;
return -1;
} else if (taken[retrieved]) {
std::cerr << "Error: not minimal" << std::endl;
return -1;
}
Expand Down

0 comments on commit 4c65c51

Please sign in to comment.