Skip to content

Commit

Permalink
pnnx convert torch narrow (#4918)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyt1024 committed Sep 4, 2023
1 parent 14b000d commit b3fbbcc
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 0 deletions.
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Expand Up @@ -237,6 +237,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_mean.cpp
pass_level2/torch_min.cpp
pass_level2/torch_mm.cpp
pass_level2/torch_narrow.cpp
pass_level2/torch_ne.cpp
pass_level2/torch_norm.cpp
pass_level2/torch_normal.cpp
Expand Down
43 changes: 43 additions & 0 deletions tools/pnnx/src/pass_level2/torch_narrow.cpp
@@ -0,0 +1,43 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_narrow : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 dim
pnnx.Input input_2 0 1 start
pnnx.Input input_3 0 1 length
aten::narrow op_0 4 1 input dim start length out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.narrow";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_narrow, 20)

} // namespace pnnx
1 change: 1 addition & 0 deletions tools/pnnx/tests/CMakeLists.txt
Expand Up @@ -212,6 +212,7 @@ pnnx_add_test(torch_max)
pnnx_add_test(torch_mean)
pnnx_add_test(torch_min)
pnnx_add_test(torch_mm)
pnnx_add_test(torch_narrow)
pnnx_add_test(torch_ne)
pnnx_add_test(torch_norm)
pnnx_add_test(torch_ones)
Expand Down
63 changes: 63 additions & 0 deletions tools/pnnx/tests/test_torch_narrow.py
@@ -0,0 +1,63 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
out0 = torch.narrow(x, 0, 0, 2)
out1 = torch.narrow(x, 1, 1, 2)
out2 = torch.narrow(y, 0, 0, 2)
out3 = torch.narrow(y, 1, 1, 2)
out4 = torch.narrow(z, 0, 0, 2)
out5 = torch.narrow(z, 1, 1, 2)
return out0, out1, out2, out3, out4, out5

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 3)
y = torch.rand(5, 3)
z = torch.rand(3, 5)
a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_narrow.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_narrow.pt inputshape=[3,3],[5,3],[3,5]")

# pnnx inference
import test_torch_narrow_pnnx
b = test_torch_narrow_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

0 comments on commit b3fbbcc

Please sign in to comment.