Skip to content

Commit

Permalink
Add application variable tracking for code generation (#1037)
Browse files Browse the repository at this point in the history
Add the ability to track when applicable variables are handled during code generation.

* Add in single new base method for calling each node implementations `createVariables()`.  `ShaderGenerator::createVariables()`. 
* Add in a callback on `GenContext` to allow custom tracking of application variables:: `GenContext::setApplicationVariableHandler()`. The callback is performed when `ShaderGenerator::createVariables()` is invoked.
* For general filtering workflows including when callbacks are invoked add in `geometric` classification on 
`ShaderNode`.
* Note: As unique path is also available this can be used to back-reference to the original document or compare against 
input / output blocks on a Shader.
  • Loading branch information
kwokcb authored Aug 19, 2022
1 parent f3c9f00 commit 6745af3
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 14 deletions.
5 changes: 1 addition & 4 deletions source/MaterialXGenMdl/MdlShaderGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,10 +546,7 @@ ShaderPtr MdlShaderGenerator::createShader(const string& name, ElementPtr elemen
VariableBlockPtr outputs = stage->createOutputBlock(MDL::OUTPUTS);

// Create shader variables for all nodes that need this.
for (ShaderNode* node : graph->getNodes())
{
node->getImplementation().createVariables(*node, context, *shader);
}
createVariables(graph, context, *shader);

// Create inputs for the published graph interface.
for (ShaderGraphInputSocket* inputSocket : graph->getInputSockets())
Expand Down
5 changes: 1 addition & 4 deletions source/MaterialXGenOsl/OslShaderGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,7 @@ ShaderPtr OslShaderGenerator::createShader(const string& name, ElementPtr elemen
stage->createOutputBlock(OSL::OUTPUTS);

// Create shader variables for all nodes that need this.
for (ShaderNode* node : graph->getNodes())
{
node->getImplementation().createVariables(*node, context, *shader);
}
createVariables(graph, context, *shader);

// Create uniforms for the published graph interface.
VariableBlock& uniforms = stage->getUniformBlock(OSL::UNIFORMS);
Expand Down
2 changes: 2 additions & 0 deletions source/MaterialXGenShader/GenContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ GenContext::GenContext(ShaderGeneratorPtr sg) :
}

addReservedWords(reservedWords);

_applicationVariableHandler = nullptr;
}

void GenContext::addNodeImplementation(const string& name, ShaderNodeImplPtr impl)
Expand Down
17 changes: 17 additions & 0 deletions source/MaterialXGenShader/GenContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ MATERIALX_NAMESPACE_BEGIN

class ClosureContext;

/// A standard function to allow for handling of application variables for a given node
using ApplicationVariableHandler = std::function<void(ShaderNode*, GenContext&)>;

/// @class GenContext
/// A context class for shader generation.
/// Used for thread local storage of data needed during shader generation.
Expand Down Expand Up @@ -187,6 +190,18 @@ class MX_GENSHADER_API GenContext
/// @param suffix Suffix string returned. Is empty if not found.
void getOutputSuffix(const ShaderOutput* output, string& suffix) const;

/// Set handler for application variables
void setApplicationVariableHandler(ApplicationVariableHandler handler)
{
_applicationVariableHandler = handler;
}

/// Get handler for application variables
ApplicationVariableHandler getApplicationVariableHandler() const
{
return _applicationVariableHandler;
}

protected:
GenContext() = delete;

Expand All @@ -201,6 +216,8 @@ class MX_GENSHADER_API GenContext
std::unordered_map<const ShaderOutput*, string> _outputSuffix;

vector<ClosureContext*> _closureContexts;

ApplicationVariableHandler _applicationVariableHandler;
};


Expand Down
5 changes: 1 addition & 4 deletions source/MaterialXGenShader/HwShaderGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,7 @@ ShaderPtr HwShaderGenerator::createShader(const string& name, ElementPtr element
output->setPath(outputSocket->getPath());

// Create shader variables for all nodes that need this.
for (ShaderNode* node : graph->getNodes())
{
node->getImplementation().createVariables(*node, context, *shader);
}
createVariables(graph, context, *shader);

HwLightShadersPtr lightShaders = context.getUserData<HwLightShaders>(HW::USER_DATA_LIGHT_SHADERS);

Expand Down
13 changes: 13 additions & 0 deletions source/MaterialXGenShader/ShaderGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,4 +447,17 @@ ShaderStagePtr ShaderGenerator::createStage(const string& name, Shader& shader)
return shader.createStage(name, _syntax);
}

void ShaderGenerator::createVariables(ShaderGraphPtr graph, GenContext& context, Shader& shader) const
{
ApplicationVariableHandler handler = context.getApplicationVariableHandler();
for (ShaderNode* node : graph->getNodes())
{
if (handler)
{
handler(node, context);
}
node->getImplementation().createVariables(*node, context, shader);
}
}

MATERIALX_NAMESPACE_END
4 changes: 4 additions & 0 deletions source/MaterialXGenShader/ShaderGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ class MX_GENSHADER_API ShaderGenerator
/// Replace tokens with identifiers according to the given substitutions map.
void replaceTokens(const StringMap& substitutions, ShaderStage& stage) const;

/// Create shader variables (e.g. uniforms, inputs and outputs) for
/// nodes that require input data from the application.
void createVariables(ShaderGraphPtr graph, GenContext& context, Shader& shader) const;

protected:
static const string T_FILE_TRANSFORM_UV;

Expand Down
5 changes: 5 additions & 0 deletions source/MaterialXGenShader/ShaderNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ const string ShaderNode::TEXTURE2D_GROUPNAME = "texture2d";
const string ShaderNode::TEXTURE3D_GROUPNAME = "texture3d";
const string ShaderNode::PROCEDURAL2D_GROUPNAME = "procedural2d";
const string ShaderNode::PROCEDURAL3D_GROUPNAME = "procedural3d";
const string ShaderNode::GEOMETRIC_GROUPNAME = "geometric";

//
// ShaderNode methods
Expand Down Expand Up @@ -374,6 +375,10 @@ ShaderNodePtr ShaderNode::create(const ShaderGraph* parent, const string& name,
{
newNode->_classification |= Classification::SAMPLE3D;
}
else if (groupName == GEOMETRIC_GROUPNAME)
{
newNode->_classification |= Classification::GEOMETRIC;
}

// Create any metadata.
newNode->createMetadata(nodeDef, context);
Expand Down
2 changes: 2 additions & 0 deletions source/MaterialXGenShader/ShaderNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ class MX_GENSHADER_API ShaderNode
// Types based on nodegroup
static const uint32_t SAMPLE2D = 1 << 20; /// Can be sampled in 2D (uv space)
static const uint32_t SAMPLE3D = 1 << 21; /// Can be sampled in 3D (position)
static const uint32_t GEOMETRIC = 1 << 22; /// Geometric input
};

/// @struct ScopeInfo
Expand Down Expand Up @@ -403,6 +404,7 @@ class MX_GENSHADER_API ShaderNode
static const string TEXTURE3D_GROUPNAME;
static const string PROCEDURAL2D_GROUPNAME;
static const string PROCEDURAL3D_GROUPNAME;
static const string GEOMETRIC_GROUPNAME;

public:
/// Constructor.
Expand Down
102 changes: 102 additions & 0 deletions source/MaterialXTest/MaterialXGenShader/GenShader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,105 @@ TEST_CASE("GenShader: Track Dependencies", "[genshader]")
}
#endif
}

void variableTracker(mx::ShaderNode* node, mx::GenContext& /*context*/)
{
static mx::StringMap results;
results["primvar_one"] = "geompropvalue1/geomprop";
results["primvar_two"] = "geompropvalue2/geomprop";
results["0"] = "Tworld";
results["upstream_primvar"] = "constant/value";

if (node->hasClassification(mx::ShaderNode::Classification::GEOMETRIC))
{
const mx::ShaderInput* geomPropInput = node->getInput("geomprop");
if (geomPropInput && geomPropInput->getValue())
{
std::string prop = geomPropInput->getValue()->getValueString();
REQUIRE(results.count(prop));
REQUIRE(results[prop] == geomPropInput->getPath());
}
else
{
const mx::ShaderInput* indexIput = node->getInput("index");
if (indexIput && indexIput->getValue())
{
std::string prop = indexIput->getValue()->getValueString();
REQUIRE(results.count(prop));
REQUIRE(results[prop] == indexIput->getPath());
}
}
}
}

TEST_CASE("GenShader: Track Application Variables", "[genshader]")
{
std::string testDocumentString =
"<?xml version=\"1.0\"?> \
<materialx version=\"1.38\"> \
<geompropvalue name=\"geompropvalue\" type=\"color3\" > \
<input name=\"geomprop\" type=\"string\" uniform=\"true\" nodename=\"constant\" /> \
</geompropvalue> \
<geompropvalue name=\"geompropvalue1\" type=\"color3\" > \
<input name=\"geomprop\" type=\"string\" uniform=\"true\" value=\"primvar_one\" /> \
</geompropvalue> \
<geompropvalue name=\"geompropvalue2\" type=\"color3\" > \
<input name=\"geomprop\" type=\"string\" uniform=\"true\" value=\"primvar_two\" /> \
</geompropvalue> \
<multiply name=\"multiply\" type=\"color3\" > \
<input name=\"in1\" type=\"color3\" nodename=\"geompropvalue\" /> \
<input name=\"in2\" type=\"color3\" nodename=\"geompropvalue1\" /> \
</multiply> \
<add name=\"add\" type=\"color3\" > \
<input name=\"in1\" type=\"color3\" nodename=\"multiply\" /> \
<input name=\"in2\" type=\"color3\" nodename=\"geompropvalue2\" /> \
</add> \
<standard_surface name=\"standard_surface\" type=\"surfaceshader\" > \
<input name=\"base_color\" type=\"color3\" nodename=\"add\" /> \
</standard_surface> \
<constant name=\"constant\" type=\"string\" > \
<input name=\"value\" type=\"string\" uniform=\"true\" value=\"upstream_primvar\" /> \
</constant> \
<surfacematerial name=\"surfacematerial\" type=\"material\" > \
<input name=\"surfaceshader\" type=\"surfaceshader\" nodename=\"standard_surface\" /> \
</surfacematerial> \
</materialx>";

const mx::string testElement = "surfacematerial";

mx::DocumentPtr libraries = mx::createDocument();
mx::FileSearchPath searchPath(mx::FilePath::getCurrentPath());
mx::loadLibraries({ "libraries/targets", "libraries/stdlib", "libraries/pbrlib", "libraries/bxdf" }, searchPath, libraries);

mx::DocumentPtr testDoc = mx::createDocument();
mx::readFromXmlString(testDoc, testDocumentString);
testDoc->importLibrary(libraries);

mx::ElementPtr element = testDoc->getChild(testElement);
CHECK(element);

#ifdef MATERIALX_BUILD_GEN_GLSL
{
mx::GenContext context(mx::GlslShaderGenerator::create());
context.registerSourceCodeSearchPath(searchPath);
context.setApplicationVariableHandler(variableTracker);
mx::ShaderPtr shader = context.getShaderGenerator().generate(testElement, element, context);
}
#endif
#ifdef MATERIALX_BUILD_GEN_OSL
{
mx::GenContext context(mx::OslShaderGenerator::create());
context.registerSourceCodeSearchPath(searchPath);
context.setApplicationVariableHandler(variableTracker);
mx::ShaderPtr shader = context.getShaderGenerator().generate(testElement, element, context);
}
#endif
#ifdef MATERIALX_BUILD_GEN_MDL
{
mx::GenContext context(mx::MdlShaderGenerator::create());
context.registerSourceCodeSearchPath(searchPath);
context.setApplicationVariableHandler(variableTracker);
mx::ShaderPtr shader = context.getShaderGenerator().generate(testElement, element, context);
}
#endif
}
8 changes: 6 additions & 2 deletions source/PyMaterialX/PyMaterialXGenShader/PyGenContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ namespace mx = MaterialX;

void bindPyGenContext(py::module& mod)
{
py::class_<mx::ApplicationVariableHandler>(mod, "ApplicationVariableHandler");

py::class_<mx::GenContext, mx::GenContextPtr>(mod, "GenContext")
.def(py::init<mx::ShaderGeneratorPtr>())
.def("getShaderGenerator", &mx::GenContext::getShaderGenerator)
.def("getOptions", static_cast<mx::GenOptions& (mx::GenContext::*)()>(&mx::GenContext::getOptions), py::return_value_policy::reference)
.def("getOptions", static_cast<mx::GenOptions & (mx::GenContext::*)()>(&mx::GenContext::getOptions), py::return_value_policy::reference)
.def("registerSourceCodeSearchPath", static_cast<void (mx::GenContext::*)(const mx::FilePath&)>(&mx::GenContext::registerSourceCodeSearchPath))
.def("registerSourceCodeSearchPath", static_cast<void (mx::GenContext::*)(const mx::FileSearchPath&)>(&mx::GenContext::registerSourceCodeSearchPath))
.def("resolveSourceFile", &mx::GenContext::resolveSourceFile)
.def("pushUserData", &mx::GenContext::pushUserData);
.def("pushUserData", &mx::GenContext::pushUserData)
.def("setApplicationVariableHandler", &mx::GenContext::setApplicationVariableHandler)
.def("getApplicationVariableHandler", &mx::GenContext::getApplicationVariableHandler);
}

void bindPyGenUserData(py::module& mod)
Expand Down

0 comments on commit 6745af3

Please sign in to comment.