diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 8e7fc3a727..a6e7ee4da3 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -3247,7 +3247,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { &mut self, callee_ty: Self::Type, _fn_attrs: Option<&CodegenFnAttrs>, - _fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>, + fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>, callee: Self::Value, args: &[Self::Value], funclet: Option<&Self::Funclet>, @@ -3310,9 +3310,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let libm_intrinsic = instance_def_id.and_then(|def_id| self.libm_intrinsics.borrow().get(&def_id).copied()); let buffer_load_intrinsic = instance_def_id - .and_then(|def_id| self.buffer_load_intrinsics.borrow().get(&def_id).copied()); + .is_some_and(|def_id| self.buffer_load_intrinsics.borrow().contains(&def_id)); let buffer_store_intrinsic = instance_def_id - .and_then(|def_id| self.buffer_store_intrinsics.borrow().get(&def_id).copied()); + .is_some_and(|def_id| self.buffer_store_intrinsics.borrow().contains(&def_id)); let is_panic_entry_point = instance_def_id .is_some_and(|def_id| self.panic_entry_points.borrow().contains(&def_id)); let from_trait_impl = @@ -4101,14 +4101,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { self.abort_with_kind_and_message_debug_printf("panic", message, debug_printf_args); return self.undef(result_type); } - - if let Some(mode) = buffer_load_intrinsic { - return self.codegen_buffer_load_intrinsic(result_type, args, mode); + if buffer_load_intrinsic { + return self.codegen_buffer_load_intrinsic(fn_abi, result_type, args); } - - if let Some(mode) = buffer_store_intrinsic { - self.codegen_buffer_store_intrinsic(args, mode); - + if buffer_store_intrinsic { + self.codegen_buffer_store_intrinsic(fn_abi, args); let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self); return SpirvValue { kind: SpirvValueKind::IllegalTypeUsed(void_ty), diff --git a/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs index 884f5a88db..0aaee7f8d7 100644 --- a/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs +++ b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs @@ -9,8 +9,9 @@ use rustc_abi::{Align, Size}; use rustc_codegen_spirv_types::Capability; use rustc_codegen_ssa::traits::BuilderMethods; use rustc_errors::ErrorGuaranteed; +use rustc_middle::ty::Ty; use rustc_span::DUMMY_SP; -use rustc_target::callconv::PassMode; +use rustc_target::callconv::{FnAbi, PassMode}; impl<'a, 'tcx> Builder<'a, 'tcx> { fn load_err(&mut self, original_type: Word, invalid_type: Word) -> SpirvValue { @@ -181,10 +182,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller. pub fn codegen_buffer_load_intrinsic( &mut self, + fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>, result_type: Word, args: &[SpirvValue], - pass_mode: &PassMode, ) -> SpirvValue { + let pass_mode = &fn_abi.unwrap().ret.mode; match pass_mode { PassMode::Ignore => { return SpirvValue { @@ -364,8 +366,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } /// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller. - pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue], pass_mode: &PassMode) { + pub fn codegen_buffer_store_intrinsic( + &mut self, + fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>, + args: &[SpirvValue], + ) { // Signature: fn store(array: &[u32], index: u32, value: T); + let pass_mode = &fn_abi.unwrap().args.last().unwrap().mode; let is_pair = match pass_mode { // haha shrug PassMode::Ignore => return, diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index 104a98e14a..af14f8b037 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -149,16 +149,10 @@ impl<'tcx> CodegenCx<'tcx> { // FIXME(eddyb) should the maps exist at all, now that the `DefId` is known // at `call` time, and presumably its high-level details can be looked up? if attrs.buffer_load_intrinsic.is_some() { - let mode = &fn_abi.ret.mode; - self.buffer_load_intrinsics - .borrow_mut() - .insert(def_id, mode); + self.buffer_load_intrinsics.borrow_mut().insert(def_id); } if attrs.buffer_store_intrinsic.is_some() { - let mode = &fn_abi.args.last().unwrap().mode; - self.buffer_store_intrinsics - .borrow_mut() - .insert(def_id, mode); + self.buffer_store_intrinsics.borrow_mut().insert(def_id); } // Check for usage of `libm` intrinsics outside of `libm` itself diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index f3811f8a19..c8acc4bc1d 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -34,7 +34,7 @@ use rustc_middle::ty::{self, Instance, Ty, TyCtxt, TypingEnv}; use rustc_session::Session; use rustc_span::symbol::Symbol; use rustc_span::{DUMMY_SP, SourceFile, Span}; -use rustc_target::callconv::{FnAbi, PassMode}; +use rustc_target::callconv::FnAbi; use rustc_target::spec::{HasTargetSpec, Target, TargetTuple}; use std::cell::RefCell; use std::collections::BTreeSet; @@ -80,9 +80,9 @@ pub struct CodegenCx<'tcx> { pub fmt_rt_arg_new_fn_ids_to_ty_and_spec: RefCell, char)>>, /// Intrinsic for loading a `` from a `&[u32]`. The `PassMode` is the mode of the ``. - pub buffer_load_intrinsics: RefCell>, + pub buffer_load_intrinsics: RefCell>, /// Intrinsic for storing a `` into a `&[u32]`. The `PassMode` is the mode of the ``. - pub buffer_store_intrinsics: RefCell>, + pub buffer_store_intrinsics: RefCell>, /// Maps `DefId`s of `From::from` method implementations to their source and target types. /// Used to optimize constant conversions like `u32::from(42u8)` to avoid creating the source type. diff --git a/tests/compiletests/ui/byte_addressable_buffer/scalar_and_pair_mix_load.rs b/tests/compiletests/ui/byte_addressable_buffer/scalar_and_pair_mix_load.rs new file mode 100644 index 0000000000..8c44291935 --- /dev/null +++ b/tests/compiletests/ui/byte_addressable_buffer/scalar_and_pair_mix_load.rs @@ -0,0 +1,32 @@ +// build-pass + +use spirv_std::ByteAddressableBuffer; +use spirv_std::spirv; + +#[derive(Copy, Clone, Debug)] +pub struct MyScalar(i32); + +#[spirv(fragment)] +pub fn load_scalar( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], + scalar: &mut MyScalar, +) { + unsafe { + let mut buf = ByteAddressableBuffer::from_slice(buf); + *scalar = buf.load(5); + } +} + +#[derive(Copy, Clone, Debug)] +pub struct MyScalarPair(i32, i32); + +#[spirv(fragment)] +pub fn load_scalar_pair( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], + scalar_pair: &mut MyScalarPair, +) { + unsafe { + let mut buf = ByteAddressableBuffer::from_slice(buf); + *scalar_pair = buf.load(5); + } +} diff --git a/tests/compiletests/ui/byte_addressable_buffer/scalar_and_pair_mix_load_mut.rs b/tests/compiletests/ui/byte_addressable_buffer/scalar_and_pair_mix_load_mut.rs new file mode 100644 index 0000000000..990392d8ce --- /dev/null +++ b/tests/compiletests/ui/byte_addressable_buffer/scalar_and_pair_mix_load_mut.rs @@ -0,0 +1,32 @@ +// build-pass + +use spirv_std::ByteAddressableBuffer; +use spirv_std::spirv; + +#[derive(Copy, Clone, Debug)] +pub struct MyScalar(i32); + +#[spirv(fragment)] +pub fn load_scalar( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + scalar: &mut MyScalar, +) { + unsafe { + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); + *scalar = buf.load(5); + } +} + +#[derive(Copy, Clone, Debug)] +pub struct MyScalarPair(i32, i32); + +#[spirv(fragment)] +pub fn load_scalar_pair( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + scalar_pair: &mut MyScalarPair, +) { + unsafe { + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); + *scalar_pair = buf.load(5); + } +} diff --git a/tests/compiletests/ui/byte_addressable_buffer/scalar_and_pair_mix_store.rs b/tests/compiletests/ui/byte_addressable_buffer/scalar_and_pair_mix_store.rs new file mode 100644 index 0000000000..d82ffeb106 --- /dev/null +++ b/tests/compiletests/ui/byte_addressable_buffer/scalar_and_pair_mix_store.rs @@ -0,0 +1,33 @@ +// build-pass + +use spirv_std::ByteAddressableBuffer; +use spirv_std::spirv; + +#[derive(Copy, Clone, Debug)] +pub struct MyScalar(i32); + +#[spirv(fragment)] +pub fn store_scalar( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + #[spirv(flat)] scalar: MyScalar, +) { + unsafe { + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); + buf.store(5, scalar); + } +} + +#[derive(Copy, Clone, Debug)] +pub struct MyScalarPair(i32, i32); + +#[spirv(fragment)] +pub fn store_scalar_pair( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + #[spirv(flat)] pair0: i32, + #[spirv(flat)] pair1: i32, +) { + unsafe { + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); + buf.store(5, MyScalarPair(pair0, pair1)); + } +} diff --git a/tests/compiletests/ui/byte_addressable_buffer/scalar_pair.rs b/tests/compiletests/ui/byte_addressable_buffer/scalar_pair.rs new file mode 100644 index 0000000000..40a706836b --- /dev/null +++ b/tests/compiletests/ui/byte_addressable_buffer/scalar_pair.rs @@ -0,0 +1,41 @@ +// build-pass + +use spirv_std::ByteAddressableBuffer; +use spirv_std::spirv; + +#[derive(Copy, Clone, Debug)] +pub struct MyScalarPair(i32, i32); + +#[spirv(fragment)] +pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], + scalar_pair: &mut MyScalarPair, +) { + unsafe { + let buf = ByteAddressableBuffer::from_slice(buf); + *scalar_pair = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn load_mut( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + scalar_pair: &mut MyScalarPair, +) { + unsafe { + let buf = ByteAddressableBuffer::from_mut_slice(buf); + *scalar_pair = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn store( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + #[spirv(flat)] pair0: i32, + #[spirv(flat)] pair1: i32, +) { + unsafe { + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); + buf.store(5, MyScalarPair(pair0, pair1)); + } +}