Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/rustc_codegen_spirv/src/linker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ pub fn link(
simple_passes::remove_non_uniform_decorations(sess, &mut output)?;
}

{
let _timer = sess.timer("link_promote_int8_to_int32");
simple_passes::promote_int8_to_int32(&mut output);
}

// NOTE(eddyb) SPIR-T pipeline is entirely limited to this block.
{
let (spv_words, module_or_err, lower_from_spv_timer) =
Expand Down
78 changes: 77 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/simple_passes.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{get_name, get_names};
use rspirv::dr::{Block, Function, Module};
use rspirv::dr::{Block, Function, Module, Operand};
use rspirv::spirv::{Decoration, ExecutionModel, Op, Word};
use rustc_codegen_spirv_types::Capability;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
Expand Down Expand Up @@ -365,3 +365,79 @@ pub fn remove_non_uniform_decorations(_sess: &Session, module: &mut Module) -> s
}
Ok(())
}

/// When `OpCapability Int8` is not declared, promote all implicit `i8`/`u8` types to `i32`/`u32`.
pub fn promote_int8_to_int32(module: &mut Module) {
let has_int8 = module.capabilities.iter().any(|inst| {
inst.class.opcode == Op::Capability
&& inst.operands[0].unwrap_capability() == Capability::Int8
});
if has_int8 {
return;
}

let narrow_types: FxHashMap<Word, u32> = module
.types_global_values
.iter()
.filter_map(|inst| {
if inst.class.opcode == Op::TypeInt && inst.operands[0].unwrap_literal_bit32() == 8 {
let signedness = inst.operands[1].unwrap_literal_bit32();
Some((inst.result_id?, signedness))
} else {
None
}
})
.collect();

if narrow_types.is_empty() {
return;
}

// skip any 8-bit type that is used as the element type of an OpTypePointer.
// such types are explicit interface/storage types chosen by the user
let pointer_element_types: FxHashSet<Word> = module
.types_global_values
.iter()
.filter_map(|inst| {
if inst.class.opcode == Op::TypePointer {
// operands: [StorageClass, element_type_id]
Some(inst.operands[1].unwrap_id_ref())
} else {
None
}
})
.collect();

let narrow_types: FxHashMap<Word, u32> = narrow_types
.into_iter()
.filter(|(id, _)| !pointer_element_types.contains(id))
.collect();

if narrow_types.is_empty() {
return;
}

for inst in &mut module.types_global_values {
// widen each 8-bit OpTypeInt to 32 bits
if inst.class.opcode == Op::TypeInt
&& let Some(id) = inst.result_id
&& narrow_types.contains_key(&id)
{
inst.operands[0] = Operand::LiteralBit32(32);
}

// fix OpConstant values: sign-extend signed 8-bit constants to 32 bits.
if inst.class.opcode == Op::Constant
&& let Some(ty) = inst.result_type
&& let Some(&signedness) = narrow_types.get(&ty)
&& let Operand::LiteralBit32(ref mut val) = inst.operands[0]
{
let narrow = *val as u8;
*val = if signedness != 0 {
(narrow as i8 as i32) as u32
} else {
narrow as u32
};
}
}
}
27 changes: 27 additions & 0 deletions tests/compiletests/ui/lang/core/promote-u8-to-u32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// build-pass
//PartialOrd on CustomPosition(u32) internally returns Option<Ordering>,
//where Ordering is represented as i8 in Rust's layout.
//This caused rust-gpu to emit OpTypeInt 8 declarations requiring OpCapability Int8
#![no_std]

use spirv_std::{glam::Vec4, spirv};

pub struct ShaderInputs {
pub x: CustomPosition,
pub y: CustomPosition,
}

#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq)]
pub struct CustomPosition(u32);

#[spirv(vertex)]
pub fn test_vs(
#[spirv(push_constant)] inputs: &ShaderInputs,
#[spirv(position)] out_pos: &mut Vec4,
) {
let mut result: f32 = 0.;
if inputs.x < inputs.y {
result = 1.0;
}
*out_pos = Vec4::new(inputs.x.0 as f32, inputs.y.0 as f32, result as f32, 1.0);
}
Loading