Skip to content

Commit

Permalink
Add preparation support of row export for mesh shader
Browse files Browse the repository at this point in the history
Future generation allows row export for mesh shader, we don't have to
set thread group size to max(vertexCount, primitiveCount) to launch
enough threads to do vertex/primitive export. Consider this mesh shader
test: the mesh shader export 256 points while the mesh shader workgroup
size is (8, 2, 4). In wave32 mode, only 2 waves could be launched. That
is wave0 exports row0, 2, 4, 6 and wave1 exports row1, 3, 5, 7. The
future generation is able to achieve this.

In this change, we add a loop structure to handle row export. It is
something like this:

  loopIndex = 0
  primOrVertexIndex = threadIdInSubgroup

  while (primOrVertexIndex < primOrVertexCount) {
    Export primitive/vertex

    loopIndex += numWaves
    primOrVertexIndex += loopIndex * waveSize
  }

The row export will be distributed uniformly on existing waves. This
loop structure will be degenerated if row export is disabled and the
final optimized CFG is equivalent to previous mesh shader implementation
without unnecessary control flow instructions.
  • Loading branch information
amdrexu committed Oct 6, 2022
1 parent 733897b commit ea29e54
Showing 1 changed file with 136 additions and 17 deletions.
153 changes: 136 additions & 17 deletions lgc/patch/MeshTaskShader.cpp
Expand Up @@ -483,18 +483,26 @@ void MeshTaskShader::processMeshShader(Function *entryPoint) {
m_payloadRingEntryOffset = nullptr;

auto &meshMode = m_pipelineState->getShaderModes()->getMeshShaderMode();
const unsigned waveSize = m_pipelineState->getShaderWaveSize(ShaderStageMesh);

// Setup LDS layout
layoutMeshShaderLds(m_pipelineState, entryPoint, &m_ldsLayout);
m_lds = getOrCreateMeshLds(entryPoint->getParent());

// Mutate mesh shader entry-point
entryPoint = mutateMeshShaderEntryPoint(entryPoint);

// Force s_barrier to be present if necessary (ignore optimization)
const unsigned numMeshThreads = meshMode.workgroupSizeX * meshMode.workgroupSizeY * meshMode.workgroupSizeZ;
auto primAmpFactor =
m_pipelineState->getShaderResourceUsage(ShaderStageGeometry)->inOutUsage.gs.calcFactor.primAmpFactor;
// If we enable row export, the actual thread group size is determined by work group size provided from API mesh
// shader.
const unsigned flatWorkgroupSize = m_pipelineState->enableMeshRowExport() ? numMeshThreads : primAmpFactor;
entryPoint->addFnAttr("amdgpu-flat-work-group-size",
std::to_string(primAmpFactor) + std::string(",") + std::to_string(primAmpFactor));
std::to_string(primAmpFactor) + std::string(",") + std::to_string(flatWorkgroupSize));

const unsigned numWaves = alignTo(flatWorkgroupSize, waveSize) / waveSize;

// API mesh shader entry block
BasicBlock *beginMeshShaderBlock = &entryPoint->getEntryBlock();
Expand All @@ -518,7 +526,8 @@ void MeshTaskShader::processMeshShader(Function *entryPoint) {
};

auto entryBlock = createBlock(".entry", beginMeshShaderBlock);
auto initPrimitiveIndicesBlock = createBlock(".initPrimitiveIndices", beginMeshShaderBlock);
auto initPrimitiveIndicesHeaderBlock = createBlock(".initPrimitiveIndicesHeader", beginMeshShaderBlock);
auto initPrimitiveIndicesBodyBlock = createBlock(".initPrimitiveIndicesBody", beginMeshShaderBlock);
auto endInitPrimitiveIndicesBlock = createBlock(".endInitPrimitiveIndices", beginMeshShaderBlock);

auto writeSpecialValueBlock = createBlock(".writeSpecialValue", beginMeshShaderBlock);
Expand All @@ -529,10 +538,12 @@ void MeshTaskShader::processMeshShader(Function *entryPoint) {
auto endDummyAllocReqBlock = createBlock(".endDummyAllocReq");
auto checkExportPrimitiveBlock = createBlock(".checkExportPrimitive");

auto exportPrimitiveBlock = createBlock(".exportPrimitive");
auto exportPrimitiveHeaderBlock = createBlock(".exportPrimitiveHeader");
auto exportPrimitiveBodyBlock = createBlock(".exportPrimitiveBody");
auto endExportPrimitiveBlock = createBlock(".endExportPrimitive");

auto exportVertexBlock = createBlock(".exportVertex");
auto exportVertexHeaderBlock = createBlock(".exportVertexHeader");
auto exportVertexBodyBlock = createBlock(".exportVertexBody");
auto endExportVertexBlock = createBlock(".endExportVertex");

auto collectMeshStatsBlock = createBlock(".collectMeshStats");
Expand All @@ -544,24 +555,60 @@ void MeshTaskShader::processMeshShader(Function *entryPoint) {

initWaveThreadInfo(entryPoint);

m_builder->CreateBr(initPrimitiveIndicesHeaderBlock);
}

// Construct ".initPrimitiveIndicesHeader" block
PHINode *loopIndexPhi = nullptr;
{
m_builder->SetInsertPoint(initPrimitiveIndicesHeaderBlock);

if (m_pipelineState->enableMeshRowExport()) {
loopIndexPhi = m_builder->CreatePHI(m_builder->getInt32Ty(), 2);
loopIndexPhi->addIncoming(m_builder->getInt32(0), entryBlock); // loopIndex = 0

// primitiveIndex = threadIdInSubgroup + loopIndex * waveSize
m_waveThreadInfo.primOrVertexIndex =
m_builder->CreateAdd(m_waveThreadInfo.threadIdInSubgroup,
m_builder->CreateMul(loopIndexPhi, m_builder->getInt32(waveSize)), "primitiveIndex");
}

auto validPrimitive =
m_builder->CreateICmpULT(m_waveThreadInfo.primOrVertexIndex, m_builder->getInt32(meshMode.outputPrimitives));
m_builder->CreateCondBr(validPrimitive, initPrimitiveIndicesBlock, endInitPrimitiveIndicesBlock);
m_builder->CreateCondBr(validPrimitive, initPrimitiveIndicesBodyBlock, endInitPrimitiveIndicesBlock);
}

// Construct ".initPrimitiveIndices" block
// Construct ".initPrimitiveIndicesBody" block
{
m_builder->SetInsertPoint(initPrimitiveIndicesBlock);
m_builder->SetInsertPoint(initPrimitiveIndicesBodyBlock);

if (m_pipelineState->enableMeshRowExport()) {
//
// Row export is something like this:
//
// loopIndex = 0
// primitiveIndex = threadIdInSubgroup
//
// while (primitiveIndex < outputPrimitives) {
// Zero primitive connectivity data
//
// loopIndex += numWaves
// primitiveIndex += loopIndex * waveSize
// }
//
auto loopIndex = m_builder->CreateAdd(loopIndexPhi, m_builder->getInt32(numWaves)); // loopIndex += numWaves
loopIndexPhi->addIncoming(loopIndex, initPrimitiveIndicesBodyBlock);
}

auto ldsStart = m_builder->getInt32(getMeshShaderLdsRegionStart(MeshLdsRegion::PrimitiveIndices));
auto ldsOffset = m_builder->CreateAdd(ldsStart, m_waveThreadInfo.primOrVertexIndex);

writeValueToLds(m_builder->getInt32(0), ldsOffset);
m_builder->CreateBr(endInitPrimitiveIndicesBlock);
m_builder->CreateBr(m_pipelineState->enableMeshRowExport() ? initPrimitiveIndicesHeaderBlock
: endInitPrimitiveIndicesBlock);
}

// Construct ".endInitPrimitiveIndices" block
unsigned numMeshThreads = meshMode.workgroupSizeX * meshMode.workgroupSizeY * meshMode.workgroupSizeZ;
Value *firstThreadInSubgroup = nullptr;
{
m_builder->SetInsertPoint(endInitPrimitiveIndicesBlock);
Expand Down Expand Up @@ -664,32 +711,104 @@ void MeshTaskShader::processMeshShader(Function *entryPoint) {
{
m_builder->SetInsertPoint(checkExportPrimitiveBlock);

m_builder->CreateBr(exportPrimitiveHeaderBlock);
}

// Construct ".exportPrimitiveHeader" block
{
m_builder->SetInsertPoint(exportPrimitiveHeaderBlock);

if (m_pipelineState->enableMeshRowExport()) {
loopIndexPhi = m_builder->CreatePHI(m_builder->getInt32Ty(), 2);
loopIndexPhi->addIncoming(m_builder->getInt32(0), checkExportPrimitiveBlock); // loopIndex = 0

// primitiveIndex = threadIdInSubgroup + loopIndex * waveSize
m_waveThreadInfo.primOrVertexIndex =
m_builder->CreateAdd(m_waveThreadInfo.threadIdInSubgroup,
m_builder->CreateMul(loopIndexPhi, m_builder->getInt32(waveSize)), "primitiveIndex");
}

auto validPrimitive = m_builder->CreateICmpULT(m_waveThreadInfo.primOrVertexIndex, primitiveCount);
m_builder->CreateCondBr(validPrimitive, exportPrimitiveBlock, endExportPrimitiveBlock);
m_builder->CreateCondBr(validPrimitive, exportPrimitiveBodyBlock, endExportPrimitiveBlock);
}

// Construct ".exportPrimitive" block
// Construct ".exportPrimitiveBody" block
{
m_builder->SetInsertPoint(exportPrimitiveBlock);
m_builder->SetInsertPoint(exportPrimitiveBodyBlock);

if (m_pipelineState->enableMeshRowExport()) {
//
// Row export is something like this:
//
// loopIndex = 0
// primitiveIndex = threadIdInSubgroup
//
// while (primitiveIndex < primitiveCount) {
// Export primitive
// Export primitive attributes
//
// loopIndex += numWaves
// primitiveIndex += loopIndex * waveSize
// }
//
auto loopIndex = m_builder->CreateAdd(loopIndexPhi, m_builder->getInt32(numWaves)); // loopIndex += numWaves
loopIndexPhi->addIncoming(loopIndex, exportPrimitiveBodyBlock);
}

exportPrimitive();
m_builder->CreateBr(endExportPrimitiveBlock);
m_builder->CreateBr(m_pipelineState->enableMeshRowExport() ? exportPrimitiveHeaderBlock : endExportPrimitiveBlock);
}

// Construct ".endExportPrimitive" block
{
m_builder->SetInsertPoint(endExportPrimitiveBlock);

m_builder->CreateBr(exportVertexHeaderBlock);
}

// Construct ".exportVertexHeader" block
{
m_builder->SetInsertPoint(exportVertexHeaderBlock);

if (m_pipelineState->enableMeshRowExport()) {
loopIndexPhi = m_builder->CreatePHI(m_builder->getInt32Ty(), 2);
loopIndexPhi->addIncoming(m_builder->getInt32(0), endExportPrimitiveBlock); // loopIndex = 0

// vertexIndex = threadIdInSubgroup + loopIndex * waveSize
m_waveThreadInfo.primOrVertexIndex =
m_builder->CreateAdd(m_waveThreadInfo.threadIdInSubgroup,
m_builder->CreateMul(loopIndexPhi, m_builder->getInt32(waveSize)), "vertexIndex");
}

auto validVertex = m_builder->CreateICmpULT(m_waveThreadInfo.primOrVertexIndex, vertexCount);
m_builder->CreateCondBr(validVertex, exportVertexBlock, endExportVertexBlock);
m_builder->CreateCondBr(validVertex, exportVertexBodyBlock, endExportVertexBlock);
}

// Construct "exportVertex" block
// Construct "exportVertexBody" block
{
m_builder->SetInsertPoint(exportVertexBlock);
m_builder->SetInsertPoint(exportVertexBodyBlock);

if (m_pipelineState->enableMeshRowExport()) {
//
// Row export is something like this:
//
// loopIndex = 0
// vertexIndex = threadIdInSubgroup
//
// while (vertexIndex < vertexCount) {
// Export vertex position data
// Export vertex attributes
//
// loopIndex += numWaves
// vertexIndex += loopIndex * waveSize
// }
//
auto loopIndex = m_builder->CreateAdd(loopIndexPhi, m_builder->getInt32(numWaves)); // loopIndex += numWaves
loopIndexPhi->addIncoming(loopIndex, exportVertexBodyBlock);
}

exportVertex();
m_builder->CreateBr(endExportVertexBlock);
m_builder->CreateBr(m_pipelineState->enableMeshRowExport() ? exportVertexHeaderBlock : endExportVertexBlock);
}

// Construct ".endExportVertex" block
Expand Down

0 comments on commit ea29e54

Please sign in to comment.