Skip to content

Commit

Permalink
Remove IRHLSLExportDecoration and IRKeepAliveDecoration for non-C…
Browse files Browse the repository at this point in the history
…UDA/Torch targets (shader-slang#4364)

* Remove `IRHLSLExportDecoration` and `IRKeepAliveDecoration` for non-CUDA/Torch targets

* Update hlsl-torch-cross-compile.slang
  • Loading branch information
saipraveenb25 committed Jun 13, 2024
1 parent f0d40ad commit fba316f
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 12 deletions.
28 changes: 24 additions & 4 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,17 @@ Result linkAndOptimizeIR(
if (requiredLoweringPassSet.optionalType)
lowerOptionalType(irModule, sink);

switch (target)
{
case CodeGenTarget::CUDASource:
case CodeGenTarget::PyTorchCppBinding:
break;

default:
removeTorchAndCUDAEntryPoints(irModule);
break;
}

switch (target)
{
case CodeGenTarget::CPPSource:
Expand Down Expand Up @@ -605,10 +616,19 @@ Result linkAndOptimizeIR(
if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations())
fuseCallsToSaturatedCooperation(irModule);

// Generate any requested derivative wrappers
if (requiredLoweringPassSet.derivativePyBindWrapper)
generateDerivativeWrappers(irModule, sink);

switch (target)
{
case CodeGenTarget::CUDASource:
case CodeGenTarget::PyTorchCppBinding:
{
// Generate any requested derivative wrappers
if (requiredLoweringPassSet.derivativePyBindWrapper)
generateDerivativeWrappers(irModule, sink);
break;
}
default:
break;
}
// Next, we need to ensure that the code we emit for
// the target doesn't contain any operations that would
// be illegal on the target platform. For example,
Expand Down
26 changes: 21 additions & 5 deletions source/slang/slang-ir-pytorch-cpp-binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,6 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host

builder->addExternCppDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
builder->addTorchEntryPointDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
builder->addHLSLExportDecoration(reflectionFunc);
builder->addKeepAliveDecoration(reflectionFunc);
}

Expand Down Expand Up @@ -817,7 +816,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)

builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
builder.addHLSLExportDecoration(reflFunc);
builder.addKeepAliveDecoration(reflFunc);
}

Expand Down Expand Up @@ -899,7 +897,6 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
// Mark for host-side emit logic.
builder.addCudaHostDecoration(hostFunc);
// Keep alive. This method will be accessed externally.
builder.addHLSLExportDecoration(hostFunc);
builder.addKeepAliveDecoration(hostFunc);
}

Expand Down Expand Up @@ -1163,6 +1160,27 @@ void handleAutoBindNames(IRModule* module)
}
}

void removeTorchAndCUDAEntryPoints(IRModule* module)
{
// Go through global insts, find cuda & torch related entry points and remove the keep-alive decoration.
IRBuilder builder(module);
for (auto globalInst : module->getGlobalInsts())
{
if (auto func = as<IRFunc>(globalInst))
{
if (func->findDecoration<IRAutoPyBindCudaDecoration>() ||
func->findDecoration<IRTorchEntryPointDecoration>() ||
func->findDecoration<IRCudaKernelDecoration>())
{
if (auto keepAlive = func->findDecoration<IRKeepAliveDecoration>())
keepAlive->removeAndDeallocate();
if (auto hlslExport = func->findDecoration<IRHLSLExportDecoration>())
hlslExport->removeAndDeallocate();
}
}
}
}

void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
Expand Down Expand Up @@ -1237,7 +1255,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
}

builder.addHLSLExportDecoration(wrapperFunc);
builder.addKeepAliveDecoration(wrapperFunc);

builder.addCudaKernelForwardDerivativeDecoration(func, wrapperFunc);
Expand Down Expand Up @@ -1296,7 +1313,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
}

builder.addHLSLExportDecoration(wrapperFunc);
builder.addKeepAliveDecoration(wrapperFunc);

builder.addCudaKernelBackwardDerivativeDecoration(func, wrapperFunc);
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ir-pytorch-cpp-binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ void removeTorchKernels(IRModule* module);
void handleAutoBindNames(IRModule* module);
void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink);
void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink* sink);
void removeTorchAndCUDAEntryPoints(IRModule* module);

}

12 changes: 9 additions & 3 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1385,21 +1385,27 @@ static void addLinkageDecoration(
{
builder->addCudaKernelDecoration(inst);
builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
builder->addHLSLExportDecoration(inst);

// Temp decorations to get this function through the linker.
builder->addKeepAliveDecoration(inst);
builder->addHLSLExportDecoration(inst);
}
else if (as<TorchEntryPointAttribute>(modifier))
{
builder->addTorchEntryPointDecoration(inst, decl->getName()->text.getUnownedSlice());
builder->addCudaHostDecoration(inst);
builder->addHLSLExportDecoration(inst);
builder->addKeepAliveDecoration(inst);
builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());

// Temp decorations to get this function through the linker.
builder->addKeepAliveDecoration(inst);
builder->addHLSLExportDecoration(inst);
}
else if (as<AutoPyBindCudaAttribute>(modifier))
{
builder->addAutoPyBindCudaDecoration(inst, decl->getName()->text.getUnownedSlice());
builder->addAutoPyBindExportInfoDecoration(inst);

// Temp decorations to get this function through the linker.
builder->addKeepAliveDecoration(inst);
builder->addHLSLExportDecoration(inst);
}
Expand Down
55 changes: 55 additions & 0 deletions tests/autodiff/hlsl-torch-cross-compile.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//TEST:SIMPLE(filecheck=HLSL): -target hlsl -line-directive-mode none -entry computeMain -stage compute
//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none
//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none

//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

typedef DifferentialPair<float> dpfloat;
typedef float.Differential dfloat;

[Differentiable]
float func1(float x)
{
return x * 4;
}

[AutoPyBindCUDA]
[CUDAKernel]
void torchMain(TensorView<float> v)
{
v[0] = func1(v[0]);
v[1] = func1(v[1]);
}

// Shouldn't see torchMain (or its transformations) anywhere in the HLSL output
// HLSL-NOT:torchMain
// HLSL:func1
// HLSL-NOT:torchMain
// HLSL:computeMain
// HLSL-NOT:torchMain

[Differentiable]
float func2(float a)
{
return a;
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
{
dpfloat dpa = dpfloat(2.0, 1.0);
dpfloat dpb = dpfloat(1.5, 1.0);

outputBuffer[0] = fwd_diff(func1)(dpa).d; // Expect: 1
outputBuffer[1] = fwd_diff(func2)(dpfloat(dpa.p, 0.0)).d; // Expect: 0
}
}

// Ensure that the generated CUDA and Torch kernels do have torchMain & its transformations

// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
// TORCH-NEXT: void __kernel__torchMain(TensorView {{[[:alnum:]_]+}});

// CUDA: __global__ void __kernel__torchMain(TensorView {{[[:alnum:]_]+}})

0 comments on commit fba316f

Please sign in to comment.