Skip to content

Commit

Permalink
Fix ByteAddressableBuffer PassMode::Pair (#837)
Browse files Browse the repository at this point in the history
  • Loading branch information
khyperia committed Jan 10, 2022
1 parent b99fc51 commit fe5c771
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 47 deletions.
28 changes: 15 additions & 13 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Expand Up @@ -2188,7 +2188,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
for (argument, argument_type) in args.iter().zip(argument_types) {
assert_ty_eq!(self, argument.ty, argument_type);
}
let libm_intrinsic = self.libm_intrinsics.borrow().get(&callee_val).cloned();
let libm_intrinsic = self.libm_intrinsics.borrow().get(&callee_val).copied();
let buffer_load_intrinsic = self
.buffer_load_intrinsic_fn_id
.borrow()
.get(&callee_val)
.copied();
let buffer_store_intrinsic = self
.buffer_store_intrinsic_fn_id
.borrow()
.get(&callee_val)
.copied();
if let Some(libm_intrinsic) = libm_intrinsic {
let result = self.call_libm_intrinsic(libm_intrinsic, result_type, args);
if result_type != result.ty {
Expand All @@ -2207,18 +2217,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// needing to materialize `&core::panic::Location` or `format_args!`.
self.abort();
self.undef(result_type)
} else if self
.buffer_load_intrinsic_fn_id
.borrow()
.contains(&callee_val)
{
self.codegen_buffer_load_intrinsic(result_type, args)
} else if self
.buffer_store_intrinsic_fn_id
.borrow()
.contains(&callee_val)
{
self.codegen_buffer_store_intrinsic(args);
} else if let Some(mode) = buffer_load_intrinsic {
self.codegen_buffer_load_intrinsic(result_type, args, mode)
} else if let Some(mode) = buffer_store_intrinsic {
self.codegen_buffer_store_intrinsic(args, mode);

let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self);
SpirvValue {
Expand Down
85 changes: 68 additions & 17 deletions crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs
@@ -1,10 +1,12 @@
use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind};
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use rustc_codegen_ssa::traits::{BaseTypeMethods, BuilderMethods};
use rustc_errors::ErrorReported;
use rustc_span::DUMMY_SP;
use rustc_target::abi::Align;
use rustc_target::abi::call::PassMode;
use rustc_target::abi::{Align, Size};

impl<'a, 'tcx> Builder<'a, 'tcx> {
fn load_err(&mut self, original_type: Word, invalid_type: Word) -> SpirvValue {
Expand Down Expand Up @@ -168,7 +170,25 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
&mut self,
result_type: Word,
args: &[SpirvValue],
pass_mode: PassMode,
) -> SpirvValue {
match pass_mode {
PassMode::Ignore => {
return SpirvValue {
kind: SpirvValueKind::IllegalTypeUsed(result_type),
ty: result_type,
}
}
// PassMode::Pair is identical to PassMode::Direct - it's returned as a struct
PassMode::Direct(_) | PassMode::Pair(_, _) => (),
PassMode::Cast(_) => {
self.fatal("PassMode::Cast not supported in codegen_buffer_load_intrinsic")
}
PassMode::Indirect { .. } => {
self.fatal("PassMode::Indirect not supported in codegen_buffer_load_intrinsic")
}
}

// Signature: fn load<T>(array: &[u32], index: u32) -> T;
if args.len() != 3 {
self.fatal(&format!(
Expand All @@ -184,15 +204,16 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
self.recurse_load_type(result_type, result_type, array, word_index, 0)
}

fn store_err(&mut self, original_type: Word, value: SpirvValue) {
fn store_err(&mut self, original_type: Word, value: SpirvValue) -> Result<(), ErrorReported> {
let mut err = self.struct_err(&format!(
"Cannot load type {} in an untyped buffer store",
"Cannot store type {} in an untyped buffer store",
self.debug_type(original_type)
));
if original_type != value.ty {
err.note(&format!("due to containing type {}", value.ty));
}
err.emit();
Err(ErrorReported)
}

fn store_u32(
Expand All @@ -201,7 +222,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
dynamic_index: SpirvValue,
constant_offset: u32,
value: SpirvValue,
) {
) -> Result<(), ErrorReported> {
let actual_index = if constant_offset != 0 {
let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset);
self.add(dynamic_index, const_offset_val)
Expand All @@ -216,6 +237,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.unwrap()
.with_type(u32_ptr);
self.store(value, ptr, Align::ONE);
Ok(())
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -228,7 +250,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
constant_word_offset: u32,
element: Word,
count: u32,
) {
) -> Result<(), ErrorReported> {
let element_size_bytes = match self.lookup_type(element).sizeof(self) {
Some(size) => size,
None => return self.store_err(original_type, value),
Expand All @@ -245,8 +267,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
array,
dynamic_word_index,
constant_word_offset + element_size_words * index,
);
)?;
}
Ok(())
}

fn recurse_store_type(
Expand All @@ -256,17 +279,17 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
array: SpirvValue,
dynamic_word_index: SpirvValue,
constant_word_offset: u32,
) {
) -> Result<(), ErrorReported> {
match self.lookup_type(value.ty) {
SpirvType::Integer(32, signed) => {
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let value_u32 = self.intcast(value, u32_ty, signed);
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32);
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
}
SpirvType::Float(32) => {
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let value_u32 = self.bitcast(value, u32_ty);
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32);
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
}
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
.store_vec_mat_arr(
Expand All @@ -291,7 +314,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
constant_word_offset,
element,
count,
);
)
}
SpirvType::Adt {
size: Some(_),
Expand All @@ -310,20 +333,35 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
array,
dynamic_word_index,
constant_word_offset + word_offset,
);
)?;
}
Ok(())
}

_ => self.store_err(original_type, value),
}
}

/// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller.
pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue]) {
pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue], pass_mode: PassMode) {
// Signature: fn store<T>(array: &[u32], index: u32, value: T);
if args.len() != 4 {
let is_pair = match pass_mode {
// haha shrug
PassMode::Ignore => return,
PassMode::Direct(_) => false,
PassMode::Pair(_, _) => true,
PassMode::Cast(_) => {
self.fatal("PassMode::Cast not supported in codegen_buffer_store_intrinsic")
}
PassMode::Indirect { .. } => {
self.fatal("PassMode::Indirect not supported in codegen_buffer_store_intrinsic")
}
};
let expected_args = if is_pair { 5 } else { 4 };
if args.len() != expected_args {
self.fatal(&format!(
"buffer_store_intrinsic should have 4 args, it has {}",
"buffer_store_intrinsic should have {} args, it has {}",
expected_args,
args.len()
));
}
Expand All @@ -332,7 +370,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let byte_index = args[2];
let two = self.constant_u32(DUMMY_SP, 2);
let word_index = self.lshr(byte_index, two);
let value = args[3];
self.recurse_store_type(value.ty, value, array, word_index, 0);
if is_pair {
let value_one = args[3];
let value_two = args[4];
let one_result = self.recurse_store_type(value_one.ty, value_one, array, word_index, 0);

let size_of_one = self.lookup_type(value_one.ty).sizeof(self);
if one_result.is_ok() && size_of_one != Some(Size::from_bytes(4)) {
self.fatal("Expected PassMode::Pair first element to have size 4");
}

let _ = self.recurse_store_type(value_two.ty, value_two, array, word_index, 1);
} else {
let value = args[3];
let _ = self.recurse_store_type(value.ty, value, array, word_index, 0);
}
}
}
10 changes: 8 additions & 2 deletions crates/rustc_codegen_spirv/src/codegen_cx/declare.rs
Expand Up @@ -120,10 +120,16 @@ impl<'tcx> CodegenCx<'tcx> {
self.unroll_loops_decorations.borrow_mut().insert(fn_id);
}
if attrs.buffer_load_intrinsic.is_some() {
self.buffer_load_intrinsic_fn_id.borrow_mut().insert(fn_id);
let mode = fn_abi.ret.mode;
self.buffer_load_intrinsic_fn_id
.borrow_mut()
.insert(fn_id, mode);
}
if attrs.buffer_store_intrinsic.is_some() {
self.buffer_store_intrinsic_fn_id.borrow_mut().insert(fn_id);
let mode = fn_abi.args.last().unwrap().mode;
self.buffer_store_intrinsic_fn_id
.borrow_mut()
.insert(fn_id, mode);
}

let instance_def_id = instance.def_id();
Expand Down
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Expand Up @@ -66,7 +66,7 @@ impl<'tcx> CodegenCx<'tcx> {
}
// FIXME(eddyb) support these (by just ignoring them) - if there
// is any validation concern, it should be done on the types.
PassMode::Ignore => self.tcx.sess.span_err(
PassMode::Ignore => self.tcx.sess.span_fatal(
hir_param.ty_span,
&format!(
"entry point parameter type not yet supported \
Expand Down
10 changes: 5 additions & 5 deletions crates/rustc_codegen_spirv/src/codegen_cx/mod.rs
Expand Up @@ -29,7 +29,7 @@ use rustc_session::Session;
use rustc_span::def_id::{DefId, LOCAL_CRATE};
use rustc_span::symbol::{sym, Symbol};
use rustc_span::{SourceFile, Span, DUMMY_SP};
use rustc_target::abi::call::FnAbi;
use rustc_target::abi::call::{FnAbi, PassMode};
use rustc_target::abi::{HasDataLayout, TargetDataLayout};
use rustc_target::spec::{HasTargetSpec, Target};
use std::cell::{Cell, RefCell};
Expand Down Expand Up @@ -66,10 +66,10 @@ pub struct CodegenCx<'tcx> {

/// Simple `panic!("...")` and builtin panics (from MIR `Assert`s) call `#[lang = "panic"]`.
pub panic_fn_id: Cell<Option<Word>>,
/// Intrinsic for loading a <T> from a &[u32]
pub buffer_load_intrinsic_fn_id: RefCell<FxHashSet<Word>>,
/// Intrinsic for storing a <T> into a &[u32]
pub buffer_store_intrinsic_fn_id: RefCell<FxHashSet<Word>>,
/// Intrinsic for loading a <T> from a &[u32]. The PassMode is the mode of the <T>.
pub buffer_load_intrinsic_fn_id: RefCell<FxHashMap<Word, PassMode>>,
/// Intrinsic for storing a <T> into a &[u32]. The PassMode is the mode of the <T>.
pub buffer_store_intrinsic_fn_id: RefCell<FxHashMap<Word, PassMode>>,
/// Builtin bounds-checking panics (from MIR `Assert`s) call `#[lang = "panic_bounds_check"]`.
pub panic_bounds_check_fn_id: Cell<Option<Word>>,

Expand Down
8 changes: 2 additions & 6 deletions crates/spirv-std/src/byte_addressable_buffer.rs
Expand Up @@ -5,18 +5,14 @@ use core::mem;
#[spirv(buffer_load_intrinsic)]
#[spirv_std_macros::gpu_only]
#[allow(improper_ctypes_definitions)]
unsafe extern "unadjusted" fn buffer_load_intrinsic<T>(_buffer: &[u32], _offset: u32) -> T {
unsafe fn buffer_load_intrinsic<T>(_buffer: &[u32], _offset: u32) -> T {
unimplemented!()
} // actually implemented in the compiler

#[spirv(buffer_store_intrinsic)]
#[spirv_std_macros::gpu_only]
#[allow(improper_ctypes_definitions)]
unsafe extern "unadjusted" fn buffer_store_intrinsic<T>(
_buffer: &mut [u32],
_offset: u32,
_value: T,
) {
unsafe fn buffer_store_intrinsic<T>(_buffer: &mut [u32], _offset: u32, _value: T) {
unimplemented!()
} // actually implemented in the compiler

Expand Down
1 change: 0 additions & 1 deletion crates/spirv-std/src/lib.rs
Expand Up @@ -2,7 +2,6 @@
#![cfg_attr(
target_arch = "spirv",
feature(
abi_unadjusted,
asm,
asm_const,
asm_experimental_arch,
Expand Down
23 changes: 23 additions & 0 deletions tests/README.md
@@ -0,0 +1,23 @@
# Compiletests

This folder contains tests known as "compiletests". Each file in the `ui` folder corresponds to a
single compiletest. The way they work is a tool iterates over every file, and tries to compile it.
At the start of the file, there's some meta-comments about the expected result of the compile:
whether it should succeed compilation, or fail. If it is expected to fail, there's a corresponding
.stderr file next to the file that contains the expected compiler error message.

The `src` folder here is the tool that iterates over every file in the `ui` folder. It uses the
`compiletests` library, taken from rustc's own compiletest framework.

You can run compiletests via `cargo compiletests`. This is an alias set up in `.cargo/config` for
`cargo run --release -p compiletests --`. You can filter to run specific tests by passing the
(partial) filenames to `cargo compiletests some_file_name`, and update the `.stderr` files to
contain new output via the `--bless` flag (with `--bless`, make sure you're actually supposed to be
changing the .stderr files due to an intentional change, and hand-validate the output is correct
afterwards).

Keep in mind that tests here here are not executed, merely checked for errors (including validating
the resulting binary with spirv-val). Because of this, there might be some strange code in here -
the point isn't to make a fully functional shader every time (that would take an annoying amount of
effort), but rather validate that specific parts of the compiler are doing their job correctly
(either succeeding as they should, or erroring as they should).
4 changes: 2 additions & 2 deletions tests/ui/arch/debug_printf_type_checking.stderr
Expand Up @@ -96,9 +96,9 @@ error[E0277]: the trait bound `{float}: Vector<f32, 2_usize>` is not satisfied
<DVec2 as Vector<f64, 2_usize>>
and 13 others
note: required by a bound in `debug_printf_assert_is_vector`
--> $SPIRV_STD_SRC/lib.rs:146:8
--> $SPIRV_STD_SRC/lib.rs:145:8
|
146 | V: crate::vector::Vector<TY, SIZE>,
145 | V: crate::vector::Vector<TY, SIZE>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector`

error[E0308]: mismatched types
Expand Down
25 changes: 25 additions & 0 deletions tests/ui/byte_addressable_buffer/empty_struct.rs
@@ -0,0 +1,25 @@
// build-pass

use spirv_std::ByteAddressableBuffer;

pub struct EmptyStruct {}

#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
#[spirv(flat)] out: &mut EmptyStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
*out = buf.load(5);
}
}

#[spirv(fragment)]
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32]) {
let val = EmptyStruct {};
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);
}
}

0 comments on commit fe5c771

Please sign in to comment.