Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to flattenSubgraphs #1101

Merged
merged 1 commit into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 77 additions & 133 deletions source/MaterialXCore/Node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,45 +251,42 @@ NodePtr GraphElement::addMaterialNode(const string& name, ConstNodePtr shaderNod

void GraphElement::flattenSubgraphs(const string& target, NodePredicate filter)
{
vector<NodePtr> processNodeVec = getNodes();
while (!processNodeVec.empty())
vector<NodePtr> nodeQueue = getNodes();
while (!nodeQueue.empty())
{
// Precompute graph implementations and downstream ports for this node vector.
// Determine which nodes require processing, and precompute declarations
// and graph implementations for these nodes.
using PortElementVec = vector<PortElementPtr>;
std::vector<NodePtr> processNodeVec;
std::unordered_map<NodePtr, NodeGraphPtr> graphImplMap;
std::unordered_map<NodePtr, ConstNodeDefPtr> declarationMap;
std::unordered_map<NodePtr, PortElementVec> downstreamPortMap;
for (NodePtr cacheNode : processNodeVec)
for (NodePtr node : nodeQueue)
{
InterfaceElementPtr implement = cacheNode->getImplementation(target);
if (!implement || !implement->isA<NodeGraph>())
if (filter && !filter(node))
{
continue;
}
NodeGraphPtr subNodeGraph = implement->asA<NodeGraph>();
graphImplMap[cacheNode] = subNodeGraph;
downstreamPortMap[cacheNode] = cacheNode->getDownstreamPorts();
for (NodePtr subNode : subNodeGraph->getNodes())

InterfaceElementPtr implement = node->getImplementation(target);
if (implement && implement->isA<NodeGraph>())
{
downstreamPortMap[subNode] = subNode->getDownstreamPorts();
processNodeVec.push_back(node);
graphImplMap[node] = implement->asA<NodeGraph>();
declarationMap[node] = node->getDeclaration(target);
downstreamPortMap[node] = node->getDownstreamPorts();
for (NodePtr sourceSubNode : implement->asA<NodeGraph>()->getNodes())
{
downstreamPortMap[sourceSubNode] = sourceSubNode->getDownstreamPorts();
}
}
}
processNodeVec.clear();

// Attributes in addition to value to copy over
StringVec copyAttributes = { ValueElement::UNIT_ATTRIBUTE,
ValueElement::UNITTYPE_ATTRIBUTE,
ValueElement::COLOR_SPACE_ATTRIBUTE };
nodeQueue.clear();

// Iterate through nodes with graph implementations.
for (const auto& pair : graphImplMap)
for (NodePtr processNode : processNodeVec)
{
NodePtr processNode = pair.first;
if (filter && !filter(processNode))
{
continue;
}

NodeGraphPtr sourceSubGraph = pair.second;
NodeGraphPtr sourceSubGraph = graphImplMap[processNode];
std::unordered_map<NodePtr, NodePtr> subNodeMap;

// Create a new instance of each original subnode.
Expand All @@ -302,150 +299,97 @@ void GraphElement::flattenSubgraphs(const string& target, NodePredicate filter)
destSubNode->copyContentFrom(sourceSubNode);
setChildIndex(destSubNode->getName(), getChildIndex(processNode->getName()));

// Transfer interface properties from the reference node to the new subnode.
for (ValueElementPtr destValue : destSubNode->getChildrenOfType<ValueElement>())
{
if (!destValue->hasInterfaceName())
{
continue;
}

ValueElementPtr refValue = processNode->getChildOfType<ValueElement>(destValue->getInterfaceName());
if (refValue)
{
if (refValue->hasValueString())
{
destValue->setValueString(refValue->getValueString());
}
for (auto copyAttribute : copyAttributes)
{
if (refValue->hasAttribute(copyAttribute))
{
destValue->setAttribute(copyAttribute, refValue->getAttribute(copyAttribute));
}
}
if (destValue->isA<Input>() && refValue->isA<Input>())
{
InputPtr refInput = refValue->asA<Input>();
InputPtr newInput = destValue->asA<Input>();
if (refInput->hasNodeName())
{
newInput->setNodeName(refInput->getNodeName());
}
if (refInput->hasOutputString())
{
newInput->setOutputString(refInput->getOutputString());
}
if (refInput->hasNodeGraphString())
{
newInput->setNodeGraphString(refInput->getNodeGraphString());
}
}
}
destValue->removeAttribute(ValueElement::INTERFACE_NAME_ATTRIBUTE);
}

// Store the mapping between subgraphs.
subNodeMap[sourceSubNode] = destSubNode;

// Add the subnode to the queue, allowing processing of nested subgraphs.
processNodeVec.push_back(destSubNode);
nodeQueue.push_back(destSubNode);
}

// Transfer internal connections between subgraphs.
// Update properties of generated subnodes.
for (const auto& subNodePair : subNodeMap)
{
NodePtr sourceSubNode = subNodePair.first;
NodePtr destSubNode = subNodePair.second;

// Update node connections.
for (PortElementPtr sourcePort : downstreamPortMap[sourceSubNode])
{
if (sourcePort->isA<Input>())
{
auto it = subNodeMap.find(sourcePort->getParent()->asA<Node>());
if (it != subNodeMap.end())
{
it->second->setConnectedNode(sourcePort->getName(), destSubNode);
InputPtr processNodeInput = it->second->getInput(sourcePort->getName());
if (processNodeInput)
{
processNodeInput->setNodeName(destSubNode->getName());
}
}
}
else if (sourcePort->isA<Output>())
{
for (PortElementPtr processNodePort : downstreamPortMap[processNode])
{
processNodePort->setConnectedNode(destSubNode);
processNodePort->setNodeName(destSubNode->getName());
}
}
}
}

// Connect any nodegraph outputs within the graph which point to another
// flatten node within the nodegraph. As it's been flattened the previous
// reference is incorrect and needs to be updated.
if (sourceSubGraph->getOutputCount())
{
for (OutputPtr sourceOutput : getOutputs())
// Transfer interface properties.
for (InputPtr destInput : destSubNode->getInputs())
{
const string& nodeNameString = sourceOutput->getNodeName();
const string& outputString = sourceOutput->getOutputString();

if (nodeNameString != processNode->getName())
{
continue;
}

// Look for what the original output pointed to.
OutputPtr sourceSubGraphOutput = outputString.empty() ? sourceSubGraph->getOutputs()[0] : sourceSubGraph->getOutput(outputString);
if (!sourceSubGraphOutput)
if (destInput->hasInterfaceName())
{
continue;
}

string destName = sourceSubGraphOutput->getNodeName();
if (destName.empty())
{
destName = sourceSubGraphOutput->getNodeGraphString();
}
NodePtr sourceSubNode = sourceSubGraph->getNode(destName);
NodePtr destNode = sourceSubNode ? subNodeMap[sourceSubNode] : nullptr;
if (destNode)
{
destName = destNode->getName();
InputPtr sourceInput = processNode->getInput(destInput->getInterfaceName());
if (sourceInput)
{
destInput->copyContentFrom(sourceInput);
}
else
{
ConstNodeDefPtr declaration = declarationMap[processNode];
InputPtr declInput = declaration ? declaration->getActiveInput(destInput->getInterfaceName()) : nullptr;
if (declInput)
{
if (declInput->hasValueString())
{
destInput->setValueString(declInput->getValueString());
}
if (declInput->hasDefaultGeomPropString())
{
ConstGeomPropDefPtr geomPropDef = getDocument()->getGeomPropDef(declInput->getDefaultGeomPropString());
if (geomPropDef)
{
destInput->setConnectedNode(addGeomNode(geomPropDef, "geomNode"));
}
}
}
}
destInput->removeAttribute(ValueElement::INTERFACE_NAME_ATTRIBUTE);
}

// Point original output to this one
sourceOutput->setNodeName(destName);
}
}

// If the node was flattened then any downstream references
// need to be updated to point to the new root of the flatten node.
PortElementVec downstreamPorts = downstreamPortMap[processNode];
for (auto downstreamPort : downstreamPorts)
// Update downstream ports with connections to subgraph outputs.
for (PortElementPtr downstreamPort : downstreamPortMap[processNode])
{
const string& outputString = downstreamPort->getOutputString();

// Look for an output on the flattened graph
OutputPtr sourceSubGraphOutput = outputString.empty() ? sourceSubGraph->getOutputs()[0] : sourceSubGraph->getOutput(outputString);
if (!sourceSubGraphOutput)
{
continue;
}

// Find connected node to the output
string destName = sourceSubGraphOutput->getNodeName();
if (destName.empty())
if (downstreamPort->hasOutputString())
{
destName = sourceSubGraphOutput->getNodeGraphString();
}
NodePtr sourceSubNode = sourceSubGraph->getNode(destName);
NodePtr destNode = sourceSubNode ? subNodeMap[sourceSubNode] : nullptr;
if (destNode)
{
destName = destNode->getName();
OutputPtr subGraphOutput = sourceSubGraph->getOutput(downstreamPort->getOutputString());
if (subGraphOutput)
{
string destName = subGraphOutput->getNodeName();
NodePtr sourceSubNode = sourceSubGraph->getNode(destName);
NodePtr destNode = sourceSubNode ? subNodeMap[sourceSubNode] : nullptr;
if (destNode)
{
destName = destNode->getName();
}
downstreamPort->setNodeName(destName);
downstreamPort->setOutputString(EMPTY_STRING);
}
}

// Use that node to overwrite downstream port connection
downstreamPort->setNodeName(destName);
downstreamPort->setOutputString(EMPTY_STRING);
}

// The processed node has been replaced, so remove it from the graph.
Expand Down
11 changes: 8 additions & 3 deletions source/MaterialXCore/Node.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,14 @@ class MX_CORE_API GraphElement : public InterfaceElement
/// @}
/// @name Utility
/// @{

/// Flatten any references to graph-based node definitions within this
/// node graph, replacing each reference with the equivalent node network.

/// Flatten all subgraphs at the root scope of this graph element,
/// recursively replacing each graph-defined node with its equivalent
/// node network.
/// @param target An optional target string to be used in specifying
/// which node definitions are used in this process.
/// @param filter An optional node predicate specifying which nodes
/// should be included and excluded from this process.
void flattenSubgraphs(const string& target = EMPTY_STRING, NodePredicate filter = nullptr);

/// Return a vector of all children (nodes and outputs) sorted in
Expand Down
12 changes: 6 additions & 6 deletions source/MaterialXTest/MaterialXFormat/XmlIo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ TEST_CASE("Load content", "[xmlio]")
mx::readFromXmlBuffer(writtenDoc, xmlString.c_str());
REQUIRE(*writtenDoc == *doc);

// Flatten subgraph references.
for (mx::NodeGraphPtr nodeGraph : doc->getNodeGraphs())
// Flatten all subgraphs.
doc->flattenSubgraphs();
for (mx::NodeGraphPtr graph : doc->getNodeGraphs())
{
if (nodeGraph->getActiveSourceUri() != doc->getSourceUri())
if (graph->getActiveSourceUri() == doc->getSourceUri())
{
continue;
graph->flattenSubgraphs();
}
nodeGraph->flattenSubgraphs();
REQUIRE(nodeGraph->validate());
}
REQUIRE(doc->validate());

// Verify that all referenced types and nodes are declared.
bool referencesValid = true;
Expand Down
21 changes: 21 additions & 0 deletions source/MaterialXView/Viewer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ Viewer::Viewer(const std::string& materialFilename,
_splitByUdims(true),
_mergeMaterials(false),
_showAllInputs(false),
_flattenSubgraphs(false),
_targetShader("standard_surface"),
_captureRequested(false),
_exitRequested(false),
Expand Down Expand Up @@ -892,6 +893,13 @@ void Viewer::createAdvancedSettings(Widget* parent)
_showAllInputs = enable;
});

ng::CheckBox* flattenBox = new ng::CheckBox(advancedPopup, "Flatten Subgraphs");
flattenBox->set_checked(_flattenSubgraphs);
flattenBox->set_callback([this](bool enable)
{
_flattenSubgraphs = enable;
});

ng::CheckBox* splitDirectLightBox = new ng::CheckBox(advancedPopup, "Split Direct Light");
splitDirectLightBox->set_checked(_splitDirectLight);
splitDirectLightBox->set_callback([this](bool enable)
Expand Down Expand Up @@ -1173,6 +1181,19 @@ void Viewer::loadDocument(const mx::FilePath& filename, mx::DocumentPtr librarie
// Apply modifiers to the content document.
applyModifiers(doc, _modifiers);

// Flatten subgraphs if requested.
if (_flattenSubgraphs)
{
doc->flattenSubgraphs();
for (mx::NodeGraphPtr graph : doc->getNodeGraphs())
{
if (graph->getActiveSourceUri() == doc->getActiveSourceUri())
{
graph->flattenSubgraphs();
}
}
}

// Validate the document.
std::string message;
if (!doc->validate(&message))
Expand Down
1 change: 1 addition & 0 deletions source/MaterialXView/Viewer.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ class Viewer : public ng::Screen
bool _splitByUdims;
bool _mergeMaterials;
bool _showAllInputs;
bool _flattenSubgraphs;

// Shader translation
std::string _targetShader;
Expand Down