Skip to content

Commit

Permalink
Provide both a safe and an unsafe implementation of array filling
Browse files Browse the repository at this point in the history
  • Loading branch information
hniksic committed May 9, 2021
1 parent 459107b commit 47c71fa
Showing 1 changed file with 85 additions and 47 deletions.
132 changes: 85 additions & 47 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ use serde::{
de::{self, Deserialize, Deserializer, SeqAccess, Visitor},
ser::{Serialize, SerializeTuple, Serializer},
};
use std::{fmt, marker::PhantomData, mem::MaybeUninit};
use std::{fmt, marker::PhantomData};

/// Serialize const generic or arbitrarily-large arrays
///
Expand Down Expand Up @@ -113,6 +113,89 @@ struct ArrayVisitor<T, const N: usize> {
_marker: PhantomData<T>,
}

#[cfg(inhibit_unsafe)]
fn into_array<T, E: serde::de::Error, const N: usize>(
mut source: impl Iterator<Item = Result<T, E>>,
) -> Result<[T; N], E> {
use std::convert::TryInto;

// Build a temporary container to hold our data as we deserialize it
// We can't rely on a Default<T> implementation, so we can't use an array here
let mut arr = Vec::with_capacity(N);

while let Some(val) = source.next() {
arr.push(val?);
}

// We can convert a Vec into an array via TryInto, which will fail if the length of the Vec
// doesn't match that of the array.
match arr.try_into() {
Ok(arr) => Ok(arr),

Err(arr) => Err(de::Error::invalid_length(
arr.len(),
&format!("an array of size {}", N).as_str(),
)),
}
}

#[cfg(not(inhibit_unsafe))]
fn into_array<T, E: serde::de::Error, const N: usize>(
mut source: impl Iterator<Item = Result<T, E>>,
) -> Result<[T; N], E> {
use std::mem::MaybeUninit;

// Safety: `assume_init` is sound because the type we are claiming to have
// initialized here is a bunch of `MaybeUninit`s, which do not require
// initialization.
let mut arr: [MaybeUninit<T>; N] = unsafe { MaybeUninit::uninit().assume_init() };

// Iterate over the array and fill the elemenets with the ones obtained from
// `seq`.
let mut place_iter = arr.iter_mut();
let mut cnt_filled = 0;
let err = loop {
match (source.next(), place_iter.next()) {
(Some(Ok(val)), Some(place)) => *place = MaybeUninit::new(val),
// no error, we're done
(None, None) => break None,
// error from serde, propagate it
(Some(Err(e)), _) => break Some(e),
// lengths do not match, report invalid_length
(None, Some(_)) | (Some(Ok(_)), None) => {
break Some(de::Error::invalid_length(
cnt_filled,
&format!("an array of size {}", N).as_str(),
))
}
}
cnt_filled += 1;
};
if let Some(err) = err {
if std::mem::needs_drop::<T>() {
for elem in std::array::IntoIter::new(arr).take(cnt_filled) {
// Safety: `assume_init()` is sound because we did initialize CNT_FILLED
// elements. We call it to drop the deserialized values.
unsafe {
elem.assume_init();
}
}
}
return Err(err);
}

// Safety: everything is initialized and we are ready to transmute to the
// initialized array type.

// See https://github.com/rust-lang/rust/issues/62875#issuecomment-513834029
//let ret = unsafe { std::mem::transmute::<_, [T; N]>(arr) };

let ret = unsafe { std::mem::transmute_copy(&arr) };
std::mem::forget(arr);

Ok(ret)
}

impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
where
T: Deserialize<'de>,
Expand All @@ -129,52 +212,7 @@ where
where
A: SeqAccess<'de>,
{
// Safety: `assume_init` is sound because the type we are claiming to have
// initialized here is a bunch of `MaybeUninit`s, which do not require
// initialization.
let mut arr: [MaybeUninit<T>; N] = unsafe { MaybeUninit::uninit().assume_init() };

// Iterate over the array and fill the elemenets with the ones obtained from
// `seq`.
let mut place_iter = arr.iter_mut();
let mut cnt_filled = 0;
let err = loop {
match (seq.next_element(), place_iter.next()) {
(Ok(Some(val)), Some(place)) => *place = MaybeUninit::new(val),
// no error, we're done
(Ok(None), None) => break None,
// error from serde, propagate it
(Err(e), _) => break Some(e),
// lengths do not match, report invalid_length
(Ok(None), Some(_)) | (Ok(Some(_)), None) => {
break Some(de::Error::invalid_length(cnt_filled, &self))
}
}
cnt_filled += 1;
};
if let Some(err) = err {
if std::mem::needs_drop::<T>() {
for elem in std::array::IntoIter::new(arr).take(cnt_filled) {
// Safety: `assume_init()` is sound because we did initialize CNT_FILLED
// elements. We call it to drop the deserialized values.
unsafe {
elem.assume_init();
}
}
}
return Err(err);
}

// Safety: everything is initialized and we are ready to transmute to the
// initialized array type.

// See https://github.com/rust-lang/rust/issues/62875#issuecomment-513834029
//let ret = unsafe { std::mem::transmute::<_, [T; N]>(arr) };

let ret = unsafe { std::mem::transmute_copy(&arr) };
std::mem::forget(arr);

Ok(ret)
into_array(std::iter::from_fn(move || seq.next_element().transpose()))
}
}

Expand Down

0 comments on commit 47c71fa

Please sign in to comment.