Skip to content

Commit

Permalink
fix(//core/conversion/evaluators): A couple fixes for evaluators
Browse files Browse the repository at this point in the history
- Fixes aten::append to correctly append values and not pointers
- Fixes prim::RaiseException and aten::warn to print out strings

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jun 14, 2020
1 parent 6421f3d commit 07ba980
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ auto aten_registrations TRTORCH_UNUSED = RegisterNodeEvaluators()
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
auto el = args.at(n->input(1)).IValue();

list.push_back(std::move(el));
list.push_back(std::move(*el));
return list;
},
EvalOptions().validSchemas({
Expand Down Expand Up @@ -430,16 +430,16 @@ auto aten_registrations TRTORCH_UNUSED = RegisterNodeEvaluators()
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto el = args.at(n->input(0)).unwrapToDouble();

return std::floor(el);
return static_cast<int64_t>(std::floor(el));
},
EvalOptions().validSchemas({
"aten::floor.float(float a) -> (int)",
})
}).evaluator({
c10::Symbol::fromQualString("aten::warn"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto warning = args.at(n->input(0)).IValue()->toString();
LOG_WARNING(warning);
auto warning = args.at(n->input(0)).IValue();
LOG_WARNING("Warning from TorchScript: " << *warning);
return {};
},
EvalOptions()
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ auto prim_registrations = RegisterNodeEvaluators()
}).evaluator({
c10::Symbol::fromQualString("prim::RaiseException"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto exception = args.at(n->input(0)).IValue()->toString();
TRTORCH_THROW_ERROR(exception);
auto exception = args.at(n->input(0)).IValue();
TRTORCH_THROW_ERROR("Error from TorchScript: " << *exception);
return {};
}
});
Expand Down

0 comments on commit 07ba980

Please sign in to comment.