Skip to content

Commit

Permalink
Change a few local buffer allocations to use unique_ptr so they'll be…
Browse files Browse the repository at this point in the history
… automatically cleaned up and use RAII to avoid hazards that might allow them to leak
  • Loading branch information
prozacchiwawa committed Dec 6, 2023
1 parent 6e44d89 commit c2d9629
Showing 1 changed file with 33 additions and 41 deletions.
74 changes: 33 additions & 41 deletions src/prover_disk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,11 @@ class DiskProver {
// The list of C2 entries is small enough to keep in memory. When proving, we can
// read from disk the C1 and C3 entries.
uint64_t prev_c2_f7 = 0;
auto* c2_buf = new uint8_t[c2_size];
auto c2_buf = std::make_unique<uint8_t[]>(c2_size);
for (uint32_t i = 0; i < c2_entries - 1; i++) {
SafeRead(disk_file, c2_buf, c2_size);
SafeRead(disk_file, c2_buf.get(), c2_size);

const uint64_t f7 = Bits(c2_buf, c2_size, c2_size * 8).Slice(0, k).GetValue();
const uint64_t f7 = Bits(c2_buf.get(), c2_size, c2_size * 8).Slice(0, k).GetValue();

// Short-circuit reading of the C2 table as soon as we encounter an f7 entry whose
// value is lesser than the previous f7 read. This ensures that we don't read
Expand All @@ -429,8 +429,6 @@ class DiskProver {
this->C2.push_back(f7);
prev_c2_f7 = f7;
}

delete[] c2_buf;
}

explicit DiskProver(const std::vector<uint8_t>& vecBytes)
Expand Down Expand Up @@ -869,18 +867,21 @@ class DiskProver {

// This is the checkpoint at the beginning of the park
uint16_t line_point_size = EntrySizes::CalculateLinePointSize(k);
auto* line_point_bin = new uint8_t[line_point_size + 7];
SafeRead(disk_file, line_point_bin, line_point_size);
uint128_t line_point = Util::SliceInt128FromBytes(line_point_bin, 0, k * 2);
// Using unique_ptr here allows this object to be automatically destructed at the throw below.
auto line_point_bin = std::make_unique<uint8_t[]>(line_point_size + 7);
SafeRead(disk_file, line_point_bin.get(), line_point_size);
uint128_t line_point = Util::SliceInt128FromBytes(line_point_bin.get(), 0, k * 2);

// Reads EPP stubs
uint32_t stubs_size_bits = (is_compressed ? (Util::ByteAlign((kEntriesPerPark - 1) * compressed_stub_size_bits) / 8) : EntrySizes::CalculateStubsSize(k)) * 8;
auto* stubs_bin = new uint8_t[stubs_size_bits / 8 + 7];
SafeRead(disk_file, stubs_bin, stubs_size_bits / 8);
// As above: avoid leak via throw.
auto stubs_bin = std::make_unique<uint8_t[]>(stubs_size_bits / 8 + 7);
SafeRead(disk_file, stubs_bin.get(), stubs_size_bits / 8);

// Reads EPP deltas
uint32_t max_deltas_size_bits = (is_compressed ? compressed_park_size - (line_point_size + stubs_size_bits) : EntrySizes::CalculateMaxDeltasSize(k, table_index)) * 8;
auto* deltas_bin = new uint8_t[max_deltas_size_bits / 8];
// Avoid leak via throw.
auto deltas_bin = std::make_unique<uint8_t[]>(max_deltas_size_bits / 8);

// Reads the size of the encoded deltas object
uint16_t encoded_deltas_size = 0;
Expand All @@ -899,12 +900,12 @@ class DiskProver {
SafeRead(disk_file, deltas.data(), encoded_deltas_size);
} else {
// Compressed
SafeRead(disk_file, deltas_bin, encoded_deltas_size);
SafeRead(disk_file, deltas_bin.get(), encoded_deltas_size);

// Decodes the deltas
double R = (is_compressed ? compressed_ans_r_value : kRValues[table_index - 1]);
deltas =
Encoding::ANSDecodeDeltas(deltas_bin, encoded_deltas_size, kEntriesPerPark - 1, R);
Encoding::ANSDecodeDeltas(deltas_bin.get(), encoded_deltas_size, kEntriesPerPark - 1, R);
}

uint32_t start_bit = 0;
Expand All @@ -914,7 +915,7 @@ class DiskProver {
for (uint32_t i = 0;
i < std::min((uint32_t)(position % kEntriesPerPark), (uint32_t)deltas.size());
i++) {
uint64_t stub = Util::EightBytesToInt(stubs_bin + start_bit / 8);
uint64_t stub = Util::EightBytesToInt(stubs_bin.get() + start_bit / 8);
stub <<= start_bit % 8;
stub >>= 64 - stub_size;

Expand All @@ -926,10 +927,6 @@ class DiskProver {
uint128_t big_delta = ((uint128_t)sum_deltas << stub_size) + sum_stubs;
uint128_t final_line_point = line_point + big_delta;

delete[] line_point_bin;
delete[] stubs_bin;
delete[] deltas_bin;

return final_line_point;
}

Expand Down Expand Up @@ -1013,16 +1010,16 @@ class DiskProver {

uint32_t c1_entry_size = Util::ByteAlign(k) / 8;

auto* c1_entry_bytes = new uint8_t[c1_entry_size];
auto c1_entry_bytes = std::make_unique<uint8_t[]>(c1_entry_size);
SafeSeek(disk_file, table_begin_pointers[8] + c1_index * Util::ByteAlign(k) / 8);

uint64_t curr_f7 = c2_entry_f;
uint64_t prev_f7 = c2_entry_f;
broke = false;
// Goes through C2 entries until we find the correct C1 checkpoint.
for (uint64_t start = 0; start < kCheckpoint1Interval; start++) {
SafeRead(disk_file, c1_entry_bytes, c1_entry_size);
Bits c1_entry = Bits(c1_entry_bytes, Util::ByteAlign(k) / 8, Util::ByteAlign(k));
SafeRead(disk_file, c1_entry_bytes.get(), c1_entry_size);
Bits c1_entry = Bits(c1_entry_bytes.get(), Util::ByteAlign(k) / 8, Util::ByteAlign(k));
uint64_t read_f7 = c1_entry.Slice(0, k).GetValue();

if (start != 0 && read_f7 == 0) {
Expand All @@ -1048,7 +1045,8 @@ class DiskProver {
}

uint32_t c3_entry_size = EntrySizes::CalculateC3Size(k);
auto* bit_mask = new uint8_t[c3_entry_size];
// Use a unique_ptr so the early returns below don't leak.
auto bit_mask = std::make_unique<uint8_t[]>(c3_entry_size);

// Double entry means that our entries are in more than one checkpoint park.
bool double_entry = f7 == curr_f7 && c1_index > 0;
Expand All @@ -1063,8 +1061,8 @@ class DiskProver {
// In this case, we read the previous park as well as the current one
c1_index -= 1;
SafeSeek(disk_file, table_begin_pointers[8] + c1_index * Util::ByteAlign(k) / 8);
SafeRead(disk_file, c1_entry_bytes, Util::ByteAlign(k) / 8);
Bits c1_entry_bits = Bits(c1_entry_bytes, Util::ByteAlign(k) / 8, Util::ByteAlign(k));
SafeRead(disk_file, c1_entry_bytes.get(), Util::ByteAlign(k) / 8);
Bits c1_entry_bits = Bits(c1_entry_bytes.get(), Util::ByteAlign(k) / 8, Util::ByteAlign(k));
next_f7 = curr_f7;
curr_f7 = c1_entry_bits.Slice(0, k).GetValue();

Expand All @@ -1079,10 +1077,10 @@ class DiskProver {
return std::vector<uint64_t>();
}

SafeRead(disk_file, bit_mask, c3_entry_size - 2);
SafeRead(disk_file, bit_mask.get(), c3_entry_size - 2);

p7_positions =
GetP7Positions(curr_f7, f7, curr_p7_pos, bit_mask, encoded_size, c1_index);
GetP7Positions(curr_f7, f7, curr_p7_pos, bit_mask.get(), encoded_size, c1_index);

SafeRead(disk_file, encoded_size_buf, 2);
encoded_size = Bits(encoded_size_buf, 2, 16).GetValue();
Expand All @@ -1093,12 +1091,12 @@ class DiskProver {
return std::vector<uint64_t>();
}

SafeRead(disk_file, bit_mask, c3_entry_size - 2);
SafeRead(disk_file, bit_mask.get(), c3_entry_size - 2);

c1_index++;
curr_p7_pos = c1_index * kCheckpoint1Interval;
auto second_positions =
GetP7Positions(next_f7, f7, curr_p7_pos, bit_mask, encoded_size, c1_index);
GetP7Positions(next_f7, f7, curr_p7_pos, bit_mask.get(), encoded_size, c1_index);

p7_positions.insert(
p7_positions.end(), second_positions.begin(), second_positions.end());
Expand All @@ -1114,17 +1112,15 @@ class DiskProver {
return std::vector<uint64_t>();
}

SafeRead(disk_file, bit_mask, c3_entry_size - 2);
SafeRead(disk_file, bit_mask.get(), c3_entry_size - 2);

p7_positions =
GetP7Positions(curr_f7, f7, curr_p7_pos, bit_mask, encoded_size, c1_index);
GetP7Positions(curr_f7, f7, curr_p7_pos, bit_mask.get(), encoded_size, c1_index);
}

// p7_positions is a list of all the positions into table P7, where the output is equal to
// f7. If it's empty, no proofs are present for this f7.
if (p7_positions.empty()) {
delete[] bit_mask;
delete[] c1_entry_bytes;
return std::vector<uint64_t>();
}

Expand All @@ -1134,28 +1130,24 @@ class DiskProver {

// Given the p7 positions, which are all adjacent, we can read the pos6 values from table
// P7.
auto* p7_park_buf = new uint8_t[p7_park_size_bytes];
auto p7_park_buf = std::make_unique<uint8_t[]>(p7_park_size_bytes);
uint64_t park_index = (p7_positions[0] == 0 ? 0 : p7_positions[0]) / kEntriesPerPark;
SafeSeek(disk_file, table_begin_pointers[7] + park_index * p7_park_size_bytes);
SafeRead(disk_file, p7_park_buf, p7_park_size_bytes);
ParkBits p7_park = ParkBits(p7_park_buf, p7_park_size_bytes, p7_park_size_bytes * 8);
SafeRead(disk_file, p7_park_buf.get(), p7_park_size_bytes);
ParkBits p7_park = ParkBits(p7_park_buf.get(), p7_park_size_bytes, p7_park_size_bytes * 8);
for (uint64_t i = 0; i < p7_positions[p7_positions.size() - 1] - p7_positions[0] + 1; i++) {
uint64_t new_park_index = (p7_positions[i]) / kEntriesPerPark;
if (new_park_index > park_index) {
SafeSeek(disk_file, table_begin_pointers[7] + new_park_index * p7_park_size_bytes);
SafeRead(disk_file, p7_park_buf, p7_park_size_bytes);
p7_park = ParkBits(p7_park_buf, p7_park_size_bytes, p7_park_size_bytes * 8);
SafeRead(disk_file, p7_park_buf.get(), p7_park_size_bytes);
p7_park = ParkBits(p7_park_buf.get(), p7_park_size_bytes, p7_park_size_bytes * 8);
}
uint32_t start_bit_index = (p7_positions[i] % kEntriesPerPark) * (k + 1);

uint64_t p7_int = p7_park.Slice(start_bit_index, start_bit_index + k + 1).GetValue();
p7_entries.push_back(p7_int);
}

delete[] bit_mask;
delete[] c1_entry_bytes;
delete[] p7_park_buf;

return p7_entries;
}

Expand Down

0 comments on commit c2d9629

Please sign in to comment.