Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable layer fusion optimizations? #252

Closed
jasonliu19 opened this issue Dec 4, 2019 · 6 comments
Closed

Disable layer fusion optimizations? #252

jasonliu19 opened this issue Dec 4, 2019 · 6 comments
Labels
Precision: FP16 question Further information is requested TODO

Comments

@jasonliu19
Copy link

jasonliu19 commented Dec 4, 2019

Is there a way to disable layer fusion when building an engine? I'm facing some correctness problems when scale layers are fused together in fp16. Disabling fusion to help debug this issue would be useful.

Environment

TensorRT Version: 6.0.1.5
GPU Type: GTX 1080ti
Nvidia Driver Version: 418.87.00
CUDA Version: 10.1
CUDNN Version: 7.6.3
Operating System + Version: Ubuntu 18.04.3
Python Version (if applicable): 3.6.8
TensorFlow Version (if applicable): n/a
PyTorch Version (if applicable): 1.3.0
Baremetal or Container (if container which image + tag): baremetal

@rmccorm4
Copy link
Collaborator

rmccorm4 commented Dec 4, 2019

Hi @jasonliu19,

Sounds like an interesting problem, I don't think that should be expected to happen. Could you provide a simple model and script to reproduce so we can look into the underlying issue?

We don't currently expose turning off layer fusion.

Edit: Also just for completeness, can you share your environment info from the issue template?

@rmccorm4 rmccorm4 added needs-info Precision: FP16 question Further information is requested labels Dec 4, 2019
@jasonliu19
Copy link
Author

jasonliu19 commented Dec 5, 2019

Added my environment above.
Here's a simple script with 2 scale layers. I'm not sure if fusion is causing the discrepancies between PyTorch and TensorRT that I'm seeing, but it's something that I noticed. I also observed that single scale layers sometimes differ from equivalent PyTorch operations given certain inputs.

import numpy as np
import tensorrt as trt
import torch

SEED = 123
torch.manual_seed(SEED)
np.random.seed(SEED)


def test_output_equality(output_base, output_diff):
    """Test model output equality."""
    for k in range(len(output_base)):
        base = output_base[k]
        diff = output_diff[k]
        assert base.dtype == diff.dtype, "dtypes do not match {} != {}".format(base.dtype, diff.dtype)
        assert base.shape == diff.shape, "shapes do not match {} != {}".format(base.shape, diff.shape)
        total_count = base.numel()
        epsilons = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
        print("---Output {}---".format(k))
        for idx, epsilon in enumerate(epsilons):
            # Count how many tensors are different from base, using the provided epsilon threshold
            failed = torch.gt(torch.abs(torch.add(base, -diff)), epsilon)
            diff_count = torch.sum(failed).item()
            diff_percent = float(diff_count) / float(total_count) * 100.
            print(
                "  Epsilon {}) base and diff are different for {} values ({:.2f}%)".format(
                    epsilon, diff_count, diff_percent
                )
            )
        print("Max difference:", (base - diff).abs().max())


def add_scale(network, trt_input, shift, scale, power, dtype):
    scale_layer = network.add_scale(trt_input, trt.ScaleMode.CHANNEL, shift, scale, power)
    scale_layer.precision = dtype
    out_tensor = scale_layer.get_output(0)
    return out_tensor


def compare_torch_and_trt_scale():
    scale_a, shift_a, scale_b, shift_b = (
        np.random.randn(64).astype(np.float16) * 3,
        np.random.randn(64).astype(np.float16),
        np.random.randn(64).astype(np.float16) * 3,
        np.random.randn(64).astype(np.float16),
    )
    logger = trt.Logger(trt.Logger.VERBOSE)
    builder = trt.Builder(logger)
    builder.fp16_mode = True
    builder.max_batch_size = 1
    builder.strict_type_constraints = True
    network = builder.create_network()
    torch_input = torch.ones((1, 64, 512, 512), dtype=torch.float16, device="cuda").contiguous()
    trt_output = torch.zeros((1, 64, 512, 512), dtype=torch.float16, device="cuda").contiguous()
    trt_input = network.add_input(name="input_0", shape=tuple(torch_input.shape)[1:], dtype=trt.float16)
    trt_input.location = trt.TensorLocation.DEVICE
    scale_out = add_scale(network, trt_input, shift_a, scale_a, trt.Weights(trt.float16), trt.float16)
    scale_out = add_scale(network, scale_out, shift_b, scale_b, trt.Weights(trt.float16), trt.float16)
    scale_out.name = "output_0"
    scale_out.location = trt.TensorLocation.DEVICE
    scale_out.dtype = trt.float16
    network.mark_output(scale_out)
    engine = builder.build_cuda_engine(network)
    bindings = [None] * 2
    bindings[engine.get_binding_index("input_0")] = torch_input.data_ptr()
    bindings[engine.get_binding_index("output_0")] = trt_output.data_ptr()
    context = engine.create_execution_context()
    context.execute(1, bindings)

    print("Pytorch no fusion vs trt")
    scale_a_torch = torch.tensor(scale_a, dtype=torch.float16, device="cuda").reshape(1, -1, 1, 1)
    scale_b_torch = torch.tensor(scale_b, dtype=torch.float16, device="cuda").reshape(1, -1, 1, 1)
    shift_a_torch = torch.tensor(shift_a, dtype=torch.float16, device="cuda").reshape(1, -1, 1, 1)
    shift_b_torch = torch.tensor(shift_b, dtype=torch.float16, device="cuda").reshape(1, -1, 1, 1)
    torch_output = torch_input * scale_a_torch
    torch_output = torch_output + shift_a_torch
    torch_output = torch_output * scale_b_torch
    torch_output = torch_output + shift_b_torch
    test_output_equality([trt_output], [torch_output])

    print("Pytorch fusion in fp16 vs trt")
    scale_fused = scale_a_torch * scale_b_torch
    shift_fused = shift_a_torch * scale_b_torch + shift_b_torch
    torch_output = torch_input * scale_fused + shift_fused
    test_output_equality([trt_output], [torch_output])

    print("Pytorch fusion in fp32 vs trt")
    scale_fused = scale_a_torch.to(torch.float32) * scale_b_torch.to(torch.float32)
    shift_fused = shift_a_torch.to(torch.float32) * scale_b_torch.to(torch.float32) + shift_b_torch.to(torch.float32)
    torch_output = torch_input * scale_fused.to(torch.float16) + shift_fused.to(torch.float16)
    test_output_equality([trt_output], [torch_output])


if __name__ == "__main__":
    compare_torch_and_trt_scale()

@rmccorm4
Copy link
Collaborator

rmccorm4 commented Dec 5, 2019

Repro'd

  • V100
  • TensorRT NGC Container 19.10-py3
  • PyTorch 1.3
nvidia-docker run -it -v ${PWD}:/mnt --workdir=/mnt nvcr.io/nvidia/tensorrt:19.10-py3
python torch_vs_trt.py

[TensorRT] WARNING: Tensor DataType is determined at build time for tensors not marked as input or output.
[TensorRT] VERBOSE: Applying generic optimizations to the graph for inference.
[TensorRT] VERBOSE: Original: 2 layers
[TensorRT] VERBOSE: After dead-layer removal: 2 layers
[TensorRT] VERBOSE: After scale fusion: 2 layers
[TensorRT] VERBOSE: Fusing (Unnamed Layer* 0) [Scale] with (Unnamed Layer* 1) [Scale]
[TensorRT] VERBOSE: After vertical fusions: 1 layers
[TensorRT] VERBOSE: After final dead-layer removal: 1 layers
[TensorRT] VERBOSE: After tensor merging: 1 layers
[TensorRT] VERBOSE: After concat removal: 1 layers
[TensorRT] VERBOSE: Graph construction and optimization completed in 0.000323121 seconds.
[TensorRT] VERBOSE: Constructing optimization profile number 0 out of 1
--------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 0.234496
[TensorRT] VERBOSE: Tactic: 0 time 0.246784
[TensorRT] VERBOSE: Fastest Tactic: 1002 Time: 0.234496
[TensorRT] VERBOSE: --------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 0.306176
[TensorRT] VERBOSE: Tactic: 0 time 3.7079
[TensorRT] VERBOSE: Fastest Tactic: 1002 Time: 0.306176
[TensorRT] VERBOSE: --------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 0.509952
[TensorRT] VERBOSE: Tactic: 0 time 0.221184
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.221184
[TensorRT] VERBOSE: --------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 0.30208
[TensorRT] VERBOSE: Tactic: 0 time 0.101376
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.101376
[TensorRT] VERBOSE: *************** Autotuning format combination: Float(1,512,262144,16777216) -> Float(1,512,262144,16777216) ***************
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Scale] + (Unnamed Layer* 1) [Scale] (Scale)
[TensorRT] VERBOSE: Tactic: 0 time 0.170976
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.170976
[TensorRT] VERBOSE: *************** Autotuning format combination: Float(1,512,262144:32,524288) -> Float(1,512,262144:32,524288) ***************
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Scale] + (Unnamed Layer* 1) [Scale] (Scale)
[TensorRT] VERBOSE: Scale has no valid tactics for this config, skipping
[TensorRT] VERBOSE: *************** Autotuning format combination: Half(1,512,262144,16777216) -> Half(1,512,262144,16777216) ***************
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Scale] + (Unnamed Layer* 1) [Scale] (Scale)
[TensorRT] VERBOSE: Tactic: 0 time 0.1024
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.1024
[TensorRT] VERBOSE: *************** Autotuning format combination: Half(1,512,262144:2,8388608) -> Half(1,512,262144:2,8388608) ***************
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Scale] + (Unnamed Layer* 1) [Scale] (Scale)
[TensorRT] VERBOSE: Tactic: 0 time 0.104448
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.104448
[TensorRT] VERBOSE: *************** Autotuning format combination: Half(8,4096,1:8,2097152) -> Half(8,4096,1:8,2097152) ***************
[TensorRT] VERBOSE: --------------- Timing Runner: (Unnamed Layer* 0) [Scale] + (Unnamed Layer* 1) [Scale] (Scale)
[TensorRT] VERBOSE: Tactic: 0 time 0.10752
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.10752
[TensorRT] VERBOSE: --------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 0.227328
[TensorRT] VERBOSE: Tactic: 0 time 0.24576
[TensorRT] VERBOSE: Fastest Tactic: 1002 Time: 0.227328
[TensorRT] VERBOSE: --------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 1.26464
[TensorRT] VERBOSE: Tactic: 0 time 1.33837
[TensorRT] VERBOSE: Fastest Tactic: 1002 Time: 1.26464
[TensorRT] VERBOSE: --------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 0.521216
[TensorRT] VERBOSE: Tactic: 0 time 0.221184
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.221184
[TensorRT] VERBOSE: --------------- Timing Runner: <reformat> (Reformat)
[TensorRT] VERBOSE: Tactic: 1002 time 0.299008
[TensorRT] VERBOSE: Tactic: 0 time 0.091136
[TensorRT] VERBOSE: Fastest Tactic: 0 Time: 0.091136
[TensorRT] VERBOSE: Formats and tactics selection completed in 0.113438 seconds.
[TensorRT] VERBOSE: After reformat layers: 1 layers
[TensorRT] VERBOSE: Block size 0
[TensorRT] VERBOSE: Total Activation Memory: 0
[TensorRT] INFO: Detected 1 inputs and 1 output network tensors.
[TensorRT] VERBOSE: Engine generation completed in 1.87506 seconds.
[TensorRT] VERBOSE: Engine Layer Information:
[TensorRT] VERBOSE: Layer: (Unnamed Layer* 0) [Scale] + (Unnamed Layer* 1) [Scale] (Scale), Tactic: 0, input_0[Half(64,512,512)] -> output_0[Half(64,512,512)]

===========Pytorch no fusion vs TRT===========
---Output 0---
  Epsilon 0.1) base and diff are different for 0 values (0.00%)
Max difference: tensor(0.0312, device='cuda:0', dtype=torch.float16)
  Epsilon 0.01) base and diff are different for 262144 values (1.56%)
Max difference: tensor(0.0312, device='cuda:0', dtype=torch.float16)
  Epsilon 0.001) base and diff are different for 3670016 values (21.88%)
Max difference: tensor(0.0312, device='cuda:0', dtype=torch.float16)
  Epsilon 0.0001) base and diff are different for 6291456 values (37.50%)
Max difference: tensor(0.0312, device='cuda:0', dtype=torch.float16)
  Epsilon 1e-05) base and diff are different for 6291456 values (37.50%)
Max difference: tensor(0.0312, device='cuda:0', dtype=torch.float16)
  Epsilon 1e-06) base and diff are different for 6291456 values (37.50%)
Max difference: tensor(0.0312, device='cuda:0', dtype=torch.float16)
  Epsilon 1e-07) base and diff are different for 6291456 values (37.50%)
Max difference: tensor(0.0312, device='cuda:0', dtype=torch.float16)
  Epsilon 1e-08) base and diff are different for 6291456 values (37.50%)
Max difference: tensor(0.0312, device='cuda:0', dtype=torch.float16)

===========Pytorch fusion in fp16 vs TRT===========
---Output 0---
  Epsilon 0.1) base and diff are different for 0 values (0.00%)
Max difference: tensor(0.0039, device='cuda:0', dtype=torch.float16)
  Epsilon 0.01) base and diff are different for 0 values (0.00%)
Max difference: tensor(0.0039, device='cuda:0', dtype=torch.float16)
  Epsilon 0.001) base and diff are different for 524288 values (3.12%)
Max difference: tensor(0.0039, device='cuda:0', dtype=torch.float16)
  Epsilon 0.0001) base and diff are different for 1048576 values (6.25%)
Max difference: tensor(0.0039, device='cuda:0', dtype=torch.float16)
  Epsilon 1e-05) base and diff are different for 1048576 values (6.25%)
Max difference: tensor(0.0039, device='cuda:0', dtype=torch.float16)
  Epsilon 1e-06) base and diff are different for 1048576 values (6.25%)
Max difference: tensor(0.0039, device='cuda:0', dtype=torch.float16)
  Epsilon 1e-07) base and diff are different for 1048576 values (6.25%)
Max difference: tensor(0.0039, device='cuda:0', dtype=torch.float16)
  Epsilon 1e-08) base and diff are different for 1048576 values (6.25%)
Max difference: tensor(0.0039, device='cuda:0', dtype=torch.float16)

===========Pytorch fusion in fp32 vs TRT===========
---Output 0---
  Epsilon 0.1) base and diff are different for 0 values (0.00%)
Max difference: tensor(0., device='cuda:0', dtype=torch.float16)
  Epsilon 0.01) base and diff are different for 0 values (0.00%)
Max difference: tensor(0., device='cuda:0', dtype=torch.float16)
  Epsilon 0.001) base and diff are different for 0 values (0.00%)
Max difference: tensor(0., device='cuda:0', dtype=torch.float16)
  Epsilon 0.0001) base and diff are different for 0 values (0.00%)
Max difference: tensor(0., device='cuda:0', dtype=torch.float16)
  Epsilon 1e-05) base and diff are different for 0 values (0.00%)
Max difference: tensor(0., device='cuda:0', dtype=torch.float16)
  Epsilon 1e-06) base and diff are different for 0 values (0.00%)
Max difference: tensor(0., device='cuda:0', dtype=torch.float16)
  Epsilon 1e-07) base and diff are different for 0 values (0.00%)
Max difference: tensor(0., device='cuda:0', dtype=torch.float16)
  Epsilon 1e-08) base and diff are different for 0 values (0.00%)
Max difference: tensor(0., device='cuda:0', dtype=torch.float16)

@rmccorm4
Copy link
Collaborator

rmccorm4 commented Jan 1, 2020

Also repro'd with TensorRT 7.0 + PyTorch 1.3.1.

Might be related to #305

@rmccorm4
Copy link
Collaborator

Hi @jasonliu19,

Sorry this is super late. Two things:

1. Disabling Layer Fusion

It turns out there is a workaround to disable layer fusion for debugging purposes. When you mark a layer as a network output (network.mark_output(layer)), we must keep the results after that layer is executed, so it will disable the layer fusion optimizations that involve that layer fusing with the layer afterwards. It's a very useful way of debugging a network. When trying to figure out if a certain layer is causing probems, you can do a binary search of marking layers as outputs to find it.

2. Pytorch vs TensorRT Fusion Output Differences

TensorRT computes new scales while fusing in fp32 precision, which is exactly what's happening in the third comparison (Pytorch fusion in fp32 vs TRT), which is why the results are the same.

@cathy-kim
Copy link

Hi @rmccorm4
Using the first method, it causes the host latency to increase in order to bring more tensors from GPU computed result.

Is there any method to prevent the host latency from increasing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Precision: FP16 question Further information is requested TODO
Projects
None yet
Development

No branches or pull requests

3 participants