Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly track used entries in UniformSort #250

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 24 additions & 30 deletions src/uniformsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ namespace UniformSort {

inline int64_t const BUF_SIZE = 262144;

inline static bool IsPositionEmpty(const uint8_t *memory, uint32_t const entry_len)
{
for (uint32_t i = 0; i < entry_len; i++)
if (memory[i] != 0)
return false;
return true;
}

inline void SortToMemory(
FileDisk &input_disk,
uint64_t const input_disk_begin,
Expand All @@ -44,60 +36,62 @@ namespace UniformSort {
uint64_t const num_entries,
uint32_t const bits_begin)
{
uint64_t const memory_len = Util::RoundSize(num_entries) * entry_len;
uint64_t const rounded_entries = Util::RoundSize(num_entries);
auto const swap_space = std::make_unique<uint8_t[]>(entry_len);
auto const buffer = std::make_unique<uint8_t[]>(BUF_SIZE);
uint64_t bucket_length = 0;
// The number of buckets needed (the smallest power of 2 greater than 2 * num_entries).
while ((1ULL << bucket_length) < 2 * num_entries) bucket_length++;
memset(memory, 0, memory_len);
bitfield is_used(rounded_entries);

uint64_t read_pos = input_disk_begin;
uint64_t buf_size = 0;
uint64_t buf_ptr = 0;
uint64_t swaps = 0;
uint64_t buf_ofs = 0;
for (uint64_t i = 0; i < num_entries; i++) {
if (buf_size == 0) {
// If read buffer is empty, read from disk and refill it.
buf_size = std::min((uint64_t)BUF_SIZE / entry_len, num_entries - i);
buf_ptr = 0;
buf_ofs = 0;
input_disk.Read(read_pos, buffer.get(), buf_size * entry_len);
read_pos += buf_size * entry_len;
}
buf_size--;
// First unique bits in the entry give the expected position of it in the sorted array.
// We take 'bucket_length' bits starting with the first unique one.
uint64_t pos =
Util::ExtractNum(buffer.get() + buf_ptr, entry_len, bits_begin, bucket_length) *
entry_len;
uint64_t idx =
Util::ExtractNum(buffer.get() + buf_ofs, entry_len, bits_begin, bucket_length);
uint64_t mem_ofs = idx * entry_len;
// As long as position is occupied by a previous entry...
while (!IsPositionEmpty(memory + pos, entry_len) && pos < memory_len) {
while (is_used.get(idx) && idx < rounded_entries) {
// ...store there the minimum between the two and continue to push the higher one.
if (Util::MemCmpBits(
memory + pos, buffer.get() + buf_ptr, entry_len, bits_begin) > 0) {
memcpy(swap_space.get(), memory + pos, entry_len);
memcpy(memory + pos, buffer.get() + buf_ptr, entry_len);
memcpy(buffer.get() + buf_ptr, swap_space.get(), entry_len);
swaps++;
memory + mem_ofs, buffer.get() + buf_ofs, entry_len, bits_begin) > 0) {
memcpy(swap_space.get(), memory + mem_ofs, entry_len);
memcpy(memory + mem_ofs, buffer.get() + buf_ofs, entry_len);
memcpy(buffer.get() + buf_ofs, swap_space.get(), entry_len);
}
pos += entry_len;
idx++;
mem_ofs += entry_len;
}
// Push the entry in the first free spot.
memcpy(memory + pos, buffer.get() + buf_ptr, entry_len);
buf_ptr += entry_len;
memcpy(memory + mem_ofs, buffer.get() + buf_ofs, entry_len);
is_used.set(idx);
buf_ofs += entry_len;
}
uint64_t entries_written = 0;
uint64_t entries_ofs = 0;
// Search the memory buffer for occupied entries.
for (uint64_t pos = 0; entries_written < num_entries && pos < memory_len;
pos += entry_len) {
if (!IsPositionEmpty(memory + pos, entry_len)) {
for (uint64_t idx = 0, mem_ofs = 0; entries_written < num_entries && idx < rounded_entries;
idx++, mem_ofs += entry_len) {
if (is_used.get(idx)) {
// We've found an entry.
// write the stored entry itself.
memcpy(
memory + entries_written * entry_len,
memory + pos,
memory + entries_ofs,
memory + mem_ofs,
entry_len);
entries_written++;
entries_ofs += entry_len;
}
}

Expand Down