diff --git a/examples/pytorch/gpt/utils/gpt.py b/examples/pytorch/gpt/utils/gpt.py index 20d90b45f..2c75d971b 100644 --- a/examples/pytorch/gpt/utils/gpt.py +++ b/examples/pytorch/gpt/utils/gpt.py @@ -38,7 +38,7 @@ def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, has_post_decoder_layernorm: bool = True, int8_mode: int = 0, inter_size: int = 0): - assert(head_num % tensor_para_size == 0) + assert head_num % tensor_para_size == 0 if int8_mode == 1: torch_infer_dtype = str_type_map[inference_data_type] @@ -218,7 +218,7 @@ def __len__(self): return len(self.w) def _map(self, func): - assert(self.pre_embed_idx < self.post_embed_idx, "Pre decoder embedding index should be lower than post decoder embedding index.") + assert self.pre_embed_idx < self.post_embed_idx, "Pre decoder embedding index should be lower than post decoder embedding index." for i in range(len(self.w)): if isinstance(self.w[i], list): for j in range(len(self.w[i])):