Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3247,7 +3247,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
&mut self,
callee_ty: Self::Type,
_fn_attrs: Option<&CodegenFnAttrs>,
_fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>,
fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope just using the one given to me is correct here, previously the FnAbi was computed from:

let fn_abi = self.fn_abi_of_instance(instance, ty::List::empty());

But it was in a function that wasn't given an FnAbi but a rustc_middle::ty::Instance, which I never interacted with.

callee: Self::Value,
args: &[Self::Value],
funclet: Option<&Self::Funclet>,
Expand Down Expand Up @@ -3310,9 +3310,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
let libm_intrinsic =
instance_def_id.and_then(|def_id| self.libm_intrinsics.borrow().get(&def_id).copied());
let buffer_load_intrinsic = instance_def_id
.and_then(|def_id| self.buffer_load_intrinsics.borrow().get(&def_id).copied());
.is_some_and(|def_id| self.buffer_load_intrinsics.borrow().contains(&def_id));
let buffer_store_intrinsic = instance_def_id
.and_then(|def_id| self.buffer_store_intrinsics.borrow().get(&def_id).copied());
.is_some_and(|def_id| self.buffer_store_intrinsics.borrow().contains(&def_id));
let is_panic_entry_point = instance_def_id
.is_some_and(|def_id| self.panic_entry_points.borrow().contains(&def_id));
let from_trait_impl =
Expand Down Expand Up @@ -4101,14 +4101,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
self.abort_with_kind_and_message_debug_printf("panic", message, debug_printf_args);
return self.undef(result_type);
}

if let Some(mode) = buffer_load_intrinsic {
return self.codegen_buffer_load_intrinsic(result_type, args, mode);
if buffer_load_intrinsic {
return self.codegen_buffer_load_intrinsic(fn_abi, result_type, args);
}

if let Some(mode) = buffer_store_intrinsic {
self.codegen_buffer_store_intrinsic(args, mode);

if buffer_store_intrinsic {
self.codegen_buffer_store_intrinsic(fn_abi, args);
let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self);
return SpirvValue {
kind: SpirvValueKind::IllegalTypeUsed(void_ty),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ use rustc_abi::{Align, Size};
use rustc_codegen_spirv_types::Capability;
use rustc_codegen_ssa::traits::BuilderMethods;
use rustc_errors::ErrorGuaranteed;
use rustc_middle::ty::Ty;
use rustc_span::DUMMY_SP;
use rustc_target::callconv::PassMode;
use rustc_target::callconv::{FnAbi, PassMode};

impl<'a, 'tcx> Builder<'a, 'tcx> {
fn load_err(&mut self, original_type: Word, invalid_type: Word) -> SpirvValue {
Expand Down Expand Up @@ -181,10 +182,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
/// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller.
pub fn codegen_buffer_load_intrinsic(
&mut self,
fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>,
result_type: Word,
args: &[SpirvValue],
pass_mode: &PassMode,
) -> SpirvValue {
let pass_mode = &fn_abi.unwrap().ret.mode;
match pass_mode {
PassMode::Ignore => {
return SpirvValue {
Expand Down Expand Up @@ -364,8 +366,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}

/// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller.
pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue], pass_mode: &PassMode) {
pub fn codegen_buffer_store_intrinsic(
&mut self,
fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>,
args: &[SpirvValue],
) {
// Signature: fn store<T>(array: &[u32], index: u32, value: T);
let pass_mode = &fn_abi.unwrap().args.last().unwrap().mode;
let is_pair = match pass_mode {
// haha shrug
PassMode::Ignore => return,
Expand Down
10 changes: 2 additions & 8 deletions crates/rustc_codegen_spirv/src/codegen_cx/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,10 @@ impl<'tcx> CodegenCx<'tcx> {
// FIXME(eddyb) should the maps exist at all, now that the `DefId` is known
// at `call` time, and presumably its high-level details can be looked up?
if attrs.buffer_load_intrinsic.is_some() {
let mode = &fn_abi.ret.mode;
self.buffer_load_intrinsics
.borrow_mut()
.insert(def_id, mode);
self.buffer_load_intrinsics.borrow_mut().insert(def_id);
}
if attrs.buffer_store_intrinsic.is_some() {
let mode = &fn_abi.args.last().unwrap().mode;
self.buffer_store_intrinsics
.borrow_mut()
.insert(def_id, mode);
Comment on lines -158 to -161
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, so being keyed on the DefId of a generic function meant that the last instance would win, and every call would be expected to conform to it.

self.buffer_store_intrinsics.borrow_mut().insert(def_id);
}

// Check for usage of `libm` intrinsics outside of `libm` itself
Expand Down
6 changes: 3 additions & 3 deletions crates/rustc_codegen_spirv/src/codegen_cx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use rustc_middle::ty::{self, Instance, Ty, TyCtxt, TypingEnv};
use rustc_session::Session;
use rustc_span::symbol::Symbol;
use rustc_span::{DUMMY_SP, SourceFile, Span};
use rustc_target::callconv::{FnAbi, PassMode};
use rustc_target::callconv::FnAbi;
use rustc_target::spec::{HasTargetSpec, Target, TargetTuple};
use std::cell::RefCell;
use std::collections::BTreeSet;
Expand Down Expand Up @@ -80,9 +80,9 @@ pub struct CodegenCx<'tcx> {
pub fmt_rt_arg_new_fn_ids_to_ty_and_spec: RefCell<FxHashMap<Word, (Ty<'tcx>, char)>>,

/// Intrinsic for loading a `<T>` from a `&[u32]`. The `PassMode` is the mode of the `<T>`.
pub buffer_load_intrinsics: RefCell<FxHashMap<DefId, &'tcx PassMode>>,
pub buffer_load_intrinsics: RefCell<FxHashSet<DefId>>,
/// Intrinsic for storing a `<T>` into a `&[u32]`. The `PassMode` is the mode of the `<T>`.
pub buffer_store_intrinsics: RefCell<FxHashMap<DefId, &'tcx PassMode>>,
pub buffer_store_intrinsics: RefCell<FxHashSet<DefId>>,

/// Maps `DefId`s of `From::from` method implementations to their source and target types.
/// Used to optimize constant conversions like `u32::from(42u8)` to avoid creating the source type.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// build-pass

use spirv_std::ByteAddressableBuffer;
use spirv_std::spirv;

#[derive(Copy, Clone, Debug)]
pub struct MyScalar(i32);

#[spirv(fragment)]
pub fn load_scalar(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32],
scalar: &mut MyScalar,
) {
unsafe {
let mut buf = ByteAddressableBuffer::from_slice(buf);
*scalar = buf.load(5);
}
}

#[derive(Copy, Clone, Debug)]
pub struct MyScalarPair(i32, i32);

#[spirv(fragment)]
pub fn load_scalar_pair(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32],
scalar_pair: &mut MyScalarPair,
) {
unsafe {
let mut buf = ByteAddressableBuffer::from_slice(buf);
*scalar_pair = buf.load(5);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// build-pass

use spirv_std::ByteAddressableBuffer;
use spirv_std::spirv;

#[derive(Copy, Clone, Debug)]
pub struct MyScalar(i32);

#[spirv(fragment)]
pub fn load_scalar(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
scalar: &mut MyScalar,
) {
unsafe {
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
*scalar = buf.load(5);
}
}

#[derive(Copy, Clone, Debug)]
pub struct MyScalarPair(i32, i32);

#[spirv(fragment)]
pub fn load_scalar_pair(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
scalar_pair: &mut MyScalarPair,
) {
unsafe {
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
*scalar_pair = buf.load(5);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// build-pass

use spirv_std::ByteAddressableBuffer;
use spirv_std::spirv;

#[derive(Copy, Clone, Debug)]
pub struct MyScalar(i32);

#[spirv(fragment)]
pub fn store_scalar(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
#[spirv(flat)] scalar: MyScalar,
) {
unsafe {
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
buf.store(5, scalar);
}
}

#[derive(Copy, Clone, Debug)]
pub struct MyScalarPair(i32, i32);

#[spirv(fragment)]
pub fn store_scalar_pair(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
#[spirv(flat)] pair0: i32,
#[spirv(flat)] pair1: i32,
) {
unsafe {
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
buf.store(5, MyScalarPair(pair0, pair1));
}
}
41 changes: 41 additions & 0 deletions tests/compiletests/ui/byte_addressable_buffer/scalar_pair.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// build-pass

use spirv_std::ByteAddressableBuffer;
use spirv_std::spirv;

#[derive(Copy, Clone, Debug)]
pub struct MyScalarPair(i32, i32);

#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32],
scalar_pair: &mut MyScalarPair,
) {
unsafe {
let buf = ByteAddressableBuffer::from_slice(buf);
*scalar_pair = buf.load(5);
}
}

#[spirv(fragment)]
pub fn load_mut(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
scalar_pair: &mut MyScalarPair,
) {
unsafe {
let buf = ByteAddressableBuffer::from_mut_slice(buf);
*scalar_pair = buf.load(5);
}
}

#[spirv(fragment)]
pub fn store(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
#[spirv(flat)] pair0: i32,
#[spirv(flat)] pair1: i32,
) {
unsafe {
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
buf.store(5, MyScalarPair(pair0, pair1));
}
}