Skip to content

Significant Jitter in multiple inferences #3085

@zwj536

Description

@zwj536

Description

I'm trying to infer GPT-2 model on A30, when using TensorRT for inference, the single call time is stable at 8ms(same input), but when called multiple times in a loop(the output of the previous call as the input of the next call), the time jitter from 8ms to 80ms.

the result is as follows
        ...
        generate cost:9(ms)
        generate cost:71(ms)
        generate cost:9(ms)
        generate cost:8(ms)
        generate cost:10(ms)
        generate cost:69(ms)
        generate cost:9(ms)
        generate cost:9(ms)
        generate cost:80(ms)
        generate cost:8(ms)
        generate cost:8(ms)
        generate cost:8(ms)
        generate cost:8(ms)
        generate cost:72(ms)
        generate cost:8(ms)
        generate cost:8(ms)
        generate cost:9(ms)
        generate cost:72(ms)
        generate cost:9(ms)
        ...

The timeline by nsys is as follows
image

class TrtModel():
    def __init__(self, engine_file) -> None:
        super().__init__()

        self.load_engine(engine_file)
        self.context = self.engine.create_execution_context()
        self.inputs_name = [
            "input_ids",
            "past_key_values_in",
            "attention_mask",
            "position_ids",
        ]

    def load_engine(self, engine_file):
        assert os.path.exists(engine_file)
        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
        with open(engine_file, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
            # Deserialize ICudaEngine
            self.engine = runtime.deserialize_cuda_engine(f.read())

    def prepare(self, feed_dict):  
        self.past_key_values_out = torch.zeros((24, 2, 1, 128, feed_dict['past_key_values_in'].shape[-1] + feed_dict['input_ids'].shape[1] - 1), dtype=torch.float).cuda()feed_dict['input_ids'].shape[1]), dtype=torch.float).cuda()
        self.logits = torch.zeros((1, feed_dict['input_ids'].shape[1], 49280), dtype=torch.float).cuda()
        
        self.bindings = [0] * self.engine.num_bindings
        self.bindings[0] = feed_dict["input_ids"].data_ptr()
        self.bindings[1] = feed_dict["past_key_values_in"].data_ptr()
        self.bindings[2] = feed_dict["attention_mask"].data_ptr()
        self.bindings[3] = feed_dict["position_ids"].data_ptr()
        self.bindings[4] = self.past_key_values_out.data_ptr()
        self.bindings[5] = self.logits.data_ptr()
        
    def predict(self, feed_dict):
        # reset input shape, eg pkv
        for idx, name in enumerate(self.inputs_name):
            if self.context.get_binding_shape(idx) != tuple(feed_dict[name].shape):
                self.context.set_binding_shape(idx, feed_dict[name].shape)
            
        self.context.execute_v2(self.bindings)
        return self.logits[:, -1, :], self.past_key_values_out

def prepare_inputs_for_generation(input_ids, past_key_values, attention_mask=None, position_ids=None, past_kv_is_valid=True):
    token_type_ids = None
    if past_kv_is_valid:
        input_ids = input_ids[:, -1].unsqueeze(-1)
        if token_type_ids is not None:
            token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
    
    if attention_mask is None: 
        attention_mask = torch.ones_like(input_ids)
        
    if position_ids is None: 
        if past_kv_is_valid:
            position_ids = torch.tensor([[past_key_values.shape[-1]]], dtype=torch.long)
        else:
            position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long)
            position_ids = position_ids.unsqueeze(0).view(-1, input_ids.shape[-1])
 
    model_inputs = {
        "input_ids": input_ids.cuda(),
        "past_key_values": past_key_values,
        "use_cache": True,
        "position_ids": position_ids.cuda(),
        "attention_mask": attention_mask.cuda(),
    }
    
    return model_inputs

def generate(
    trt_engine,
    input_ids: Optional[torch.Tensor] = None,
    past_key_values: Optional[np.array] = None,
    ...
):
    ...
    while True:
        loop_start = get_time_ms()
        model_inputs = prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
        feed_dict_device = {
            "input_ids": model_inputs['input_ids'],
            "past_key_values_in": model_inputs['past_key_values'],
            "attention_mask": model_inputs['attention_mask'],
            "position_ids": model_inputs['position_ids'],
        }
       trt_engine.prepare(feed_dict_device)

       # infer
       next_token_logits, past_key_values_out_tensor = trt_engine.predict(feed_dict_device)
       loop_start3 = get_time_ms()
       print('\tgenerate cost:%d(ms)' % (loop_start3 - loop_start))

       # postprocess
       past_key_values = past_key_values_out_tensor
       next_token_scores = logits_processor(input_ids, next_token_logits)
       next_token_scores = logits_warper(input_ids, next_token_scores)
       probs = nn.functional.softmax(next_token_scores, dim=-1)
       next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
       input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

       if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
            break

  return input_ids

if __name__ == "__main__":
     text = "# language: Python\n\ndef binary_search("
     text_input = tokenizer(text, return_tensors='pt')

     trt_engine = BackendGPU("xxx.engine")
     generate(
         trt_engine,
         text_input['input_ids'].cuda(),
         None,
         ...
     )

Summary

trtexec summary is as follows

[06/26/2023-11:43:24] [I] Average on 10 runs - GPU latency: 7.57195 ms - Host latency: 7.62485 ms (enqueue 3.16179 ms)
[06/26/2023-11:43:24] [I] Average on 10 runs - GPU latency: 7.57727 ms - Host latency: 7.62781 ms (enqueue 2.67576 ms)
[06/26/2023-11:43:24] [I] Average on 10 runs - GPU latency: 7.5772 ms - Host latency: 7.63237 ms (enqueue 2.6843 ms)
[06/26/2023-11:43:24] [I] 
...
[06/26/2023-11:43:24] [I] === Performance summary ===
[06/26/2023-11:43:24] [I] Throughput: 131.187 qps
[06/26/2023-11:43:24] [I] Latency: min = 7.61353 ms, max = 11.1703 ms, mean = 7.65891 ms, median = 7.63509 ms, percentile(90%) = 7.69922 ms, percentile(95%) = 7.72717 ms, percentile(99%) = 7.79211 ms
[06/26/2023-11:43:24] [I] Enqueue Time: min = 1.97473 ms, max = 12.2169 ms, mean = 2.84481 ms, median = 2.67487 ms, percentile(90%) = 2.83813 ms, percentile(95%) = 3.03088 ms, percentile(99%) = 10.4944 ms
[06/26/2023-11:43:24] [I] H2D Latency: min = 0.017395 ms, max = 0.103943 ms, mean = 0.0314965 ms, median = 0.0293579 ms, percentile(90%) = 0.0379639 ms, percentile(95%) = 0.050293 ms, percentile(99%) = 0.097168 ms
[06/26/2023-11:43:24] [I] GPU Compute Time: min = 7.56323 ms, max = 11.1268 ms, mean = 7.60134 ms, median = 7.57861 ms, percentile(90%) = 7.63599 ms, percentile(95%) = 7.66663 ms, percentile(99%) = 7.72607 ms
[06/26/2023-11:43:24] [I] D2H Latency: min = 0.0244141 ms, max = 0.0296631 ms, mean = 0.026073 ms, median = 0.0258789 ms, percentile(90%) = 0.0275879 ms, percentile(95%) = 0.0279541 ms, percentile(99%) = 0.0288086 ms
[06/26/2023-11:43:24] [I] Total Host Walltime: 3.01858 s
[06/26/2023-11:43:24] [I] Total GPU Compute Time: 3.01013 s
[06/26/2023-11:43:24] [W] * GPU compute time is unstable, with coefficient of variance = 2.45062%.

The onnx2trt script is as follows

def onnx_to_trt(onnx_path):
    from polygraphy.backend.trt import CreateConfig, engine_from_network, network_from_onnx_path, save_engine
    from polygraphy.backend.trt import Profile
    from tensorrt import PreviewFeature
    
    p = Profile()
    p.add("input_ids", min=(1,1), opt=(1,1), max=(1,256))
    p.add("past_key_values_in", min=(24,2,0,128,0), opt=(24,2,1,128,1), max=(24,2,1,128,256))
    p.add("attention_mask", min=(1,1), opt=(1,1), max=(1,256))
    p.add("position_ids", min=(1,1), opt=(1,1), max=(1,256))
    preview_features = []
    preview_features.append(PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
    engine = engine_from_network(
        network_from_onnx_path(onnx_path), 
        config=CreateConfig(
            tf32=True,
            #fp16=True, 
            profiles=[p],
            preview_features=preview_features
        )
    )
    engine_path = "xxx.engine"
    save_engine(engine, path=engine_path)

Environment

TensorRT Version: 8.6

NVIDIA GPU: A30

NVIDIA Driver Version: 515.65.01

CUDA Version: 11.7

CUDNN Version:

Operating System: ubuntu-20.04.1

Python Version (if applicable): 3.9

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions