diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 98e076bf5fe..11bf33ae55d 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -492,6 +492,11 @@ pub fn link( simple_passes::remove_non_uniform_decorations(sess, &mut output)?; } + { + let _timer = sess.timer("link_promote_int8_to_int32"); + simple_passes::promote_int8_to_int32(&mut output); + } + // NOTE(eddyb) SPIR-T pipeline is entirely limited to this block. { let (spv_words, module_or_err, lower_from_spv_timer) = diff --git a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs index ebd37a9ad78..9da1b49c260 100644 --- a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs +++ b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs @@ -1,5 +1,5 @@ use super::{get_name, get_names}; -use rspirv::dr::{Block, Function, Module}; +use rspirv::dr::{Block, Function, Module, Operand}; use rspirv::spirv::{Decoration, ExecutionModel, Op, Word}; use rustc_codegen_spirv_types::Capability; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; @@ -365,3 +365,79 @@ pub fn remove_non_uniform_decorations(_sess: &Session, module: &mut Module) -> s } Ok(()) } + +/// When `OpCapability Int8` is not declared, promote all implicit `i8`/`u8` types to `i32`/`u32`. +pub fn promote_int8_to_int32(module: &mut Module) { + let has_int8 = module.capabilities.iter().any(|inst| { + inst.class.opcode == Op::Capability + && inst.operands[0].unwrap_capability() == Capability::Int8 + }); + if has_int8 { + return; + } + + let narrow_types: FxHashMap = module + .types_global_values + .iter() + .filter_map(|inst| { + if inst.class.opcode == Op::TypeInt && inst.operands[0].unwrap_literal_bit32() == 8 { + let signedness = inst.operands[1].unwrap_literal_bit32(); + Some((inst.result_id?, signedness)) + } else { + None + } + }) + .collect(); + + if narrow_types.is_empty() { + return; + } + + // skip any 8-bit type that is used as the element type of an OpTypePointer. + // such types are explicit interface/storage types chosen by the user + let pointer_element_types: FxHashSet = module + .types_global_values + .iter() + .filter_map(|inst| { + if inst.class.opcode == Op::TypePointer { + // operands: [StorageClass, element_type_id] + Some(inst.operands[1].unwrap_id_ref()) + } else { + None + } + }) + .collect(); + + let narrow_types: FxHashMap = narrow_types + .into_iter() + .filter(|(id, _)| !pointer_element_types.contains(id)) + .collect(); + + if narrow_types.is_empty() { + return; + } + + for inst in &mut module.types_global_values { + // widen each 8-bit OpTypeInt to 32 bits + if inst.class.opcode == Op::TypeInt + && let Some(id) = inst.result_id + && narrow_types.contains_key(&id) + { + inst.operands[0] = Operand::LiteralBit32(32); + } + + // fix OpConstant values: sign-extend signed 8-bit constants to 32 bits. + if inst.class.opcode == Op::Constant + && let Some(ty) = inst.result_type + && let Some(&signedness) = narrow_types.get(&ty) + && let Operand::LiteralBit32(ref mut val) = inst.operands[0] + { + let narrow = *val as u8; + *val = if signedness != 0 { + (narrow as i8 as i32) as u32 + } else { + narrow as u32 + }; + } + } +} diff --git a/tests/compiletests/ui/lang/core/promote-u8-to-u32.rs b/tests/compiletests/ui/lang/core/promote-u8-to-u32.rs new file mode 100644 index 00000000000..03dfb410c3c --- /dev/null +++ b/tests/compiletests/ui/lang/core/promote-u8-to-u32.rs @@ -0,0 +1,27 @@ +// build-pass +//PartialOrd on CustomPosition(u32) internally returns Option, +//where Ordering is represented as i8 in Rust's layout. +//This caused rust-gpu to emit OpTypeInt 8 declarations requiring OpCapability Int8 +#![no_std] + +use spirv_std::{glam::Vec4, spirv}; + +pub struct ShaderInputs { + pub x: CustomPosition, + pub y: CustomPosition, +} + +#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq)] +pub struct CustomPosition(u32); + +#[spirv(vertex)] +pub fn test_vs( + #[spirv(push_constant)] inputs: &ShaderInputs, + #[spirv(position)] out_pos: &mut Vec4, +) { + let mut result: f32 = 0.; + if inputs.x < inputs.y { + result = 1.0; + } + *out_pos = Vec4::new(inputs.x.0 as f32, inputs.y.0 as f32, result as f32, 1.0); +}