Skip to content

Commit

Permalink
Unsafe improvements: core parquet crate. (#6024)
Browse files Browse the repository at this point in the history
* Unsafe improvements: core `parquet` crate.

* Make FromBytes an unsafe trait.
  • Loading branch information
veluca93 committed Jul 9, 2024
1 parent c47f230 commit 3ce8e84
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 17 deletions.
3 changes: 2 additions & 1 deletion parquet/src/bloom_filter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ impl Block {

#[inline]
fn to_ne_bytes(self) -> [u8; 32] {
unsafe { std::mem::transmute(self) }
// SAFETY: [u32; 8] and [u8; 32] have the same size and neither has invalid bit patterns.
unsafe { std::mem::transmute(self.0) }
}

#[inline]
Expand Down
32 changes: 22 additions & 10 deletions parquet/src/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ macro_rules! gen_as_bytes {
impl AsBytes for $source_ty {
#[allow(clippy::size_of_in_element_count)]
fn as_bytes(&self) -> &[u8] {
// SAFETY: macro is only used with primitive types that have no padding, so the
// resulting slice always refers to initialized memory.
unsafe {
std::slice::from_raw_parts(
self as *const $source_ty as *const u8,
Expand All @@ -481,6 +483,8 @@ macro_rules! gen_as_bytes {
#[inline]
#[allow(clippy::size_of_in_element_count)]
fn slice_as_bytes(self_: &[Self]) -> &[u8] {
// SAFETY: macro is only used with primitive types that have no padding, so the
// resulting slice always refers to initialized memory.
unsafe {
std::slice::from_raw_parts(
self_.as_ptr() as *const u8,
Expand All @@ -492,10 +496,15 @@ macro_rules! gen_as_bytes {
#[inline]
#[allow(clippy::size_of_in_element_count)]
unsafe fn slice_as_bytes_mut(self_: &mut [Self]) -> &mut [u8] {
std::slice::from_raw_parts_mut(
self_.as_mut_ptr() as *mut u8,
std::mem::size_of_val(self_),
)
// SAFETY: macro is only used with primitive types that have no padding, so the
// resulting slice always refers to initialized memory. Moreover, self has no
// invalid bit patterns, so all writes to the resulting slice will be valid.
unsafe {
std::slice::from_raw_parts_mut(
self_.as_mut_ptr() as *mut u8,
std::mem::size_of_val(self_),
)
}
}
}
};
Expand Down Expand Up @@ -534,12 +543,15 @@ unimplemented_slice_as_bytes!(FixedLenByteArray);

impl AsBytes for bool {
fn as_bytes(&self) -> &[u8] {
// SAFETY: a bool is guaranteed to be either 0x00 or 0x01 in memory, so the memory is
// valid.
unsafe { std::slice::from_raw_parts(self as *const bool as *const u8, 1) }
}
}

impl AsBytes for Int96 {
fn as_bytes(&self) -> &[u8] {
// SAFETY: Int96::data is a &[u32; 3].
unsafe { std::slice::from_raw_parts(self.data() as *const [u32] as *const u8, 12) }
}
}
Expand Down Expand Up @@ -718,6 +730,7 @@ pub(crate) mod private {

#[inline]
fn encode<W: std::io::Write>(values: &[Self], writer: &mut W, _: &mut BitWriter) -> Result<()> {
// SAFETY: Self is one of i32, i64, f32, f64, which have no padding.
let raw = unsafe {
std::slice::from_raw_parts(
values.as_ptr() as *const u8,
Expand Down Expand Up @@ -747,9 +760,10 @@ pub(crate) mod private {
return Err(eof_err!("Not enough bytes to decode"));
}

// SAFETY: Raw types should be as per the standard rust bit-vectors
unsafe {
let raw_buffer = &mut Self::slice_as_bytes_mut(buffer)[..bytes_to_decode];
{
// SAFETY: Self has no invalid bit patterns, so writing to the slice
// obtained with slice_as_bytes_mut is always safe.
let raw_buffer = &mut unsafe { Self::slice_as_bytes_mut(buffer) }[..bytes_to_decode];
raw_buffer.copy_from_slice(data.slice(
decoder.start..decoder.start + bytes_to_decode
).as_ref());
Expand Down Expand Up @@ -810,9 +824,7 @@ pub(crate) mod private {
_: &mut BitWriter,
) -> Result<()> {
for value in values {
let raw = unsafe {
std::slice::from_raw_parts(value.data() as *const [u32] as *const u8, 12)
};
let raw = SliceAsBytes::slice_as_bytes(value.data());
writer.write_all(raw)?;
}
Ok(())
Expand Down
45 changes: 39 additions & 6 deletions parquet/src/util/bit_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ fn array_from_slice<const N: usize>(bs: &[u8]) -> Result<[u8; N]> {
}
}

pub trait FromBytes: Sized {
/// # Safety
/// All bit patterns 00000xxxx, where there are `BIT_CAPACITY` `x`s,
/// must be valid, unless BIT_CAPACITY is 0.
pub unsafe trait FromBytes: Sized {
const BIT_CAPACITY: usize;
type Buffer: AsMut<[u8]> + Default;
fn try_from_le_slice(b: &[u8]) -> Result<Self>;
fn from_le_bytes(bs: Self::Buffer) -> Self;
Expand All @@ -51,7 +55,9 @@ pub trait FromBytes: Sized {
macro_rules! from_le_bytes {
($($ty: ty),*) => {
$(
impl FromBytes for $ty {
// SAFETY: this macro is used for types for which all bit patterns are valid.
unsafe impl FromBytes for $ty {
const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8;
type Buffer = [u8; size_of::<Self>()];
fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Ok(Self::from_le_bytes(array_from_slice(b)?))
Expand All @@ -66,7 +72,9 @@ macro_rules! from_le_bytes {

from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64, f32, f64 }

impl FromBytes for bool {
// SAFETY: the 0000000x bit pattern is always valid for `bool`.
unsafe impl FromBytes for bool {
const BIT_CAPACITY: usize = 1;
type Buffer = [u8; 1];

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand All @@ -77,7 +85,9 @@ impl FromBytes for bool {
}
}

impl FromBytes for Int96 {
// SAFETY: BIT_CAPACITY is 0.
unsafe impl FromBytes for Int96 {
const BIT_CAPACITY: usize = 0;
type Buffer = [u8; 12];

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand All @@ -95,7 +105,9 @@ impl FromBytes for Int96 {
}
}

impl FromBytes for ByteArray {
// SAFETY: BIT_CAPACITY is 0.
unsafe impl FromBytes for ByteArray {
const BIT_CAPACITY: usize = 0;
type Buffer = Vec<u8>;

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand All @@ -106,7 +118,9 @@ impl FromBytes for ByteArray {
}
}

impl FromBytes for FixedLenByteArray {
// SAFETY: BIT_CAPACITY is 0.
unsafe impl FromBytes for FixedLenByteArray {
const BIT_CAPACITY: usize = 0;
type Buffer = Vec<u8>;

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand Down Expand Up @@ -457,10 +471,17 @@ impl BitReader {
}
}

assert_ne!(T::BIT_CAPACITY, 0);
assert!(num_bits <= T::BIT_CAPACITY);

// Read directly into output buffer
match size_of::<T>() {
1 => {
let ptr = batch.as_mut_ptr() as *mut u8;
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 8 {
let out_slice = (&mut out[i..i + 8]).try_into().unwrap();
Expand All @@ -471,6 +492,10 @@ impl BitReader {
}
2 => {
let ptr = batch.as_mut_ptr() as *mut u16;
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 16 {
let out_slice = (&mut out[i..i + 16]).try_into().unwrap();
Expand All @@ -481,6 +506,10 @@ impl BitReader {
}
4 => {
let ptr = batch.as_mut_ptr() as *mut u32;
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 32 {
let out_slice = (&mut out[i..i + 32]).try_into().unwrap();
Expand All @@ -491,6 +520,10 @@ impl BitReader {
}
8 => {
let ptr = batch.as_mut_ptr() as *mut u64;
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 64 {
let out_slice = (&mut out[i..i + 64]).try_into().unwrap();
Expand Down

0 comments on commit 3ce8e84

Please sign in to comment.