Skip to content

Commit

Permalink
refactored tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybdub committed Oct 2, 2019
1 parent 922a0f7 commit 5b1d552
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 51 deletions.
42 changes: 21 additions & 21 deletions test.sh
Expand Up @@ -7,24 +7,24 @@ touch $OUTPUT_FILE
echo "| Name | Data Type | Input Shapes | torch2trt kwargs | Max Error | Throughput (PyTorch) | Throughput (TensorRT) | Latency (PyTorch) | Latency (TensorRT) |" >> $OUTPUT_FILE
echo "|------|-----------|--------------|------------------|-----------|----------------------|-----------------------|-------------------|--------------------|" >> $OUTPUT_FILE

python3 -m torch2trt.test -o $OUTPUT_FILE --name alexnet
python3 -m torch2trt.test -o $OUTPUT_FILE --name squeezenet1_0
python3 -m torch2trt.test -o $OUTPUT_FILE --name squeezenet1_1
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet18
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet34
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet50
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet101
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet152
python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet121
python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet169
python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet201
python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet161
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg11$
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg13$
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg16$
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg19$
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg11_bn
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg13_bn
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg16_bn
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg19_bn
python3 -m torch2trt.test -o $OUTPUT_FILE --name mobilenet_v2
python3 -m torch2trt.test -o $OUTPUT_FILE --name alexnet --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name squeezenet1_0 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name squeezenet1_1 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet18 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet34 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet50 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet101 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet152 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet121 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet169 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet201 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet161 --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg11$ --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg13$ --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg16$ --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg19$ --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg11_bn --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg13_bn --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg16_bn --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg19_bn --include=torch2trt.tests.torchvision.classification
python3 -m torch2trt.test -o $OUTPUT_FILE --name mobilenet_v2 --include=torch2trt.tests.torchvision.classification
31 changes: 1 addition & 30 deletions torch2trt/module_test.py
Expand Up @@ -15,35 +15,6 @@ def module_name(self):


MODULE_TESTS = [
ModuleTest(torchvision.models.alexnet, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.squeezenet1_0, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.squeezenet1_1, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.resnet18, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.resnet34, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.resnet50, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.resnet101, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.resnet152, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.densenet121, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.densenet169, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.densenet201, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.densenet161, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.vgg11, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.vgg13, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.vgg16, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.vgg19, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.vgg11_bn, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.vgg13_bn, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.vgg16_bn, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.vgg19_bn, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.mobilenet_v2, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.shufflenet_v2_x0_5, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.shufflenet_v2_x1_0, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.shufflenet_v2_x1_5, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.shufflenet_v2_x2_0, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.mnasnet0_5, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.mnasnet0_75, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.mnasnet1_0, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
ModuleTest(torchvision.models.mnasnet1_3, torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True),
]


Expand All @@ -52,4 +23,4 @@ def register_module_test(module):
global MODULE_TESTS
MODULE_TESTS += [ModuleTest(module, dtype, device, input_shapes, **torch2trt_kwargs)]
return module
return register_module_test
return register_module_test
5 changes: 5 additions & 0 deletions torch2trt/test.py
Expand Up @@ -3,6 +3,7 @@
import time
import argparse
import re
import runpy
from termcolor import colored


Expand Down Expand Up @@ -87,7 +88,11 @@ def run(self):
parser.add_argument('--output', '-o', help='Test output file path', type=str, default='torch2trt_test.md')
parser.add_argument('--name', help='Regular expression to filter modules to test by name', type=str, default='.*')
parser.add_argument('--tolerance', help='Maximum error to print warning for entry', type=float, default='-1')
parser.add_argument('--include', help='Addition python file to include defining additional tests', action='append')
args = parser.parse_args()

for include in args.include:
runpy.run_module(include)

for test in MODULE_TESTS:

Expand Down
Empty file added torch2trt/tests/__init__.py
Empty file.
Empty file.
148 changes: 148 additions & 0 deletions torch2trt/tests/torchvision/classification.py
@@ -0,0 +1,148 @@
import torch
import torchvision
from torch2trt.module_test import add_module_test


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def alexnet():
return torchvision.models.alexnet(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def squeezenet1_0():
return torchvision.models.squeezenet1_0(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def squeezenet1_1():
return torchvision.models.squeezenet1_1(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def resnet18():
return torchvision.models.resnet18(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def resnet34():
return torchvision.models.resnet34(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def resnet50():
return torchvision.models.resnet50(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def resnet101():
return torchvision.models.resnet101(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def resnet152():
return torchvision.models.resnet152(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def densenet121():
return torchvision.models.densenet121(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def densenet169():
return torchvision.models.densenet169(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def densenet201():
return torchvision.models.densenet201(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def densenet161():
return torchvision.models.densenet161(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def vgg11():
return torchvision.models.vgg11(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def vgg13():
return torchvision.models.vgg13(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def vgg16():
return torchvision.models.vgg16(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def vgg19():
return torchvision.models.vgg19(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def vgg11_bn():
return torchvision.models.vgg11_bn(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def vgg13_bn():
return torchvision.models.vgg13_bn(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def vgg16_bn():
return torchvision.models.vgg16_bn(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def vgg19_bn():
return torchvision.models.vgg19_bn(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def mobilenet_v2():
return torchvision.models.mobilenet_v2(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def shufflenet_v2_x0_5():
return torchvision.models.shufflenet_v2_x0_5(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def shufflenet_v2_x1_0():
return torchvision.models.shufflenet_v2_x1_0(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def shufflenet_v2_x1_5():
return torchvision.models.shufflenet_v2_x1_5(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def shufflenet_v2_x2_0():
return torchvision.models.shufflenet_v2_x2_0(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def mnasnet0_5():
return torchvision.models.mnasnet0_5(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def mnasnet0_75():
return torchvision.models.mnasnet0_75(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def mnasnet1_0():
return torchvision.models.mnasnet1_0(pretrained=False)


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def mnasnet1_3():
return torchvision.models.mnasnet1_3(pretrained=False)
39 changes: 39 additions & 0 deletions torch2trt/tests/torchvision/segmentation.py
@@ -0,0 +1,39 @@
import torch
import torchvision
from torch2trt.module_test import add_module_test


class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super(ModelWrapper, self).__init__()
self.model = model
def forward(self, x):
return self.model(x)['out']


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def deeplabv3_resnet50():
bb = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False)
model = ModelWrapper(bb)
return model


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def deeplabv3_resnet101():
bb = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=False)
model = ModelWrapper(bb)
return model


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def fcn_resnet50():
bb = torchvision.models.segmentation.fcn_resnet50(pretrained=False)
model = ModelWrapper(bb)
return model


@add_module_test(torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True)
def fcn_resnet101():
bb = torchvision.models.segmentation.fcn_resnet101(pretrained=False)
model = ModelWrapper(bb)
return model

0 comments on commit 5b1d552

Please sign in to comment.