Skip to content

Commit

Permalink
Extend the COM-based API to support whole program compilation. (shade…
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe committed Jun 12, 2024
1 parent 318adcc commit ccc26c2
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-check-entrypoint.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-type-by-name.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-free-list.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-get-target-code.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-io.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-json-native.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-json.cpp" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-free-list.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-get-target-code.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-io.cpp">
<Filter>Source Files</Filter>
</ClCompile>
Expand Down
5 changes: 5 additions & 0 deletions slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -4953,6 +4953,11 @@ namespace slang
uint32_t compilerOptionEntryCount,
CompilerOptionEntry* compilerOptionEntries,
ISlangBlob** outDiagnostics = nullptr) = 0;

virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
IBlob** outCode,
IBlob** outDiagnostics = nullptr) = 0;
};
#define SLANG_UUID_IComponentType IComponentType::getTypeGuid()

Expand Down
10 changes: 10 additions & 0 deletions source/slang-capture-replay/slang-composite-component-type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ namespace SlangCapture
return res;
}

SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics)
{
slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangResult res = m_actualCompositeComponentType->getTargetCode(targetIndex, outCode, outDiagnostics);
return res;
}

SLANG_NO_THROW SlangResult CompositeComponentTypeCapture::getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
4 changes: 4 additions & 0 deletions source/slang-capture-replay/slang-composite-component-type.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ namespace SlangCapture
uint32_t compilerOptionEntryCount,
slang::CompilerOptionEntry* compilerOptionEntries,
ISlangBlob** outDiagnostics = nullptr) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) override;

slang::IComponentType* getActualCompositeComponentType() const { return m_actualCompositeComponentType; }
private:
Expand Down
10 changes: 10 additions & 0 deletions source/slang-capture-replay/slang-entrypoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ namespace SlangCapture
return res;
}

SLANG_NO_THROW SlangResult EntryPointCapture::getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics)
{
slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangResult res = m_actualEntryPoint->getTargetCode(targetIndex, outCode, outDiagnostics);
return res;
}

SLANG_NO_THROW SlangResult EntryPointCapture::getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
4 changes: 4 additions & 0 deletions source/slang-capture-replay/slang-entrypoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ namespace SlangCapture
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
10 changes: 10 additions & 0 deletions source/slang-capture-replay/slang-module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ namespace SlangCapture
return res;
}

SLANG_NO_THROW SlangResult ModuleCapture::getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics)
{
slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangResult res = m_actualModule->getTargetCode(targetIndex, outCode, outDiagnostics);
return res;
}

SLANG_NO_THROW SlangResult ModuleCapture::getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
4 changes: 4 additions & 0 deletions source/slang-capture-replay/slang-module.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ namespace SlangCapture
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
10 changes: 10 additions & 0 deletions source/slang-capture-replay/slang-type-conformance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ namespace SlangCapture
return res;
}

SLANG_NO_THROW SlangResult TypeConformanceCapture::getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics)
{
slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangResult res = m_actualTypeConformance->getTargetCode(targetIndex, outCode, outDiagnostics);
return res;
}

SLANG_NO_THROW SlangResult TypeConformanceCapture::getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
4 changes: 4 additions & 0 deletions source/slang-capture-replay/slang-type-conformance.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ namespace SlangCapture
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
31 changes: 30 additions & 1 deletion source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ namespace Slang
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE;
SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE;

SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
Expand Down Expand Up @@ -602,11 +606,12 @@ namespace Slang
Index argCount,
DiagnosticSink* sink) SLANG_OVERRIDE;

private:
public:
CompositeComponentType(
Linkage* linkage,
List<RefPtr<ComponentType>> const& childComponents);

private:
List<RefPtr<ComponentType>> m_childComponents;

// The following arrays hold the concatenated entry points, parameters,
Expand Down Expand Up @@ -879,6 +884,14 @@ namespace Slang
return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getTargetCode(targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down Expand Up @@ -1112,6 +1125,14 @@ namespace Slang
return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getTargetCode(targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down Expand Up @@ -1286,6 +1307,14 @@ namespace Slang
return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getTargetCode(targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
51 changes: 51 additions & 0 deletions source/slang/slang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4661,6 +4661,57 @@ void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void
acceptVisitor(&visitor, nullptr);
}

SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetCode(
Int targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics)
{
auto linkage = getLinkage();
if (targetIndex < 0 || targetIndex >= linkage->targets.getCount())
return SLANG_E_INVALID_ARG;

// If the user hasn't specified any entry points, then we should
// discover all entrypoints that are defined in linked modules, and
// include all of them in the compile.
//
if (getEntryPointCount() == 0)
{
List<Module*> modules;
this->enumerateModules([&](Module* module)
{
modules.add(module);
});
List<RefPtr<ComponentType>> components;
components.add(this);
for (auto module : modules)
{
for (auto entryPoint : module->getEntryPoints())
{
components.add(entryPoint);
}
}
RefPtr<CompositeComponentType> composite = new CompositeComponentType(linkage, components);
ComPtr<IComponentType> linkedComponentType;
SLANG_RETURN_ON_FAIL(composite->link(linkedComponentType.writeRef(), outDiagnostics));
return linkedComponentType->getTargetCode(targetIndex, outCode, outDiagnostics);
}

auto target = linkage->targets[targetIndex];
auto targetProgram = getTargetProgram(target);

DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet);
applySettingsToDiagnosticSink(&sink, &sink, m_optionSet);

IArtifact* artifact = targetProgram->getOrCreateWholeProgramResult(&sink);
sink.getBlobIfNeeded(outDiagnostics);

if (artifact == nullptr)
return SLANG_FAIL;

return artifact->loadBlob(ArtifactKeep::Yes, outCode);
}

//
// CompositeComponentType
//
Expand Down
69 changes: 69 additions & 0 deletions tools/slang-unit-test/unit-test-get-target-code.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// unit-test-translation-unit-import.cpp

#include "../../slang.h"

#include <stdio.h>
#include <stdlib.h>

#include "tools/unit-test/slang-unit-test.h"
#include "../../slang-com-ptr.h"
#include "../../source/core/slang-io.h"
#include "../../source/core/slang-process.h"

using namespace Slang;

// Test that the IComponentType::getTargetCode API supports
// compiling a program with multiple entrypoints and retrieving a single
// compiled module that contains all the entrypoints.
//
SLANG_UNIT_TEST(getTargetCode)
{
// Source for a module that contains an undecorated entrypoint.
const char* userSourceBody = R"(
[shader("fragment")]
float4 fragMain(float4 pos:SV_Position) : SV_Target
{
return pos;
}
[shader("vertex")]
float4 vertMain(float4 pos) : SV_Position
{
return pos;
}
)";

String userSource = userSourceBody;
ComPtr<slang::IGlobalSession> globalSession;
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
slang::TargetDesc targetDesc = {};
// Request SPIR-V disassembly so we can check the content.
targetDesc.format = SLANG_SPIRV_ASM;
targetDesc.profile = globalSession->findProfile("sm_5_0");
slang::SessionDesc sessionDesc = {};
sessionDesc.targetCount = 1;
sessionDesc.targets = &targetDesc;

ComPtr<slang::ISession> session;
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

ComPtr<slang::IBlob> diagnosticBlob;
auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);

ComPtr<slang::IComponentType> linkedProgram;
module->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
SLANG_CHECK(linkedProgram != nullptr);

ComPtr<slang::IBlob> code;
linkedProgram->getTargetCode(0, code.writeRef(), diagnosticBlob.writeRef());
SLANG_CHECK(code != nullptr);

SLANG_CHECK(code->getBufferSize() != 0);

UnownedStringSlice resultStr = UnownedStringSlice((char*)code->getBufferPointer());

// Make sure the spirv disassembly contains both entrypoint names.
SLANG_CHECK(resultStr.indexOf(toSlice("fragMain")) != -1);
SLANG_CHECK(resultStr.indexOf(toSlice("vertMain")) != -1);
}

0 comments on commit ccc26c2

Please sign in to comment.