Skip to content

Commit

Permalink
Rollup merge of rust-lang#94212 - scottmcm:swapper, r=dtolnay
Browse files Browse the repository at this point in the history
Stop manually SIMDing in `swap_nonoverlapping`

Like I previously did for `reverse` (rust-lang#90821), this leaves it to LLVM to pick how to vectorize it, since it can know better the chunk size to use, compared to the "32 bytes always" approach we currently have.

A variety of codegen tests are included to confirm that the various cases are still being vectorized.

It does still need logic to type-erase in some cases, though, as while LLVM is now smart enough to vectorize over slices of things like `[u8; 4]`, it fails to do so over slices of `[u8; 3]`.

As a bonus, this change also means one no longer gets the spurious `memcpy`(s?) at the end up swapping a slice of `__m256`s: <https://rust.godbolt.org/z/joofr4v8Y>

<details>

<summary>ASM for this example</summary>

## Before (from godbolt)

note the `push`/`pop`s and `memcpy`

```x86
swap_m256_slice:
        push    r15
        push    r14
        push    r13
        push    r12
        push    rbx
        sub     rsp, 32
        cmp     rsi, rcx
        jne     .LBB0_6
        mov     r14, rsi
        shl     r14, 5
        je      .LBB0_6
        mov     r15, rdx
        mov     rbx, rdi
        xor     eax, eax
.LBB0_3:
        mov     rcx, rax
        vmovaps ymm0, ymmword ptr [rbx + rax]
        vmovaps ymm1, ymmword ptr [r15 + rax]
        vmovaps ymmword ptr [rbx + rax], ymm1
        vmovaps ymmword ptr [r15 + rax], ymm0
        add     rax, 32
        add     rcx, 64
        cmp     rcx, r14
        jbe     .LBB0_3
        sub     r14, rax
        jbe     .LBB0_6
        add     rbx, rax
        add     r15, rax
        mov     r12, rsp
        mov     r13, qword ptr [rip + memcpy@GOTPCREL]
        mov     rdi, r12
        mov     rsi, rbx
        mov     rdx, r14
        vzeroupper
        call    r13
        mov     rdi, rbx
        mov     rsi, r15
        mov     rdx, r14
        call    r13
        mov     rdi, r15
        mov     rsi, r12
        mov     rdx, r14
        call    r13
.LBB0_6:
        add     rsp, 32
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     r15
        vzeroupper
        ret
```

## After (from my machine)

Note no `rsp` manipulation, sorry for different ASM syntax

```x86
swap_m256_slice:
	cmpq	%r9, %rdx
	jne	.LBB1_6
	testq	%rdx, %rdx
	je	.LBB1_6
	cmpq	$1, %rdx
	jne	.LBB1_7
	xorl	%r10d, %r10d
	jmp	.LBB1_4
.LBB1_7:
	movq	%rdx, %r9
	andq	$-2, %r9
	movl	$32, %eax
	xorl	%r10d, %r10d
	.p2align	4, 0x90
.LBB1_8:
	vmovaps	-32(%rcx,%rax), %ymm0
	vmovaps	-32(%r8,%rax), %ymm1
	vmovaps	%ymm1, -32(%rcx,%rax)
	vmovaps	%ymm0, -32(%r8,%rax)
	vmovaps	(%rcx,%rax), %ymm0
	vmovaps	(%r8,%rax), %ymm1
	vmovaps	%ymm1, (%rcx,%rax)
	vmovaps	%ymm0, (%r8,%rax)
	addq	$2, %r10
	addq	$64, %rax
	cmpq	%r10, %r9
	jne	.LBB1_8
.LBB1_4:
	testb	$1, %dl
	je	.LBB1_6
	shlq	$5, %r10
	vmovaps	(%rcx,%r10), %ymm0
	vmovaps	(%r8,%r10), %ymm1
	vmovaps	%ymm1, (%rcx,%r10)
	vmovaps	%ymm0, (%r8,%r10)
.LBB1_6:
	vzeroupper
	retq
```

</details>

This does all its copying operations as either the original type or as `MaybeUninit`s, so as far as I know there should be no potential abstract machine issues with reading padding bytes as integers.

<details>

<summary>Perf is essentially unchanged</summary>

Though perhaps with more target features this would help more, if it could pick bigger chunks

## Before

```
running 10 tests
test slice::swap_with_slice_4x_usize_30                            ... bench:         894 ns/iter (+/- 11)
test slice::swap_with_slice_4x_usize_3000                          ... bench:      99,476 ns/iter (+/- 2,784)
test slice::swap_with_slice_5x_usize_30                            ... bench:       1,257 ns/iter (+/- 7)
test slice::swap_with_slice_5x_usize_3000                          ... bench:     139,922 ns/iter (+/- 959)
test slice::swap_with_slice_rgb_30                                 ... bench:         328 ns/iter (+/- 27)
test slice::swap_with_slice_rgb_3000                               ... bench:      16,215 ns/iter (+/- 176)
test slice::swap_with_slice_u8_30                                  ... bench:         312 ns/iter (+/- 9)
test slice::swap_with_slice_u8_3000                                ... bench:       5,401 ns/iter (+/- 123)
test slice::swap_with_slice_usize_30                               ... bench:         368 ns/iter (+/- 3)
test slice::swap_with_slice_usize_3000                             ... bench:      28,472 ns/iter (+/- 3,913)
```

## After

```
running 10 tests
test slice::swap_with_slice_4x_usize_30                            ... bench:         868 ns/iter (+/- 36)
test slice::swap_with_slice_4x_usize_3000                          ... bench:      99,642 ns/iter (+/- 1,507)
test slice::swap_with_slice_5x_usize_30                            ... bench:       1,194 ns/iter (+/- 11)
test slice::swap_with_slice_5x_usize_3000                          ... bench:     139,761 ns/iter (+/- 5,018)
test slice::swap_with_slice_rgb_30                                 ... bench:         324 ns/iter (+/- 6)
test slice::swap_with_slice_rgb_3000                               ... bench:      15,962 ns/iter (+/- 287)
test slice::swap_with_slice_u8_30                                  ... bench:         281 ns/iter (+/- 5)
test slice::swap_with_slice_u8_3000                                ... bench:       5,324 ns/iter (+/- 40)
test slice::swap_with_slice_usize_30                               ... bench:         275 ns/iter (+/- 5)
test slice::swap_with_slice_usize_3000                             ... bench:      28,277 ns/iter (+/- 277)
```

</detail>
  • Loading branch information
Dylan-DPC committed Feb 24, 2022
2 parents 000e38d + 8ca47d7 commit 7fb55b4
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 97 deletions.
43 changes: 39 additions & 4 deletions library/core/benches/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ fn binary_search_l3_worst_case(b: &mut Bencher) {
binary_search_worst_case(b, Cache::L3);
}

#[derive(Clone)]
struct Rgb(u8, u8, u8);

impl Rgb {
fn gen(i: usize) -> Self {
Rgb(i as u8, (i as u8).wrapping_add(7), (i as u8).wrapping_add(42))
}
}

macro_rules! rotate {
($fn:ident, $n:expr, $mapper:expr) => {
#[bench]
Expand All @@ -104,17 +113,43 @@ macro_rules! rotate {
};
}

#[derive(Clone)]
struct Rgb(u8, u8, u8);

rotate!(rotate_u8, 32, |i| i as u8);
rotate!(rotate_rgb, 32, |i| Rgb(i as u8, (i as u8).wrapping_add(7), (i as u8).wrapping_add(42)));
rotate!(rotate_rgb, 32, Rgb::gen);
rotate!(rotate_usize, 32, |i| i);
rotate!(rotate_16_usize_4, 16, |i| [i; 4]);
rotate!(rotate_16_usize_5, 16, |i| [i; 5]);
rotate!(rotate_64_usize_4, 64, |i| [i; 4]);
rotate!(rotate_64_usize_5, 64, |i| [i; 5]);

macro_rules! swap_with_slice {
($fn:ident, $n:expr, $mapper:expr) => {
#[bench]
fn $fn(b: &mut Bencher) {
let mut x = (0usize..$n).map(&$mapper).collect::<Vec<_>>();
let mut y = ($n..($n * 2)).map(&$mapper).collect::<Vec<_>>();
let mut skip = 0;
b.iter(|| {
for _ in 0..32 {
x[skip..].swap_with_slice(&mut y[..($n - skip)]);
skip = black_box(skip + 1) % 8;
}
black_box((x[$n / 3].clone(), y[$n * 2 / 3].clone()))
})
}
};
}

swap_with_slice!(swap_with_slice_u8_30, 30, |i| i as u8);
swap_with_slice!(swap_with_slice_u8_3000, 3000, |i| i as u8);
swap_with_slice!(swap_with_slice_rgb_30, 30, Rgb::gen);
swap_with_slice!(swap_with_slice_rgb_3000, 3000, Rgb::gen);
swap_with_slice!(swap_with_slice_usize_30, 30, |i| i);
swap_with_slice!(swap_with_slice_usize_3000, 3000, |i| i);
swap_with_slice!(swap_with_slice_4x_usize_30, 30, |i| [i; 4]);
swap_with_slice!(swap_with_slice_4x_usize_3000, 3000, |i| [i; 4]);
swap_with_slice!(swap_with_slice_5x_usize_30, 30, |i| [i; 5]);
swap_with_slice!(swap_with_slice_5x_usize_3000, 3000, |i| [i; 5]);

#[bench]
fn fill_byte_sized(b: &mut Bencher) {
#[derive(Copy, Clone)]
Expand Down
45 changes: 42 additions & 3 deletions library/core/src/mem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,49 @@ pub unsafe fn uninitialized<T>() -> T {
#[stable(feature = "rust1", since = "1.0.0")]
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
pub const fn swap<T>(x: &mut T, y: &mut T) {
// SAFETY: the raw pointers have been created from safe mutable references satisfying all the
// constraints on `ptr::swap_nonoverlapping_one`
// NOTE(eddyb) SPIR-V's Logical addressing model doesn't allow for arbitrary
// reinterpretation of values as (chunkable) byte arrays, and the loop in the
// block optimization in `swap_slice` is hard to rewrite back
// into the (unoptimized) direct swapping implementation, so we disable it.
// FIXME(eddyb) the block optimization also prevents MIR optimizations from
// understanding `mem::replace`, `Option::take`, etc. - a better overall
// solution might be to make `ptr::swap_nonoverlapping` into an intrinsic, which
// a backend can choose to implement using the block optimization, or not.
#[cfg(not(target_arch = "spirv"))]
{
// For types that are larger multiples of their alignment, the simple way
// tends to copy the whole thing to stack rather than doing it one part
// at a time, so instead treat them as one-element slices and piggy-back
// the slice optimizations that will split up the swaps.
if size_of::<T>() / align_of::<T>() > 4 {
// SAFETY: exclusive references always point to one non-overlapping
// element and are non-null and properly aligned.
return unsafe { ptr::swap_nonoverlapping(x, y, 1) };
}
}

// If a scalar consists of just a small number of alignment units, let
// the codegen just swap those pieces directly, as it's likely just a
// few instructions and anything else is probably overcomplicated.
//
// Most importantly, this covers primitives and simd types that tend to
// have size=align where doing anything else can be a pessimization.
// (This will also be used for ZSTs, though any solution works for them.)
swap_simple(x, y);
}

/// Same as [`swap`] semantically, but always uses the simple implementation.
///
/// Used elsewhere in `mem` and `ptr` at the bottom layer of calls.
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
#[inline]
pub(crate) const fn swap_simple<T>(x: &mut T, y: &mut T) {
// SAFETY: exclusive references are always valid to read/write,
// are non-overlapping, and nothing here panics so it's drop-safe.
unsafe {
ptr::swap_nonoverlapping_one(x, y);
let z = ptr::read(x);
ptr::copy_nonoverlapping(y, x, 1);
ptr::write(y, z);
}
}

Expand Down
132 changes: 42 additions & 90 deletions library/core/src/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,106 +419,58 @@ pub const unsafe fn swap<T>(x: *mut T, y: *mut T) {
#[stable(feature = "swap_nonoverlapping", since = "1.27.0")]
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
let x = x as *mut u8;
let y = y as *mut u8;
let len = mem::size_of::<T>() * count;
// SAFETY: the caller must guarantee that `x` and `y` are
// valid for writes and properly aligned.
unsafe { swap_nonoverlapping_bytes(x, y, len) }
}
macro_rules! attempt_swap_as_chunks {
($ChunkTy:ty) => {
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
{
let x: *mut MaybeUninit<$ChunkTy> = x.cast();
let y: *mut MaybeUninit<$ChunkTy> = y.cast();
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
// SAFETY: these are the same bytes that the caller promised were
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
// The `if` condition above ensures that we're not violating
// alignment requirements, and that the division is exact so
// that we don't lose any bytes off the end.
return unsafe { swap_nonoverlapping_simple(x, y, count) };
}
};
}

#[inline]
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
pub(crate) const unsafe fn swap_nonoverlapping_one<T>(x: *mut T, y: *mut T) {
// NOTE(eddyb) SPIR-V's Logical addressing model doesn't allow for arbitrary
// reinterpretation of values as (chunkable) byte arrays, and the loop in the
// block optimization in `swap_nonoverlapping_bytes` is hard to rewrite back
// into the (unoptimized) direct swapping implementation, so we disable it.
// FIXME(eddyb) the block optimization also prevents MIR optimizations from
// understanding `mem::replace`, `Option::take`, etc. - a better overall
// solution might be to make `swap_nonoverlapping` into an intrinsic, which
// a backend can choose to implement using the block optimization, or not.
#[cfg(not(target_arch = "spirv"))]
// Split up the slice into small power-of-two-sized chunks that LLVM is able
// to vectorize (unless it's a special type with more-than-pointer alignment,
// because we don't want to pessimize things like slices of SIMD vectors.)
if mem::align_of::<T>() <= mem::size_of::<usize>()
&& (!mem::size_of::<T>().is_power_of_two()
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
{
// Only apply the block optimization in `swap_nonoverlapping_bytes` for types
// at least as large as the block size, to avoid pessimizing codegen.
if mem::size_of::<T>() >= 32 {
// SAFETY: the caller must uphold the safety contract for `swap_nonoverlapping`.
unsafe { swap_nonoverlapping(x, y, 1) };
return;
}
attempt_swap_as_chunks!(usize);
attempt_swap_as_chunks!(u8);
}

// Direct swapping, for the cases not going through the block optimization.
// SAFETY: the caller must guarantee that `x` and `y` are valid
// for writes, properly aligned, and non-overlapping.
unsafe {
let z = read(x);
copy_nonoverlapping(y, x, 1);
write(y, z);
}
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple(x, y, count) }
}

/// Same behaviour and safety conditions as [`swap_nonoverlapping`]
///
/// LLVM can vectorize this (at least it can for the power-of-two-sized types
/// `swap_nonoverlapping` tries to use) so no need to manually SIMD it.
#[inline]
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
const unsafe fn swap_nonoverlapping_bytes(x: *mut u8, y: *mut u8, len: usize) {
// The approach here is to utilize simd to swap x & y efficiently. Testing reveals
// that swapping either 32 bytes or 64 bytes at a time is most efficient for Intel
// Haswell E processors. LLVM is more able to optimize if we give a struct a
// #[repr(simd)], even if we don't actually use this struct directly.
//
// FIXME repr(simd) broken on emscripten and redox
#[cfg_attr(not(any(target_os = "emscripten", target_os = "redox")), repr(simd))]
struct Block(u64, u64, u64, u64);
struct UnalignedBlock(u64, u64, u64, u64);

let block_size = mem::size_of::<Block>();

// Loop through x & y, copying them `Block` at a time
// The optimizer should unroll the loop fully for most types
// N.B. We can't use a for loop as the `range` impl calls `mem::swap` recursively
const unsafe fn swap_nonoverlapping_simple<T>(x: *mut T, y: *mut T, count: usize) {
let mut i = 0;
while i + block_size <= len {
// Create some uninitialized memory as scratch space
// Declaring `t` here avoids aligning the stack when this loop is unused
let mut t = mem::MaybeUninit::<Block>::uninit();
let t = t.as_mut_ptr() as *mut u8;

// SAFETY: As `i < len`, and as the caller must guarantee that `x` and `y` are valid
// for `len` bytes, `x + i` and `y + i` must be valid addresses, which fulfills the
// safety contract for `add`.
//
// Also, the caller must guarantee that `x` and `y` are valid for writes, properly aligned,
// and non-overlapping, which fulfills the safety contract for `copy_nonoverlapping`.
unsafe {
let x = x.add(i);
let y = y.add(i);
while i < count {
let x: &mut T =
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
unsafe { &mut *x.add(i) };
let y: &mut T =
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
// and it's distinct from `x` since the ranges are non-overlapping
unsafe { &mut *y.add(i) };
mem::swap_simple(x, y);

// Swap a block of bytes of x & y, using t as a temporary buffer
// This should be optimized into efficient SIMD operations where available
copy_nonoverlapping(x, t, block_size);
copy_nonoverlapping(y, x, block_size);
copy_nonoverlapping(t, y, block_size);
}
i += block_size;
}

if i < len {
// Swap any remaining bytes
let mut t = mem::MaybeUninit::<UnalignedBlock>::uninit();
let rem = len - i;

let t = t.as_mut_ptr() as *mut u8;

// SAFETY: see previous safety comment.
unsafe {
let x = x.add(i);
let y = y.add(i);

copy_nonoverlapping(x, t, rem);
copy_nonoverlapping(y, x, rem);
copy_nonoverlapping(t, y, rem);
}
i += 1;
}
}

Expand Down
64 changes: 64 additions & 0 deletions src/test/codegen/swap-large-types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// compile-flags: -O
// only-x86_64
// ignore-debug: the debug assertions get in the way

#![crate_type = "lib"]

use std::mem::swap;
use std::ptr::{read, copy_nonoverlapping, write};

type KeccakBuffer = [[u64; 5]; 5];

// A basic read+copy+write swap implementation ends up copying one of the values
// to stack for large types, which is completely unnecessary as the lack of
// overlap means we can just do whatever fits in registers at a time.

// CHECK-LABEL: @swap_basic
#[no_mangle]
pub fn swap_basic(x: &mut KeccakBuffer, y: &mut KeccakBuffer) {
// CHECK: alloca [5 x [5 x i64]]

// SAFETY: exclusive references are always valid to read/write,
// are non-overlapping, and nothing here panics so it's drop-safe.
unsafe {
let z = read(x);
copy_nonoverlapping(y, x, 1);
write(y, z);
}
}

// This test verifies that the library does something smarter, and thus
// doesn't need any scratch space on the stack.

// CHECK-LABEL: @swap_std
#[no_mangle]
pub fn swap_std(x: &mut KeccakBuffer, y: &mut KeccakBuffer) {
// CHECK-NOT: alloca
// CHECK: load <{{[0-9]+}} x i64>
// CHECK: store <{{[0-9]+}} x i64>
swap(x, y)
}

// CHECK-LABEL: @swap_slice
#[no_mangle]
pub fn swap_slice(x: &mut [KeccakBuffer], y: &mut [KeccakBuffer]) {
// CHECK-NOT: alloca
// CHECK: load <{{[0-9]+}} x i64>
// CHECK: store <{{[0-9]+}} x i64>
if x.len() == y.len() {
x.swap_with_slice(y);
}
}

type OneKilobyteBuffer = [u8; 1024];

// CHECK-LABEL: @swap_1kb_slices
#[no_mangle]
pub fn swap_1kb_slices(x: &mut [OneKilobyteBuffer], y: &mut [OneKilobyteBuffer]) {
// CHECK-NOT: alloca
// CHECK: load <{{[0-9]+}} x i8>
// CHECK: store <{{[0-9]+}} x i8>
if x.len() == y.len() {
x.swap_with_slice(y);
}
}
32 changes: 32 additions & 0 deletions src/test/codegen/swap-simd-types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// compile-flags: -O -C target-feature=+avx
// only-x86_64
// ignore-debug: the debug assertions get in the way

#![crate_type = "lib"]

use std::mem::swap;

// SIMD types are highly-aligned already, so make sure the swap code leaves their
// types alone and doesn't pessimize them (such as by swapping them as `usize`s).
extern crate core;
use core::arch::x86_64::__m256;

// CHECK-LABEL: @swap_single_m256
#[no_mangle]
pub fn swap_single_m256(x: &mut __m256, y: &mut __m256) {
// CHECK-NOT: alloca
// CHECK: load <8 x float>{{.+}}align 32
// CHECK: store <8 x float>{{.+}}align 32
swap(x, y)
}

// CHECK-LABEL: @swap_m256_slice
#[no_mangle]
pub fn swap_m256_slice(x: &mut [__m256], y: &mut [__m256]) {
// CHECK-NOT: alloca
// CHECK: load <8 x float>{{.+}}align 32
// CHECK: store <8 x float>{{.+}}align 32
if x.len() == y.len() {
x.swap_with_slice(y);
}
}
Loading

0 comments on commit 7fb55b4

Please sign in to comment.