Skip to content

Commit

Permalink
feat: Implement test case for aten::to.dtype lowering
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Jan 21, 2022
1 parent 4b3ae3a commit bde8ee0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
8 changes: 4 additions & 4 deletions core/lowering/passes/reduce_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
return (%out))IR";
std::string to_dtype_layout_pattern = R"IR(
graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
%out : Tensor = aten::to.dtype_layout(%x, %device, %dtype, %layout, %nb, %copy, %format, %other)
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
%out : Tensor = aten::to(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format)
return (%out))IR";

std::string to_dtype_multi_input_pattern = R"IR(
graph(%x, %device, %dtype, %layout, %nb, %copy, %format, %other):
%out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
return (%out))IR";

std::string to_type_as_pattern = R"IR(
Expand Down
22 changes: 22 additions & 0 deletions tests/core/lowering/test_reduce_to_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ TEST(LoweringPasses, ReduceToCorrectly) {
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, ReduceToDtypeLayoutCorrectly) {
std::string source_graph = R"IR(
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
%out : Tensor = aten::to(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format)
return (%out))IR";
std::string target_graph = R"IR(
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
return (%out))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
torch_tensorrt::core::lowering::passes::ReduceToOperation(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, ReduceAtenTypeAsCorrectly) {
std::string source_graph = R"IR(
graph(%input, %other):
Expand Down

0 comments on commit bde8ee0

Please sign in to comment.