-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Closed as not planned
Description
Description
I'm trying to infer GPT-2 model on A30, when using TensorRT for inference, the single call time is stable at 8ms(same input), but when called multiple times in a loop(the output of the previous call as the input of the next call), the time jitter from 8ms to 80ms.
the result is as follows
...
generate cost:9(ms)
generate cost:71(ms)
generate cost:9(ms)
generate cost:8(ms)
generate cost:10(ms)
generate cost:69(ms)
generate cost:9(ms)
generate cost:9(ms)
generate cost:80(ms)
generate cost:8(ms)
generate cost:8(ms)
generate cost:8(ms)
generate cost:8(ms)
generate cost:72(ms)
generate cost:8(ms)
generate cost:8(ms)
generate cost:9(ms)
generate cost:72(ms)
generate cost:9(ms)
...
The timeline by nsys is as follows

class TrtModel():
def __init__(self, engine_file) -> None:
super().__init__()
self.load_engine(engine_file)
self.context = self.engine.create_execution_context()
self.inputs_name = [
"input_ids",
"past_key_values_in",
"attention_mask",
"position_ids",
]
def load_engine(self, engine_file):
assert os.path.exists(engine_file)
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with open(engine_file, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
# Deserialize ICudaEngine
self.engine = runtime.deserialize_cuda_engine(f.read())
def prepare(self, feed_dict):
self.past_key_values_out = torch.zeros((24, 2, 1, 128, feed_dict['past_key_values_in'].shape[-1] + feed_dict['input_ids'].shape[1] - 1), dtype=torch.float).cuda()feed_dict['input_ids'].shape[1]), dtype=torch.float).cuda()
self.logits = torch.zeros((1, feed_dict['input_ids'].shape[1], 49280), dtype=torch.float).cuda()
self.bindings = [0] * self.engine.num_bindings
self.bindings[0] = feed_dict["input_ids"].data_ptr()
self.bindings[1] = feed_dict["past_key_values_in"].data_ptr()
self.bindings[2] = feed_dict["attention_mask"].data_ptr()
self.bindings[3] = feed_dict["position_ids"].data_ptr()
self.bindings[4] = self.past_key_values_out.data_ptr()
self.bindings[5] = self.logits.data_ptr()
def predict(self, feed_dict):
# reset input shape, eg pkv
for idx, name in enumerate(self.inputs_name):
if self.context.get_binding_shape(idx) != tuple(feed_dict[name].shape):
self.context.set_binding_shape(idx, feed_dict[name].shape)
self.context.execute_v2(self.bindings)
return self.logits[:, -1, :], self.past_key_values_out
def prepare_inputs_for_generation(input_ids, past_key_values, attention_mask=None, position_ids=None, past_kv_is_valid=True):
token_type_ids = None
if past_kv_is_valid:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if position_ids is None:
if past_kv_is_valid:
position_ids = torch.tensor([[past_key_values.shape[-1]]], dtype=torch.long)
else:
position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, input_ids.shape[-1])
model_inputs = {
"input_ids": input_ids.cuda(),
"past_key_values": past_key_values,
"use_cache": True,
"position_ids": position_ids.cuda(),
"attention_mask": attention_mask.cuda(),
}
return model_inputs
def generate(
trt_engine,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[np.array] = None,
...
):
...
while True:
loop_start = get_time_ms()
model_inputs = prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
feed_dict_device = {
"input_ids": model_inputs['input_ids'],
"past_key_values_in": model_inputs['past_key_values'],
"attention_mask": model_inputs['attention_mask'],
"position_ids": model_inputs['position_ids'],
}
trt_engine.prepare(feed_dict_device)
# infer
next_token_logits, past_key_values_out_tensor = trt_engine.predict(feed_dict_device)
loop_start3 = get_time_ms()
print('\tgenerate cost:%d(ms)' % (loop_start3 - loop_start))
# postprocess
past_key_values = past_key_values_out_tensor
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
return input_ids
if __name__ == "__main__":
text = "# language: Python\n\ndef binary_search("
text_input = tokenizer(text, return_tensors='pt')
trt_engine = BackendGPU("xxx.engine")
generate(
trt_engine,
text_input['input_ids'].cuda(),
None,
...
)
Summary
trtexec summary is as follows
[06/26/2023-11:43:24] [I] Average on 10 runs - GPU latency: 7.57195 ms - Host latency: 7.62485 ms (enqueue 3.16179 ms)
[06/26/2023-11:43:24] [I] Average on 10 runs - GPU latency: 7.57727 ms - Host latency: 7.62781 ms (enqueue 2.67576 ms)
[06/26/2023-11:43:24] [I] Average on 10 runs - GPU latency: 7.5772 ms - Host latency: 7.63237 ms (enqueue 2.6843 ms)
[06/26/2023-11:43:24] [I]
...
[06/26/2023-11:43:24] [I] === Performance summary ===
[06/26/2023-11:43:24] [I] Throughput: 131.187 qps
[06/26/2023-11:43:24] [I] Latency: min = 7.61353 ms, max = 11.1703 ms, mean = 7.65891 ms, median = 7.63509 ms, percentile(90%) = 7.69922 ms, percentile(95%) = 7.72717 ms, percentile(99%) = 7.79211 ms
[06/26/2023-11:43:24] [I] Enqueue Time: min = 1.97473 ms, max = 12.2169 ms, mean = 2.84481 ms, median = 2.67487 ms, percentile(90%) = 2.83813 ms, percentile(95%) = 3.03088 ms, percentile(99%) = 10.4944 ms
[06/26/2023-11:43:24] [I] H2D Latency: min = 0.017395 ms, max = 0.103943 ms, mean = 0.0314965 ms, median = 0.0293579 ms, percentile(90%) = 0.0379639 ms, percentile(95%) = 0.050293 ms, percentile(99%) = 0.097168 ms
[06/26/2023-11:43:24] [I] GPU Compute Time: min = 7.56323 ms, max = 11.1268 ms, mean = 7.60134 ms, median = 7.57861 ms, percentile(90%) = 7.63599 ms, percentile(95%) = 7.66663 ms, percentile(99%) = 7.72607 ms
[06/26/2023-11:43:24] [I] D2H Latency: min = 0.0244141 ms, max = 0.0296631 ms, mean = 0.026073 ms, median = 0.0258789 ms, percentile(90%) = 0.0275879 ms, percentile(95%) = 0.0279541 ms, percentile(99%) = 0.0288086 ms
[06/26/2023-11:43:24] [I] Total Host Walltime: 3.01858 s
[06/26/2023-11:43:24] [I] Total GPU Compute Time: 3.01013 s
[06/26/2023-11:43:24] [W] * GPU compute time is unstable, with coefficient of variance = 2.45062%.
The onnx2trt script is as follows
def onnx_to_trt(onnx_path):
from polygraphy.backend.trt import CreateConfig, engine_from_network, network_from_onnx_path, save_engine
from polygraphy.backend.trt import Profile
from tensorrt import PreviewFeature
p = Profile()
p.add("input_ids", min=(1,1), opt=(1,1), max=(1,256))
p.add("past_key_values_in", min=(24,2,0,128,0), opt=(24,2,1,128,1), max=(24,2,1,128,256))
p.add("attention_mask", min=(1,1), opt=(1,1), max=(1,256))
p.add("position_ids", min=(1,1), opt=(1,1), max=(1,256))
preview_features = []
preview_features.append(PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
engine = engine_from_network(
network_from_onnx_path(onnx_path),
config=CreateConfig(
tf32=True,
#fp16=True,
profiles=[p],
preview_features=preview_features
)
)
engine_path = "xxx.engine"
save_engine(engine, path=engine_path)
Environment
TensorRT Version: 8.6
NVIDIA GPU: A30
NVIDIA Driver Version: 515.65.01
CUDA Version: 11.7
CUDNN Version:
Operating System: ubuntu-20.04.1
Python Version (if applicable): 3.9
Metadata
Metadata
Assignees
Labels
No labels