From 9f019f6aefcf08c54cc16c94e67ab0be9007e70d Mon Sep 17 00:00:00 2001 From: firestar99 Date: Tue, 21 Oct 2025 16:56:37 +0200 Subject: [PATCH] SpecConstant: add arrayed spec constants --- crates/rustc_codegen_spirv/src/attr.rs | 3 + .../src/codegen_cx/entry.rs | 82 +++++++++++++------ .../ui/dis/spec_constant_array.rs | 28 +++++++ .../ui/dis/spec_constant_array.stderr | 38 +++++++++ 4 files changed, 128 insertions(+), 23 deletions(-) create mode 100644 tests/compiletests/ui/dis/spec_constant_array.rs create mode 100644 tests/compiletests/ui/dis/spec_constant_array.stderr diff --git a/crates/rustc_codegen_spirv/src/attr.rs b/crates/rustc_codegen_spirv/src/attr.rs index 9be755e687..e56f96c2a5 100644 --- a/crates/rustc_codegen_spirv/src/attr.rs +++ b/crates/rustc_codegen_spirv/src/attr.rs @@ -75,6 +75,7 @@ pub enum IntrinsicType { pub struct SpecConstant { pub id: u32, pub default: Option, + pub array_count: Option, } // NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside @@ -661,6 +662,8 @@ fn parse_spec_constant_attr( Ok(SpecConstant { id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?, default, + // to be set later + array_count: None, }) } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 5de823d261..0c16adfd6c 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -11,6 +11,7 @@ use rspirv::dr::Operand; use rspirv::spirv::{ Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word, }; +use rustc_abi::FieldsShape; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _}; use rustc_data_structures::fx::FxHashMap; use rustc_errors::MultiSpan; @@ -18,7 +19,7 @@ use rustc_hir as hir; use rustc_middle::span_bug; use rustc_middle::ty::layout::{LayoutOf, TyAndLayout}; use rustc_middle::ty::{self, Instance, Ty}; -use rustc_span::Span; +use rustc_span::{DUMMY_SP, Span}; use rustc_target::callconv::{ArgAbi, FnAbi, PassMode}; use std::assert_matches::assert_matches; @@ -395,23 +396,38 @@ impl<'tcx> CodegenCx<'tcx> { // would've assumed it was actually an implicitly-`Input`. let mut storage_class = Ok(storage_class); if let Some(spec_constant) = attrs.spec_constant { - if ref_or_value_layout.ty != self.tcx.types.u32 { + let ty = ref_or_value_layout; + let valid_array_count = match ty.fields { + FieldsShape::Array { count, .. } => { + let element = ty.field(self, 0); + (element.ty == self.tcx.types.u32).then_some(u32::try_from(count).ok()) + } + FieldsShape::Primitive => (ty.ty == self.tcx.types.u32).then_some(None), + _ => None, + }; + + if let Some(array_count) = valid_array_count { + if let Some(storage_class) = attrs.storage_class { + self.tcx.dcx().span_err( + storage_class.span, + "`#[spirv(spec_constant)]` cannot have a storage class", + ); + } else { + assert_eq!(storage_class, Ok(StorageClass::Input)); + assert!(!is_ref); + storage_class = Err(SpecConstant { + array_count, + ..spec_constant.value + }); + } + } else { self.tcx.dcx().span_err( hir_param.ty_span, format!( - "unsupported `#[spirv(spec_constant)]` type `{}` (expected `{}`)", - ref_or_value_layout.ty, self.tcx.types.u32 + "unsupported `#[spirv(spec_constant)]` type `{}` (expected `u32` or `[u32; N]`)", + ref_or_value_layout.ty ), ); - } else if let Some(storage_class) = attrs.storage_class { - self.tcx.dcx().span_err( - storage_class.span, - "`#[spirv(spec_constant)]` cannot have a storage class", - ); - } else { - assert_eq!(storage_class, Ok(StorageClass::Input)); - assert!(!is_ref); - storage_class = Err(spec_constant.value); } } @@ -448,18 +464,38 @@ impl<'tcx> CodegenCx<'tcx> { Ok(self.emit_global().id()), Err("entry-point interface variable is not a `#[spirv(spec_constant)]`"), ), - Err(SpecConstant { id, default }) => { - let mut emit = self.emit_global(); - let spec_const_id = - emit.spec_constant_bit32(value_spirv_type, default.unwrap_or(0)); - emit.decorate( - spec_const_id, - Decoration::SpecId, - [Operand::LiteralBit32(id)], - ); + Err(SpecConstant { + id, + default, + array_count, + }) => { + let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self); + let single = |id: u32| { + let mut emit = self.emit_global(); + let spec_const_id = emit.spec_constant_bit32(u32_ty, default.unwrap_or(0)); + emit.decorate( + spec_const_id, + Decoration::SpecId, + [Operand::LiteralBit32(id)], + ); + spec_const_id + }; + let param_word = if let Some(array_count) = array_count { + let array = (0..array_count).map(|i| single(id + i)).collect::>(); + let array_ty = SpirvType::Array { + element: u32_ty, + count: self.constant_u32(DUMMY_SP, array_count), + } + .def(DUMMY_SP, self); + bx.emit() + .composite_construct(array_ty, None, array) + .unwrap() + } else { + single(id) + }; ( Err("`#[spirv(spec_constant)]` is not an entry-point interface variable"), - Ok(spec_const_id), + Ok(param_word), ) } }; diff --git a/tests/compiletests/ui/dis/spec_constant_array.rs b/tests/compiletests/ui/dis/spec_constant_array.rs new file mode 100644 index 0000000000..8c70003051 --- /dev/null +++ b/tests/compiletests/ui/dis/spec_constant_array.rs @@ -0,0 +1,28 @@ +// Tests the various forms of `#[spirv(spec_constant)]`. + +// build-pass +// ignore-spv1.0 +// ignore-spv1.1 +// ignore-spv1.2 +// ignore-spv1.3 +// ignore-vulkan1.0 +// ignore-vulkan1.1 + +// compile-flags: -C llvm-args=--disassemble +// normalize-stderr-test "; .*\n" -> "" +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" + +// HACK(eddyb) `compiletest` handles `ui\dis\`, but not `ui\\dis\\`, on Windows. +// normalize-stderr-test "ui/dis/" -> "$$DIR/" + +use spirv_std::spirv; + +#[spirv(compute(threads(1)))] +pub fn main( + #[spirv(spec_constant(id = 42, default = 69))] array: [u32; 4], + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut u32, +) { + *out = array[0] + array[1] + array[2] + array[3]; +} diff --git a/tests/compiletests/ui/dis/spec_constant_array.stderr b/tests/compiletests/ui/dis/spec_constant_array.stderr new file mode 100644 index 0000000000..9533cdb38a --- /dev/null +++ b/tests/compiletests/ui/dis/spec_constant_array.stderr @@ -0,0 +1,38 @@ +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "main" %2 +OpExecutionMode %1 LocalSize 1 1 1 +%3 = OpString "$DIR/spec_constant_array.rs" +OpDecorate %4 Block +OpMemberDecorate %4 0 Offset 0 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpDecorate %5 SpecId 42 +OpDecorate %6 SpecId 43 +OpDecorate %7 SpecId 44 +OpDecorate %8 SpecId 45 +%9 = OpTypeInt 32 0 +%4 = OpTypeStruct %9 +%10 = OpTypePointer StorageBuffer %4 +%11 = OpTypeVoid +%12 = OpTypeFunction %11 +%13 = OpTypePointer StorageBuffer %9 +%2 = OpVariable %10 StorageBuffer +%14 = OpConstant %9 0 +%5 = OpSpecConstant %9 69 +%6 = OpSpecConstant %9 69 +%7 = OpSpecConstant %9 69 +%8 = OpSpecConstant %9 69 +%1 = OpFunction %11 None %12 +%15 = OpLabel +OpLine %3 25 4 +%16 = OpInBoundsAccessChain %13 %2 %14 +OpLine %3 27 11 +%17 = OpIAdd %9 %5 %6 +%18 = OpIAdd %9 %17 %7 +OpLine %3 27 4 +%19 = OpIAdd %9 %18 %8 +OpStore %16 %19 +OpNoLine +OpReturn +OpFunctionEnd