diff --git a/examples/runners/ash/src/main.rs b/examples/runners/ash/src/main.rs index 445ca75ee6..1551c22901 100644 --- a/examples/runners/ash/src/main.rs +++ b/examples/runners/ash/src/main.rs @@ -80,7 +80,6 @@ use winit::{ use std::{ borrow::Cow, - collections::HashMap, ffi::{CStr, CString}, fs::File, os::raw::c_char, @@ -104,7 +103,6 @@ pub enum RustGPUShader { } impl RustGPUShader { - // The form with dashes, e.g. `sky-shader`. fn crate_name(&self) -> &'static str { match self { RustGPUShader::Simplest => "simplest-shader", @@ -112,15 +110,6 @@ impl RustGPUShader { RustGPUShader::Mouse => "mouse-shader", } } - - // The form with underscores, e.g. `sky_shader`. - fn crate_ident(&self) -> &'static str { - match self { - RustGPUShader::Simplest => "simplest_shader", - RustGPUShader::Sky => "sky_shader", - RustGPUShader::Mouse => "mouse_shader", - } - } } #[derive(Debug, Parser)] @@ -147,7 +136,7 @@ pub fn main() { } let options = Options::parse(); - let shaders = compile_shaders(&options.shader); + let (vert_data, frag_data) = compile_shaders(&options.shader); // runtime setup let event_loop = EventLoop::new().unwrap(); @@ -166,27 +155,13 @@ pub fn main() { .unwrap(); let mut ctx = RenderBase::new(window, &options).into_ctx(); - // Create shader module and pipelines - for SpvFile { name, data } in shaders { - ctx.insert_shader_module(name, &data); - } + // Insert shader modules. + ctx.update_shader_modules(&vert_data, &frag_data); - let crate_ident = options.shader.crate_ident(); - ctx.build_pipelines( - vk::PipelineCache::null(), - vec![( - VertexShaderEntryPoint { - module: format!("{crate_ident}::main_vs"), - entry_point: "main_vs".into(), - }, - FragmentShaderEntryPoint { - module: format!("{crate_ident}::main_fs"), - entry_point: "main_fs".into(), - }, - )], - ); + // Create pipeline. + ctx.rebuild_pipeline(vk::PipelineCache::null()); - let (compiler_sender, compiler_receiver) = sync_channel(1); + let (compiler_sender, compiler_receiver) = sync_channel::<(Vec, Vec)>(1); // FIXME(eddyb) incomplete `winit` upgrade, follow the guides in: // https://github.com/rust-windowing/winit/releases/tag/v0.30.0 @@ -206,12 +181,10 @@ pub fn main() { ctx.render(); } } - Ok(new_shaders) => { - for SpvFile { name, data } in new_shaders { - ctx.insert_shader_module(name, &data); - } + Ok((new_vert_data, new_frag_data)) => { + ctx.update_shader_modules(&new_vert_data, &new_frag_data); ctx.recompiling_shaders = false; - ctx.rebuild_pipelines(vk::PipelineCache::null()); + ctx.rebuild_pipeline(vk::PipelineCache::null()); } Err(TryRecvError::Disconnected) => { panic!("compiler receiver disconnected unexpectedly"); @@ -253,7 +226,7 @@ pub fn main() { // HACK(eddyb) to see any changes, re-specializing the // shader module is needed (e.g. during pipeline rebuild). - ctx.rebuild_pipelines(vk::PipelineCache::null()); + ctx.rebuild_pipeline(vk::PipelineCache::null()); } _ => {} }, @@ -268,15 +241,14 @@ pub fn main() { .unwrap(); } -pub fn compile_shaders(shader: &RustGPUShader) -> Vec { +pub fn compile_shaders(shader: &RustGPUShader) -> (Vec, Vec) { let manifest_dir = env!("CARGO_MANIFEST_DIR"); let crate_path = [manifest_dir, "..", "..", "shaders", shader.crate_name()] .iter() .copied() .collect::(); - let crate_ident = shader.crate_ident(); - SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.1") + let mut shaders = SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.1") .print_metadata(MetadataPrintout::None) .shader_panic_strategy(spirv_builder::ShaderPanicStrategy::DebugPrintfThenExit { print_inputs: true, @@ -284,24 +256,33 @@ pub fn compile_shaders(shader: &RustGPUShader) -> Vec { }) // TODO: `multimodule` is no longer needed since // https://github.com/KhronosGroup/SPIRV-Tools/issues/4892 was fixed, but removing it is - // non-trivial and hasn't been done et. + // non-trivial and hasn't been done yet. .multimodule(true) .build() .unwrap() .module .unwrap_multi() .iter() - .map(|(name, path)| SpvFile { - name: format!("{crate_ident}::{name}"), - data: read_spv(&mut File::open(path).unwrap()).unwrap(), + .map(|(name, path)| { + ( + name.clone(), + read_spv(&mut File::open(path).unwrap()).unwrap(), + ) }) - .collect() -} - -#[derive(Debug)] -pub struct SpvFile { - pub name: String, - pub data: Vec, + .collect::>(); + + // We always have two shaders. And the fragment shader is always before the + // vertex shader in `shaders`. This is because `unwrap_multi` returns a + // `BTreeMap` sorted by shader name, and `main_fs` comes before `main_vs`, + // alphabetically. We still check the names to make sure they are in the + // order we expect. That way if the order ever changes we'll get an + // assertion failure here as opposed to a harder-to-debug failure later on. + assert_eq!(shaders.len(), 2); + assert_eq!(shaders[0].0, "main_fs"); + assert_eq!(shaders[1].0, "main_vs"); + let vert = shaders.pop().unwrap().1; + let frag = shaders.pop().unwrap().1; + (vert, frag) } pub struct RenderBase { @@ -714,9 +695,9 @@ pub struct RenderCtx { pub commands: RenderCommandPool, pub viewports: Box<[vk::Viewport]>, pub scissors: Box<[vk::Rect2D]>, - pub pipelines: Vec, - pub shader_modules: HashMap, - pub shader_set: Vec<(VertexShaderEntryPoint, FragmentShaderEntryPoint)>, + pub pipeline: Option, + pub vert_module: Option, + pub frag_module: Option, pub rendering_paused: bool, pub recompiling_shaders: bool, @@ -764,9 +745,9 @@ impl RenderCtx { framebuffers, viewports, scissors, - pipelines: Vec::new(), - shader_modules: HashMap::new(), - shader_set: Vec::new(), + pipeline: None, + vert_module: None, + frag_module: None, rendering_paused: false, recompiling_shaders: false, start: std::time::Instant::now(), @@ -792,7 +773,7 @@ impl RenderCtx { } } - pub fn rebuild_pipelines(&mut self, pipeline_cache: vk::PipelineCache) { + pub fn rebuild_pipeline(&mut self, pipeline_cache: vk::PipelineCache) { // NOTE(eddyb) this acts like an integration test for specialization constants. let spec_const_entries = [vk::SpecializationMapEntry::default() .constant_id(0x5007) @@ -804,83 +785,66 @@ impl RenderCtx { .map_entries(&spec_const_entries) .data(&spec_const_data); - self.cleanup_pipelines(); + self.cleanup_pipeline(); let pipeline_layout = self.create_pipeline_layout(); let viewport = vk::PipelineViewportStateCreateInfo::default() .scissor_count(1) .viewport_count(1); - let modules_names = self - .shader_set - .iter() - .map(|(vert, frag)| { - let vert_module = *self.shader_modules.get(&vert.module).unwrap(); - let vert_name = CString::new(vert.entry_point.clone()).unwrap(); - let frag_module = *self.shader_modules.get(&frag.module).unwrap(); - let frag_name = CString::new(frag.entry_point.clone()).unwrap(); - ((frag_module, frag_name), (vert_module, vert_name)) - }) - .collect::>(); - let descs = modules_names - .iter() - .map(|((frag_module, frag_name), (vert_module, vert_name))| { - PipelineDescriptor::new(Box::new([ - vk::PipelineShaderStageCreateInfo { - module: *vert_module, - p_name: (*vert_name).as_ptr(), - stage: vk::ShaderStageFlags::VERTEX, - ..Default::default() - }, - vk::PipelineShaderStageCreateInfo { - s_type: vk::StructureType::PIPELINE_SHADER_STAGE_CREATE_INFO, - module: *frag_module, - p_name: (*frag_name).as_ptr(), - stage: vk::ShaderStageFlags::FRAGMENT, - p_specialization_info: &specialization_info, - ..Default::default() - }, - ])) - }) - .collect::>(); - let descs_indirect_parts = descs - .iter() - .map(|desc| desc.indirect_parts()) - .collect::>(); - let pipeline_info = descs - .iter() - .zip(&descs_indirect_parts) - .map(|(desc, desc_indirect_parts)| { - vk::GraphicsPipelineCreateInfo::default() - .stages(&desc.shader_stages) - .vertex_input_state(&desc.vertex_input) - .input_assembly_state(&desc.input_assembly) - .rasterization_state(&desc.rasterization) - .multisample_state(&desc.multisample) - .depth_stencil_state(&desc.depth_stencil) - .color_blend_state(&desc_indirect_parts.color_blend) - .dynamic_state(&desc_indirect_parts.dynamic_state_info) - .viewport_state(&viewport) - .layout(pipeline_layout) - .render_pass(self.render_pass) - }) - .collect::>(); - self.pipelines = unsafe { + + let vs_entry_point = "main_vs"; + let fs_entry_point = "main_fs"; + let vert_module = self.vert_module.as_ref().unwrap(); + let frag_module = self.frag_module.as_ref().unwrap(); + let vert_name = CString::new(vs_entry_point).unwrap(); + let frag_name = CString::new(fs_entry_point).unwrap(); + let desc = PipelineDescriptor::new(Box::new([ + vk::PipelineShaderStageCreateInfo { + module: *vert_module, + p_name: (*vert_name).as_ptr(), + stage: vk::ShaderStageFlags::VERTEX, + ..Default::default() + }, + vk::PipelineShaderStageCreateInfo { + s_type: vk::StructureType::PIPELINE_SHADER_STAGE_CREATE_INFO, + module: *frag_module, + p_name: (*frag_name).as_ptr(), + stage: vk::ShaderStageFlags::FRAGMENT, + p_specialization_info: &specialization_info, + ..Default::default() + }, + ])); + let desc_indirect_parts = desc.indirect_parts(); + let pipeline_info = vk::GraphicsPipelineCreateInfo::default() + .stages(&desc.shader_stages) + .vertex_input_state(&desc.vertex_input) + .input_assembly_state(&desc.input_assembly) + .rasterization_state(&desc.rasterization) + .multisample_state(&desc.multisample) + .depth_stencil_state(&desc.depth_stencil) + .color_blend_state(&desc_indirect_parts.color_blend) + .dynamic_state(&desc_indirect_parts.dynamic_state_info) + .viewport_state(&viewport) + .layout(pipeline_layout) + .render_pass(self.render_pass); + + let mut pipelines = unsafe { self.base .device - .create_graphics_pipelines(pipeline_cache, &pipeline_info, None) + .create_graphics_pipelines(pipeline_cache, &[pipeline_info], None) .expect("Unable to create graphics pipeline") - } - .into_iter() - .map(|pipeline| Pipeline { + }; + // A single `pipeline_info` results in a single pipeline. + assert_eq!(pipelines.len(), 1); + self.pipeline = pipelines.pop().map(|pipeline| Pipeline { pipeline, pipeline_layout, - }) - .collect(); + }); } - pub fn cleanup_pipelines(&mut self) { + pub fn cleanup_pipeline(&mut self) { unsafe { self.base.device.device_wait_idle().unwrap(); - for pipeline in self.pipelines.drain(..) { + if let Some(pipeline) = self.pipeline.take() { self.base.device.destroy_pipeline(pipeline.pipeline, None); self.base .device @@ -889,31 +853,35 @@ impl RenderCtx { } } - pub fn build_pipelines( - &mut self, - pipeline_cache: vk::PipelineCache, - shader_set: Vec<(VertexShaderEntryPoint, FragmentShaderEntryPoint)>, - ) { - self.shader_set = shader_set; - self.rebuild_pipelines(pipeline_cache); - } - - /// Add a shader module to the hash map of shader modules. returns a handle to the module, and the - /// old shader module if there was one with the same name already. Does not rebuild pipelines - /// that may be using the shader module, nor does it invalidate them. - pub fn insert_shader_module(&mut self, name: String, spirv: &[u32]) { - let shader_info = vk::ShaderModuleCreateInfo::default().code(spirv); + /// Update the vertex and fragment shader modules. Does not rebuild + /// pipelines that may be using the shader module, nor does it invalidate + /// them. + pub fn update_shader_modules(&mut self, vert_data: &[u32], frag_data: &[u32]) { + let shader_info = vk::ShaderModuleCreateInfo::default().code(vert_data); let shader_module = unsafe { self.base .device .create_shader_module(&shader_info, None) - .expect("Shader module error") + .expect("Vertex shader module error") }; - if let Some(old_module) = self.shader_modules.insert(name, shader_module) { + if let Some(old_module) = self.vert_module.replace(shader_module) { unsafe { self.base.device.destroy_shader_module(old_module, None); } + } + + let shader_info = vk::ShaderModuleCreateInfo::default().code(frag_data); + let shader_module = unsafe { + self.base + .device + .create_shader_module(&shader_info, None) + .expect("Fragment shader module error") }; + if let Some(old_module) = self.frag_module.replace(shader_module) { + unsafe { + self.base.device.destroy_shader_module(old_module, None); + } + } } /// Destroys the swapchain, as well as the renderpass and frame and command buffers @@ -993,11 +961,7 @@ impl RenderCtx { }, }]; - // There should only be one pipeline because compile_shaders only loads the last spirv - // file it produced. - for pipeline in self.pipelines.iter() { - self.draw(pipeline, framebuffer, &clear_values); - } + self.draw(self.pipeline.as_ref().unwrap(), framebuffer, &clear_values); let wait_semaphors = [self.sync.rendering_complete_semaphore]; let swapchains = [self.swapchain]; @@ -1161,14 +1125,17 @@ impl Drop for RenderCtx { .device .free_command_buffers(self.commands.pool, &[self.commands.draw_command_buffer]); self.base.device.destroy_render_pass(self.render_pass, None); - self.cleanup_pipelines(); + self.cleanup_pipeline(); self.cleanup_swapchain(); self.base .device .destroy_command_pool(self.commands.pool, None); - for (_, shader_module) in self.shader_modules.drain() { - self.base.device.destroy_shader_module(shader_module, None); - } + self.base + .device + .destroy_shader_module(self.vert_module.unwrap(), None); + self.base + .device + .destroy_shader_module(self.frag_module.unwrap(), None); } } } @@ -1345,16 +1312,6 @@ impl<'a> PipelineDescriptor<'a> { } } -pub struct VertexShaderEntryPoint { - pub module: String, - pub entry_point: String, -} - -pub struct FragmentShaderEntryPoint { - module: String, - entry_point: String, -} - unsafe fn any_as_u8_slice(p: &T) -> &[u8] { unsafe { ::std::slice::from_raw_parts((p as *const T).cast::(), ::std::mem::size_of::())