Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

working compute shaders in dx12 #1182

Merged
merged 1 commit into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions h3d/Buffer.hx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ enum BufferFlag {
Used for shader input buffer
**/
UniformBuffer;
/**
Can be written
**/
ReadWriteBuffer;
}

@:allow(h3d.impl.MemoryManager)
Expand Down
175 changes: 127 additions & 48 deletions h3d/impl/DX12Driver.hx
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class ShaderRegisters {
public var samplers : Int;
public var texturesCount : Int;
public var textures2DCount : Int;
public var bufferTypes : Array<hxsl.Ast.BufferKind>;
public function new() {
}
}
Expand All @@ -111,6 +112,8 @@ class CompiledShader {
public var inputLayout : hl.CArray<InputElementDesc>;
public var inputCount : Int;
public var shader : hxsl.RuntimeShader;
public var isCompute : Bool;
public var computePipeline : ComputePipelineState;
public function new() {
}
}
Expand All @@ -133,6 +136,7 @@ class CompiledShader {
@:packed public var bufferSRV(default,null) : BufferSRV;
@:packed public var samplerDesc(default,null) : SamplerDesc;
@:packed public var cbvDesc(default,null) : ConstantBufferViewDesc;
@:packed public var uavDesc(default,null) : UAVBufferViewDesc;
@:packed public var rtvDesc(default,null) : RenderTargetViewDesc;

public var pass : h3d.mat.Pass;
Expand All @@ -156,6 +160,7 @@ class CompiledShader {
samplerDesc.comparisonFunc = NEVER;
samplerDesc.maxLod = 1e30;
descriptors2 = new hl.NativeArray(2);
uavDesc.viewDimension = BUFFER;
barrier.subResource = -1; // all
}

Expand Down Expand Up @@ -341,7 +346,7 @@ class DX12Driver extends h3d.impl.Driver {
public static var INITIAL_RT_COUNT = 1024;
public static var BUFFER_COUNT = 2;
public static var DEVICE_NAME = null;
public static var DEBUG = false;
public static var DEBUG = false; // requires dxil.dll when set to true

public function new() {
window = @:privateAccess dx.Window.windows[0];
Expand Down Expand Up @@ -875,7 +880,7 @@ class DX12Driver extends h3d.impl.Driver {

static var VERTEX_FORMATS = [null,null,R32G32_FLOAT,R32G32B32_FLOAT,R32G32B32A32_FLOAT];

function getBinaryPayload( vertex : Bool, code : String ) {
function getBinaryPayload( code : String ) {
var bin = code.indexOf("//BIN=");
if( bin >= 0 ) {
var end = code.indexOf("#", bin);
Expand All @@ -895,7 +900,7 @@ class DX12Driver extends h3d.impl.Driver {
sh.code = out.run(sh.data);
sh.code = rootStr + sh.code;
}
var bytes = getBinaryPayload(sh.vertex, sh.code);
var bytes = getBinaryPayload(sh.code);
if ( bytes == null ) {
return compiler.compile(sh.code, profile, args);
}
Expand All @@ -905,6 +910,8 @@ class DX12Driver extends h3d.impl.Driver {
override function getNativeShaderCode( shader : hxsl.RuntimeShader ) {
var out = new hxsl.HlslOut();
var vsSource = out.run(shader.vertex.data);
if( shader.mode == Compute )
return vsSource;
var out = new hxsl.HlslOut();
var psSource = out.run(shader.fragment.data);
return vsSource+"\n\n\n\n"+psSource;
Expand Down Expand Up @@ -985,14 +992,14 @@ class DX12Driver extends h3d.impl.Driver {
return range;
}

function allocConsts(size,vis,useCBV) {
function allocConsts(size,vis,type) {
var reg = regCount++;
if( size == 0 ) return -1;

if( useCBV ) {
if( type != null ) {
var pid = paramsCount;
var r = allocDescTable(vis);
r.rangeType = CBV;
r.rangeType = type;
r.numDescriptors = 1;
r.baseShaderRegister = reg;
r.registerSpace = 0;
Expand All @@ -1010,14 +1017,30 @@ class DX12Driver extends h3d.impl.Driver {


function allocParams( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
var vis = sh.vertex ? VERTEX : PIXEL;
var vis = switch( sh.kind ) {
case Vertex: VERTEX;
case Fragment: PIXEL;
default: ALL;
}
var regs = new ShaderRegisters();
regs.globals = allocConsts(sh.globalsSize, vis, false);
regs.params = allocConsts(sh.paramsSize, vis, sh.vertex ? vertexParamsCBV : fragmentParamsCBV);
regs.globals = allocConsts(sh.globalsSize, vis, null);
regs.params = allocConsts(sh.paramsSize, vis, (sh.kind == Fragment ? fragmentParamsCBV : vertexParamsCBV) ? CBV : null);
if( sh.bufferCount > 0 ) {
regs.buffers = paramsCount;
for( i in 0...sh.bufferCount )
allocConsts(1, vis, true);
regs.bufferTypes = [];
var p = sh.buffers;
while( p != null ) {
var kind = switch( p.type ) {
case TBuffer(_,_,kind): kind;
default: throw "assert";
}
regs.bufferTypes.push(kind);
allocConsts(1, vis, switch( kind ) {
case Uniform: CBV;
case RW: UAV;
});
p = p.next;
}
}
if( sh.texturesCount > 0 ) {
regs.texturesCount = sh.texturesCount;
Expand Down Expand Up @@ -1061,7 +1084,7 @@ class DX12Driver extends h3d.impl.Driver {
}

var totalVertex = calcSize(shader.vertex);
var totalFragment = calcSize(shader.fragment);
var totalFragment = shader.mode == Compute ? 0 : calcSize(shader.fragment);
var total = totalVertex + totalFragment;

if( total > 64 ) {
Expand All @@ -1083,39 +1106,57 @@ class DX12Driver extends h3d.impl.Driver {
throw "Too many globals";
}

var vertexRegisters = allocParams(shader.vertex);
var fragmentRegStart = regCount;
var fragmentRegisters = allocParams(shader.fragment);

var regs = [];
for( s in shader.getShaders() )
regs.push({ start : regCount, registers : allocParams(s) });
if( paramsCount > allocatedParams )
throw "ASSERT : Too many parameters";

var sign = new RootSignatureDesc();
sign.flags.set(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
if( shader.mode == Compute ) {
sign.flags.set(DENY_PIXEL_SHADER_ROOT_ACCESS);
sign.flags.set(DENY_VERTEX_SHADER_ROOT_ACCESS);
} else
sign.flags.set(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
sign.flags.set(DENY_HULL_SHADER_ROOT_ACCESS);
sign.flags.set(DENY_DOMAIN_SHADER_ROOT_ACCESS);
sign.flags.set(DENY_GEOMETRY_SHADER_ROOT_ACCESS);
sign.numParameters = paramsCount;
sign.parameters = cast params;

return { sign : sign, fragmentRegStart : fragmentRegStart, vertexRegisters : vertexRegisters, fragmentRegisters : fragmentRegisters, params : params, paramsCount : paramsCount, texDescs : texDescs };
return { sign : sign, registers : regs, params : params, paramsCount : paramsCount, texDescs : texDescs };
}

function compileShader( shader : hxsl.RuntimeShader ) : CompiledShader {

var res = computeRootSignature(shader);

var c = new CompiledShader();
c.vertexRegisters = res.vertexRegisters;
c.fragmentRegisters = res.fragmentRegisters;

var rootStr = stringifyRootSignature(res.sign, "ROOT_SIGNATURE", res.params, res.paramsCount);
var vs = compileSource(shader.vertex, "vs_6_0", 0, rootStr);
var ps = compileSource(shader.fragment, "ps_6_0", res.fragmentRegStart, rootStr);
var vs = shader.mode == Compute ? null : compileSource(shader.vertex, "vs_6_0", 0, rootStr);
var ps = shader.mode == Compute ? null : compileSource(shader.fragment, "ps_6_0", res.registers[1].start, rootStr);
var cs = shader.mode == Compute ? compileSource(shader.compute, "cs_6_0", 0, rootStr) : null;

var signSize = 0;
var signBytes = Driver.serializeRootSignature(res.sign, 1, signSize);
var sign = new RootSignature(signBytes,signSize);
c.rootSignature = sign;
c.shader = shader;

if( shader.mode == Compute ) {
c.isCompute = true;
var desc = new ComputePipelineStateDesc();
desc.rootSignature = sign;
desc.cs.shaderBytecode = cs;
desc.cs.bytecodeLength = cs.length;
c.computePipeline = Driver.createComputePipelineState(desc);
c.vertexRegisters = res.registers[0].registers;
return c;
}

c.vertexRegisters = res.registers[0].registers;
c.fragmentRegisters = res.registers[1].registers;

var inputs = [];
for( v in shader.vertex.data.vars )
Expand Down Expand Up @@ -1166,10 +1207,8 @@ class DX12Driver extends h3d.impl.Driver {

c.format = hxd.BufferFormat.make(format);
c.pipeline = p;
c.rootSignature = sign;
c.inputLayout = inputLayout;
c.inputCount = inputs.length;
c.shader = shader;

for( i in 0...inputs.length )
inputLayout[i].alignedByteOffset = 1; // will trigger error if not set in makePipeline()
Expand All @@ -1184,7 +1223,7 @@ class DX12Driver extends h3d.impl.Driver {

// ----- BUFFERS

function allocGPU( size : Int, heapType, state ) {
function allocGPU( size : Int, heapType, state, uav=false ) {
var desc = new ResourceDesc();
var flags = new haxe.EnumFlags();
desc.dimension = BUFFER;
Expand All @@ -1194,16 +1233,17 @@ class DX12Driver extends h3d.impl.Driver {
desc.mipLevels = 1;
desc.sampleDesc.count = 1;
desc.layout = ROW_MAJOR;
if( uav ) desc.flags.set(ALLOW_UNORDERED_ACCESS);
tmp.heap.type = heapType;
return Driver.createCommittedResource(tmp.heap, flags, desc, state, null);
}

override function allocBuffer( m : h3d.Buffer ) : GPUBuffer {
var buf = new VertexBufferData();
var size = m.getMemSize();
var bufSize = m.flags.has(UniformBuffer) ? calcCBVSize(size) : size;
var bufSize = m.flags.has(UniformBuffer) || m.flags.has(ReadWriteBuffer) ? calcCBVSize(size) : size;
buf.state = COPY_DEST;
buf.res = allocGPU(bufSize, DEFAULT, COMMON);
buf.res = allocGPU(bufSize, DEFAULT, COMMON, m.flags.has(ReadWriteBuffer));
if( !m.flags.has(UniformBuffer) ) {
var view = new VertexBufferView();
view.bufferLocation = buf.res.getGpuVirtualAddress();
Expand Down Expand Up @@ -1488,7 +1528,8 @@ class DX12Driver extends h3d.impl.Driver {

override function uploadShaderBuffers(buffers:h3d.shader.Buffers, which:h3d.shader.Buffers.BufferKind) {
uploadBuffers(buffers, buffers.vertex, which, currentShader.shader.vertex, currentShader.vertexRegisters);
uploadBuffers(buffers, buffers.fragment, which, currentShader.shader.fragment, currentShader.fragmentRegisters);
if( !currentShader.isCompute )
uploadBuffers(buffers, buffers.fragment, which, currentShader.shader.fragment, currentShader.fragmentRegisters);
}

function calcCBVSize( dataSize : Int ) {
Expand Down Expand Up @@ -1547,13 +1588,22 @@ class DX12Driver extends h3d.impl.Driver {
desc.bufferLocation = cbv.getGpuVirtualAddress();
desc.sizeInBytes = calcCBVSize(dataSize);
Driver.createConstantBufferView(desc, srv);
frame.commandList.setGraphicsRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
} else
if( currentShader.isCompute )
frame.commandList.setComputeRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
else
frame.commandList.setGraphicsRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
} else if( currentShader.isCompute )
frame.commandList.setComputeRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
else
frame.commandList.setGraphicsRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
}
case Globals:
if( shader.globalsSize > 0 )
frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
if( shader.globalsSize > 0 ) {
if( currentShader.isCompute )
frame.commandList.setComputeRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
else
frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
}
case Textures:
if( regs.texturesCount > 0 ) {
var srv = frame.shaderResourceViews.alloc(regs.texturesCount);
Expand Down Expand Up @@ -1612,10 +1662,10 @@ class DX12Driver extends h3d.impl.Driver {
t.lastFrame = frameCount;
var state = if ( t.isDepth() )
DEPTH_READ;
else if ( shader.vertex )
NON_PIXEL_SHADER_RESOURCE;
else
else if ( shader.kind == Fragment )
PIXEL_SHADER_RESOURCE;
else
NON_PIXEL_SHADER_RESOURCE;
transition(t.t, state);
Driver.createShaderResourceView(t.t.res, tdesc, srv.offset(i * frame.shaderResourceViews.stride));

Expand All @@ -1634,24 +1684,43 @@ class DX12Driver extends h3d.impl.Driver {
Driver.createSampler(desc, sampler.offset(i * frame.samplerViews.stride));
}

frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
if( currentShader.isCompute ) {
frame.commandList.setComputeRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
frame.commandList.setComputeRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
} else {
frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
}
}
case Buffers:
if( shader.bufferCount > 0 ) {
for( i in 0...shader.bufferCount ) {
var srv = frame.shaderResourceViews.alloc(1);
var b = buf.buffers[i];
var cbv = b.vbuf;
if( cbv.view != null )
throw "Buffer was allocated without UniformBuffer flag";
transition(cbv, VERTEX_AND_CONSTANT_BUFFER);
var desc = tmp.cbvDesc;
desc.bufferLocation = cbv.res.getGpuVirtualAddress();
desc.sizeInBytes = cbv.size;
Driver.createConstantBufferView(desc, srv);
frame.commandList.setGraphicsRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
}
switch( regs.bufferTypes[i] ) {
case Uniform:
if( cbv.view != null )
throw "Buffer was allocated without UniformBuffer flag";
transition(cbv, VERTEX_AND_CONSTANT_BUFFER);
var desc = tmp.cbvDesc;
desc.bufferLocation = cbv.res.getGpuVirtualAddress();
desc.sizeInBytes = cbv.size;
Driver.createConstantBufferView(desc, srv);
case RW:
if( !b.flags.has(ReadWriteBuffer) )
throw "Buffer was allocated without ReadWriteBuffer flag";
transition(cbv, UNORDERED_ACCESS);
var desc = tmp.uavDesc;
desc.numElements = b.vertices;
desc.structureSizeInBytes = b.format.strideBytes;
Driver.createUnorderedAccessView(cbv.res, null, desc, srv);
}
if( currentShader.isCompute )
frame.commandList.setComputeRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
else
frame.commandList.setGraphicsRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
}
}
}
}
Expand All @@ -1665,8 +1734,14 @@ class DX12Driver extends h3d.impl.Driver {
if( currentShader == sh )
return false;
currentShader = sh;
needPipelineFlush = true;
frame.commandList.setGraphicsRootSignature(currentShader.rootSignature);
if( sh.isCompute ) {
needPipelineFlush = false;
frame.commandList.setComputeRootSignature(currentShader.rootSignature);
frame.commandList.setPipelineState(currentShader.computePipeline);
} else {
needPipelineFlush = true;
frame.commandList.setGraphicsRootSignature(currentShader.rootSignature);
}
return true;
}

Expand Down Expand Up @@ -2026,6 +2101,10 @@ class DX12Driver extends h3d.impl.Driver {
}
}

override function computeDispatch( x : Int = 1, y : Int = 1, z : Int = 1 ) {
frame.commandList.dispatch(x,y,z);
}

}

#end
6 changes: 6 additions & 0 deletions h3d/impl/Driver.hx
Original file line number Diff line number Diff line change
Expand Up @@ -317,4 +317,10 @@ class Driver {
return 0.;
}

// --- COMPUTE

public function computeDispatch( x : Int = 1, y : Int = 1, z : Int = 1 ) {
throw "Not implemented";
}

}