You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#9478 was introduced to allow External Graph Transformers to be compiled along with ORT through the flag --external_graph_transformer_path PATH. This PR also introduced the flag --test_external_transformer_example to test it on CI.
However, #15416 (accidentally?!) removed the --test_external_transformer_example flag and CI doesn't test --external_graph_transformer_path anymore, which is now broken (didn't track the exact commit for the regression though)
The repro for --external_graph_transformer_path is from the original test on the PR #9478:
# run_repro.pyimportsysimportthreadingimporttimeclassOutputGrabber(object):
""" Class used to grab standard output or another stream. """escape_char="\b"def__init__(self, stream=None, threaded=False):
self.origstream=streamself.threaded=threadedifself.origstreamisNone:
self.origstream=sys.stdoutself.origstreamfd=self.origstream.fileno()
self.capturedtext=""# Create a pipe so the stream can be captured:self.pipe_out, self.pipe_in=os.pipe()
def__enter__(self):
self.start()
returnselfdef__exit__(self, type, value, traceback):
self.stop()
defstart(self):
""" Start capturing the stream data. """self.capturedtext=""# Save a copy of the stream:self.streamfd=os.dup(self.origstreamfd)
# Replace the original stream with our write pipe:os.dup2(self.pipe_in, self.origstreamfd)
ifself.threaded:
# Start thread that will read the stream:self.workerThread=threading.Thread(target=self.readOutput)
self.workerThread.start()
# Make sure that the thread is running and os.read() has executed:time.sleep(0.01)
defstop(self):
""" Stop capturing the stream data and save the text in `capturedtext`. """# Print the escape character to make the readOutput method stop:self.origstream.write(self.escape_char)
# Flush the stream to make sure all our data goes in before# the escape character:self.origstream.flush()
ifself.threaded:
# wait until the thread finishes so we are sure that# we have until the last character:self.workerThread.join()
else:
self.readOutput()
# Close the pipe:os.close(self.pipe_in)
os.close(self.pipe_out)
# Restore the original stream:os.dup2(self.streamfd, self.origstreamfd)
# Close the duplicate stream:os.close(self.streamfd)
defreadOutput(self):
""" Read the stream data (one byte at a time) and save the text in `capturedtext`. """whileTrue:
char=os.read(self.pipe_out,1).decode(self.origstream.encoding)
ifnotcharorself.escape_charinchar:
breakself.capturedtext+=charimporttorchfromonnxruntime.capiimport_pybind_stateastorch_ort_eagerimporttorch.nnasnnimporttorch.nn.functionalasFimportnumpyasnpimportosfromonnxruntime.trainingimportoptim, orttrainer, orttrainer_optionsimportunittestdefmy_loss(x, target):
returnF.nll_loss(F.log_softmax(x, dim=1), target)
classNeuralNet(nn.Module):
def__init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1=nn.Linear(input_size, hidden_size)
self.relu=nn.ReLU()
self.fc2=nn.Linear(hidden_size, num_classes)
defforward(self, x, target):
out=self.fc1(x)
out=self.relu(out)
out=self.fc2(out)
returnmy_loss(out, target)
classOrtEPTests(unittest.TestCase):
deftest_external_graph_transformer_triggering(self):
input_size=784hidden_size=500num_classes=10batch_size=128model=NeuralNet(input_size, hidden_size, num_classes)
model_desc= {'inputs': [('x', [batch_size, input_size]),
('target', [batch_size,])],
'outputs': [('loss', [], True)]}
optim_config=optim.SGDConfig()
opts=orttrainer.ORTTrainerOptions({'device':{'id':'cpu'}})
model=orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
# because orttrainer is lazy initialized, feed in a random data to trigger the graph transformerdata=torch.rand(batch_size, input_size)
target=torch.randint(0, 10, (batch_size,))
withOutputGrabber() asout:
loss=model.train_step(data, target)
assert'******************Trigger Customized Graph Transformer: MyGraphTransformer!'inout.capturedtextif__name__=='__main__':
unittest.main()
thiagocrepaldi
changed the title
[Build] --external_graph_transformer_path doesn't work after --test_external_transformer_example was removed from build.py
[Build] --external_graph_transformer_path doesn't. --test_external_transformer_example removed from build.py?
May 21, 2024
The "--test_external_transformer_example" was only used in eager mode. And the PR deleted eager mode code.
would be possible to readd it? For hardware vendors, it is very interesting to have external graph transformers available to experiment new fusions before proposing them as contribution to ort - if they can be publicly published
Describe the issue
@snnn FYI
#9478 was introduced to allow External Graph Transformers to be compiled along with ORT through the flag
--external_graph_transformer_path PATH
. This PR also introduced the flag--test_external_transformer_example
to test it on CI.However, #15416 (accidentally?!) removed the
--test_external_transformer_example
flag and CI doesn't test--external_graph_transformer_path
anymore, which is now broken (didn't track the exact commit for the regression though)The repro for
--external_graph_transformer_path
is from the original test on the PR #9478:Urgency
No response
Target platform
Linux
Build script
Then build with
Error / output
Visual Studio Version
No response
GCC / Compiler Version
No response
The text was updated successfully, but these errors were encountered: