Skip to content

Commit

Permalink
Use WebAssembly SIMD instructions
Browse files Browse the repository at this point in the history
They are stabilized in Rust 1.54.
  • Loading branch information
CryZe committed Jun 12, 2021
1 parent 2e6aa24 commit b1d1ad3
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 20 deletions.
29 changes: 29 additions & 0 deletions src/wide/f32x4_t.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@ cfg_if::cfg_if! {
#[derive(Default, Clone, Copy, PartialEq, Debug)]
#[repr(C, align(16))]
pub struct f32x4(m128);
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
use core::arch::wasm32::*;

// repr(transparent) allows for directly passing the v128 on the WASM stack.
#[derive(Clone, Copy, Debug)]
#[repr(transparent)]
pub struct f32x4(v128);

impl Default for f32x4 {
fn default() -> Self {
Self::splat(0.0)
}
}

impl PartialEq for f32x4 {
fn eq(&self, other: &Self) -> bool {
u32x4_all_true(f32x4_eq(self.0, other.0))
}
}
} else {
#[derive(Default, Clone, Copy, PartialEq, Debug)]
#[repr(C, align(16))]
Expand All @@ -33,6 +52,8 @@ impl f32x4 {
cfg_if::cfg_if! {
if #[cfg(all(feature = "simd", target_feature = "sse"))] {
Self(max_m128(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_max(self.0, rhs.0))
} else {
Self([
self.0[0].max(rhs.0[0]),
Expand All @@ -48,6 +69,8 @@ impl f32x4 {
cfg_if::cfg_if! {
if #[cfg(all(feature = "simd", target_feature = "sse"))] {
Self(min_m128(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_min(self.0, rhs.0))
} else {
Self([
self.0[0].min(rhs.0[0]),
Expand Down Expand Up @@ -79,6 +102,8 @@ impl core::ops::Add for f32x4 {
cfg_if::cfg_if! {
if #[cfg(all(feature = "simd", target_feature = "sse"))] {
Self(add_m128(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_add(self.0, rhs.0))
} else {
Self([
self.0[0] + rhs.0[0],
Expand All @@ -104,6 +129,8 @@ impl core::ops::Sub for f32x4 {
cfg_if::cfg_if! {
if #[cfg(all(feature = "simd", target_feature = "sse"))] {
Self(sub_m128(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_sub(self.0, rhs.0))
} else {
Self([
self.0[0] - rhs.0[0],
Expand All @@ -123,6 +150,8 @@ impl core::ops::Mul for f32x4 {
cfg_if::cfg_if! {
if #[cfg(all(feature = "simd", target_feature = "sse"))] {
Self(mul_m128(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_mul(self.0, rhs.0))
} else {
Self([
self.0[0] * rhs.0[0],
Expand Down
108 changes: 94 additions & 14 deletions src/wide/f32x8_t.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@ cfg_if::cfg_if! {
#[derive(Default, Clone, Copy, PartialEq, Debug)]
#[repr(C, align(32))]
pub struct f32x8(m128, m128);
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
use core::arch::wasm32::*;

#[derive(Clone, Copy, Debug)]
#[repr(C, align(32))]
pub struct f32x8(v128, v128);

impl Default for f32x8 {
fn default() -> Self {
Self::splat(0.0)
}
}

impl PartialEq for f32x8 {
fn eq(&self, other: &Self) -> bool {
u32x4_all_true(f32x4_eq(self.0, other.0)) &
u32x4_all_true(f32x4_eq(self.1, other.1))
}
}
} else {
#[derive(Default, Clone, Copy, PartialEq, Debug)]
#[repr(C, align(32))]
Expand All @@ -41,8 +60,14 @@ impl f32x8 {
}

pub fn floor(self) -> Self {
let roundtrip: f32x8 = cast(self.trunc_int().to_f32x8());
roundtrip - roundtrip.cmp_gt(self).blend(f32x8::splat(1.0), f32x8::default())
cfg_if::cfg_if! {
if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_floor(self.0), f32x4_floor(self.1))
} else {
let roundtrip: f32x8 = cast(self.trunc_int().to_f32x8());
roundtrip - roundtrip.cmp_gt(self).blend(f32x8::splat(1.0), f32x8::default())
}
}
}

pub fn fract(self) -> Self {
Expand All @@ -67,6 +92,8 @@ impl f32x8 {
Self(cmp_op_mask_m256!(self.0, EqualOrdered, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(cmp_eq_mask_m128(self.0, rhs.0), cmp_eq_mask_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_eq(self.0, rhs.0), f32x4_eq(self.1, rhs.1))
} else {
Self(impl_x8_cmp!(self, eq, rhs, f32::from_bits(u32::MAX), 0.0))
}
Expand All @@ -79,6 +106,8 @@ impl f32x8 {
Self(cmp_op_mask_m256!(self.0, GreaterEqualOrdered, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(cmp_ge_mask_m128(self.0, rhs.0), cmp_ge_mask_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_ge(self.0, rhs.0), f32x4_ge(self.1, rhs.1))
} else {
Self(impl_x8_cmp!(self, ge, rhs, f32::from_bits(u32::MAX), 0.0))
}
Expand All @@ -91,6 +120,8 @@ impl f32x8 {
Self(cmp_op_mask_m256!(self.0, GreaterThanOrdered, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(cmp_gt_mask_m128(self.0, rhs.0), cmp_gt_mask_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_gt(self.0, rhs.0), f32x4_gt(self.1, rhs.1))
} else {
Self(impl_x8_cmp!(self, gt, rhs, f32::from_bits(u32::MAX), 0.0))
}
Expand All @@ -103,6 +134,8 @@ impl f32x8 {
Self(cmp_op_mask_m256!(self.0, NotEqualOrdered, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(cmp_neq_mask_m128(self.0, rhs.0), cmp_neq_mask_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_ne(self.0, rhs.0), f32x4_ne(self.1, rhs.1))
} else {
Self(impl_x8_cmp!(self, ne, rhs, f32::from_bits(u32::MAX), 0.0))
}
Expand All @@ -115,6 +148,8 @@ impl f32x8 {
Self(cmp_op_mask_m256!(self.0, LessEqualOrdered, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(cmp_le_mask_m128(self.0, rhs.0), cmp_le_mask_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_le(self.0, rhs.0), f32x4_le(self.1, rhs.1))
} else {
Self(impl_x8_cmp!(self, le, rhs, f32::from_bits(u32::MAX), 0.0))
}
Expand All @@ -127,6 +162,8 @@ impl f32x8 {
Self(cmp_op_mask_m256!(self.0, LessThanOrdered, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(cmp_lt_mask_m128(self.0, rhs.0), cmp_lt_mask_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_lt(self.0, rhs.0), f32x4_lt(self.1, rhs.1))
} else {
Self(impl_x8_cmp!(self, lt, rhs, f32::from_bits(u32::MAX), 0.0))
}
Expand All @@ -139,15 +176,23 @@ impl f32x8 {
Self(blend_varying_m256(f.0, t.0, self.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse4.1"))] {
Self(blend_varying_m128(f.0, t.0, self.0), blend_varying_m128(f.1, t.1, self.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(v128_bitselect(t.0, f.0, self.0), v128_bitselect(t.1, f.1, self.1))
} else {
super::generic_bit_blend(self, t, f)
}
}
}

pub fn abs(self) -> Self {
let non_sign_bits = f32x8::splat(f32::from_bits(i32::MAX as u32));
self & non_sign_bits
cfg_if::cfg_if! {
if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_abs(self.0), f32x4_abs(self.1))
} else {
let non_sign_bits = f32x8::splat(f32::from_bits(i32::MAX as u32));
self & non_sign_bits
}
}
}

pub fn max(self, rhs: Self) -> Self {
Expand All @@ -156,6 +201,8 @@ impl f32x8 {
Self(max_m256(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(max_m128(self.0, rhs.0), max_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_max(self.0, rhs.0), f32x4_max(self.1, rhs.1))
} else {
Self(impl_x8_op!(self, max, rhs))
}
Expand All @@ -168,6 +215,8 @@ impl f32x8 {
Self(min_m256(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(min_m128(self.0, rhs.0), min_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_min(self.0, rhs.0), f32x4_min(self.1, rhs.1))
} else {
Self(impl_x8_op!(self, min, rhs))
}
Expand All @@ -188,6 +237,8 @@ impl f32x8 {
Self(round_m256!(self.0, Nearest))
} else if #[cfg(all(feature = "simd", target_feature = "sse4.1"))] {
Self(round_m128!(self.0, Nearest), round_m128!(self.1, Nearest))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_nearest(self.0), f32x4_nearest(self.1))
} else {
let to_int = f32x8::splat(1.0 / f32::EPSILON);
let u: u32x8 = cast(self);
Expand Down Expand Up @@ -225,6 +276,9 @@ impl f32x8 {
convert_to_i32_m128i_from_m128(self.0),
convert_to_i32_m128i_from_m128(self.1),
)
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
let rounded = self.round();
i32x8(i32x4_trunc_sat_f32x4(rounded.0), i32x4_trunc_sat_f32x4(rounded.1))
} else {
let rounded: [f32; 8] = cast(self.round());
let rounded_ints: i32x8 = cast([
Expand All @@ -251,6 +305,11 @@ impl f32x8 {
cast(convert_truncate_to_i32_m256i_from_m256(self.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
i32x8(truncate_m128_to_m128i(self.0), truncate_m128_to_m128i(self.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
cast(Self(
i32x4_trunc_sat_f32x4(self.0),
i32x4_trunc_sat_f32x4(self.1),
))
} else {
let n: [f32; 8] = cast(self);
let ints: i32x8 = cast([
Expand All @@ -274,6 +333,12 @@ impl f32x8 {
Self(reciprocal_m256(self.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(reciprocal_m128(self.0), reciprocal_m128(self.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
let one = f32x4_splat(1.0);
Self(
f32x4_div(one, self.0),
f32x4_div(one, self.1),
)
} else {
Self::from([
1.0 / self.0[0],
Expand All @@ -295,6 +360,12 @@ impl f32x8 {
Self(reciprocal_sqrt_m256(self.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(reciprocal_sqrt_m128(self.0), reciprocal_sqrt_m128(self.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
let one = f32x4_splat(1.0);
Self(
f32x4_div(one, f32x4_sqrt(self.0)),
f32x4_div(one, f32x4_sqrt(self.1)),
)
} else {
Self::from([
1.0 / self.0[0].sqrt(),
Expand All @@ -316,6 +387,8 @@ impl f32x8 {
Self(sqrt_m256(self.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(sqrt_m128(self.0), sqrt_m128(self.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_sqrt(self.0), f32x4_sqrt(self.1))
} else {
Self::from([
self.0[0].sqrt(),
Expand Down Expand Up @@ -353,6 +426,8 @@ impl core::ops::Add for f32x8 {
Self(add_m256(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(add_m128(self.0, rhs.0), add_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_add(self.0, rhs.0), f32x4_add(self.1, rhs.1))
} else {
Self(impl_x8_op!(self, add, rhs))
}
Expand All @@ -375,6 +450,8 @@ impl core::ops::Sub for f32x8 {
Self(sub_m256(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(sub_m128(self.0, rhs.0), sub_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_sub(self.0, rhs.0), f32x4_sub(self.1, rhs.1))
} else {
Self(impl_x8_op!(self, sub, rhs))
}
Expand All @@ -391,6 +468,8 @@ impl core::ops::Mul for f32x8 {
Self(mul_m256(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(mul_m128(self.0, rhs.0), mul_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_mul(self.0, rhs.0), f32x4_mul(self.1, rhs.1))
} else {
Self(impl_x8_op!(self, mul, rhs))
}
Expand All @@ -413,6 +492,8 @@ impl core::ops::Div for f32x8 {
Self(div_m256(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(div_m128(self.0, rhs.0), div_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(f32x4_div(self.0, rhs.0), f32x4_div(self.1, rhs.1))
} else {
Self(impl_x8_op!(self, div, rhs))
}
Expand All @@ -429,6 +510,8 @@ impl core::ops::BitAnd for f32x8 {
Self(bitand_m256(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(bitand_m128(self.0, rhs.0), bitand_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(v128_and(self.0, rhs.0), v128_and(self.1, rhs.1))
} else {
Self([
f32::from_bits(self.0[0].to_bits() & rhs.0[0].to_bits()),
Expand All @@ -454,6 +537,8 @@ impl core::ops::BitOr for f32x8 {
Self(bitor_m256(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(bitor_m128(self.0, rhs.0), bitor_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(v128_or(self.0, rhs.0), v128_or(self.1, rhs.1))
} else {
Self([
f32::from_bits(self.0[0].to_bits() | rhs.0[0].to_bits()),
Expand All @@ -479,6 +564,8 @@ impl core::ops::BitXor for f32x8 {
Self(bitxor_m256(self.0, rhs.0))
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(bitxor_m128(self.0, rhs.0), bitxor_m128(self.1, rhs.1))
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(v128_xor(self.0, rhs.0), v128_xor(self.1, rhs.1))
} else {
Self([
f32::from_bits(self.0[0].to_bits() ^ rhs.0[0].to_bits()),
Expand Down Expand Up @@ -512,17 +599,10 @@ impl core::ops::Not for f32x8 {
Self(self.0.not())
} else if #[cfg(all(feature = "simd", target_feature = "sse2"))] {
Self(self.0.not(), self.1.not())
} else if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
Self(v128_not(self.0), v128_not(self.1))
} else {
Self::from([
f32::from_bits(self.0[0].to_bits() ^ u32::MAX),
f32::from_bits(self.0[1].to_bits() ^ u32::MAX),
f32::from_bits(self.0[2].to_bits() ^ u32::MAX),
f32::from_bits(self.0[3].to_bits() ^ u32::MAX),
f32::from_bits(self.0[4].to_bits() ^ u32::MAX),
f32::from_bits(self.0[5].to_bits() ^ u32::MAX),
f32::from_bits(self.0[6].to_bits() ^ u32::MAX),
f32::from_bits(self.0[7].to_bits() ^ u32::MAX),
])
self ^ Self::splat(cast(u32::MAX))
}
}
}
Expand Down

0 comments on commit b1d1ad3

Please sign in to comment.