Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRTLLM new API support #9003

Merged
merged 41 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2d01795
Add trtllm checkpoint
meatybobby Apr 13, 2024
fa12e77
Change model config
meatybobby Apr 13, 2024
4c715ae
fix no query_group
meatybobby Apr 16, 2024
c79543b
Merge branch 'main' into bobchen/nemotron
meatybobby Apr 16, 2024
6005519
Using build API
meatybobby Apr 18, 2024
38b3b41
Change export to new API
meatybobby Apr 19, 2024
ed409f8
Update generate API
meatybobby Apr 19, 2024
a472f01
Fix runtime config
meatybobby Apr 19, 2024
a1c477d
Fix for llama
meatybobby Apr 19, 2024
b43f848
Fix for ptuning
abharwani Apr 22, 2024
a827421
Fix TP issue
meatybobby Apr 23, 2024
2b38efb
Change TP rank for building weight dict
meatybobby Apr 23, 2024
64dd631
Add lora config
abharwani Apr 23, 2024
cdb7389
add prompt embedding table config
abharwani Apr 23, 2024
487eb26
Fix PP isue
meatybobby Apr 23, 2024
b80388b
PP layers fix
meatybobby Apr 24, 2024
fab487b
Fix no prompt task ids
meatybobby Apr 24, 2024
8f0f36d
Add bos for Gemma
meatybobby Apr 24, 2024
5d3503e
Add multi block mode
meatybobby Apr 24, 2024
a8f54b0
Embedding and layernorm for PP
meatybobby Apr 24, 2024
bdf7cfc
MPI multiprocess support for multinode
meatybobby Apr 25, 2024
599520f
Only output text on first rank
meatybobby Apr 25, 2024
7821ff9
Change to ModelRunnerCpp
meatybobby Apr 25, 2024
3ecd9a7
Add falcon
meatybobby Apr 25, 2024
0ce5ae5
Add rotary_pct default value
meatybobby Apr 25, 2024
4d576ef
Falcon fix
meatybobby Apr 29, 2024
aa28fc9
Add MOE config
meatybobby Apr 30, 2024
da84b22
Fix MOE weight dict
meatybobby May 1, 2024
30e6ece
Clean code
meatybobby May 2, 2024
479d871
Add rotary_base
meatybobby May 2, 2024
05b4cbc
Fix MOE config
meatybobby May 2, 2024
d2ff752
Fix falcon new architecture
meatybobby May 3, 2024
d3f0307
Merge branch 'main' into bobchen/nemotron
meatybobby May 3, 2024
ad5c2fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 3, 2024
170df0e
Fix Gemma 7B
meatybobby May 3, 2024
b2413a1
Add rotary_scaling
meatybobby May 7, 2024
b19cfac
Merge branch 'main' into bobchen/nemotron
meatybobby May 7, 2024
4ca6eba
Merge branch 'main' into bobchen/nemotron
oyilmaz-nvidia May 8, 2024
625c239
Merge branch 'main' into bobchen/nemotron
oyilmaz-nvidia May 13, 2024
d75d601
Apply isort and black reformatting
oyilmaz-nvidia May 13, 2024
6bb65da
Merge branch 'main' into bobchen/nemotron
ericharper May 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 66 additions & 59 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
from nemo.export.tarutils import TarPath, unpack_tarball
from nemo.export.trt_llm.model_config_trt import model_config_to_tensorrt_llm
from nemo.export.trt_llm.nemo.nemo_ckpt_convert import build_tokenizer
from nemo.export.trt_llm.nemo_utils import get_tokenzier, nemo_llm_model_to_model_config, nemo_llm_to_model_config
from nemo.export.trt_llm.nemo_utils import get_tokenzier, nemo_llm_model_to_model_config, nemo_to_trtllm_config
from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm
from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer
from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine
from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_refit
from nemo.export.trt_llm.utils import is_nemo_file

Expand Down Expand Up @@ -115,6 +116,7 @@ def export(
max_output_token: int = 256,
max_batch_size: int = 8,
max_prompt_embedding_table_size=None,
use_parallel_embedding: bool = False,
use_inflight_batching: bool = False,
enable_context_fmha: bool = True,
paged_kv_cache: bool = False,
Expand Down Expand Up @@ -188,65 +190,70 @@ def export(

self.model = None

tmp_dir = tempfile.TemporaryDirectory()
nemo_export_dir = Path(tmp_dir.name)
if tensorrt_llm.mpi_rank() == 0:
tmp_dir = tempfile.TemporaryDirectory()
nemo_export_dir = Path(tmp_dir.name)

if nemo_checkpoint_path.endswith("qnemo"):
if os.path.isdir(nemo_checkpoint_path):
nemo_export_dir = nemo_checkpoint_path
if nemo_checkpoint_path.endswith("qnemo"):
if os.path.isdir(nemo_checkpoint_path):
nemo_export_dir = nemo_checkpoint_path
else:
unpack_tarball(nemo_checkpoint_path, tmp_dir.name)
nemo_checkpoint_path = tmp_dir.name
self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path)

qnemo_to_tensorrt_llm(
nemo_checkpoint_path=nemo_checkpoint_path,
engine_dir=self.model_dir,
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_target_modules=lora_target_modules,
)
else:
unpack_tarball(nemo_checkpoint_path, tmp_dir.name)
nemo_checkpoint_path = tmp_dir.name
self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path)

qnemo_to_tensorrt_llm(
nemo_checkpoint_path=nemo_checkpoint_path,
engine_dir=self.model_dir,
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_target_modules=lora_target_modules,
)
else:
model_configs, self.tokenizer = nemo_llm_to_model_config(
in_file=nemo_checkpoint_path,
decoder_type=model_type,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
nemo_export_dir=nemo_export_dir,
save_nemo_model_config=save_nemo_model_config,
)
weights_dicts, model_configs, self.tokenizer = nemo_to_trtllm_config(
in_file=nemo_checkpoint_path,
decoder_type=model_type,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
use_parallel_embedding=use_parallel_embedding,
nemo_export_dir=nemo_export_dir,
save_nemo_model_config=save_nemo_model_config,
)

model_config_to_tensorrt_llm(
model_configs,
self.model_dir,
world_size=tensor_parallel_size * pipeline_parallel_size,
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
use_inflight_batching=use_inflight_batching,
paged_kv_cache=paged_kv_cache,
enable_context_fmha=enable_context_fmha,
enable_multi_block_mode=enable_multi_block_mode,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
max_lora_rank=max_lora_rank,
)
for weight_dict, model_config in zip(weights_dicts, model_configs):
build_and_save_engine(
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
model_config=model_config,
model_weights=weight_dict,
model_dir=self.model_dir,
model_type=model_type,
lora_ckpt_list=self.lora_ckpt_list,
use_lora_plugin=use_lora_plugin,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
enable_multi_block_mode=enable_multi_block_mode,
)

tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
if os.path.exists(tokenizer_path):
shutil.copy(tokenizer_path, self.model_dir)
else:
self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer'))
tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
if os.path.exists(tokenizer_path):
shutil.copy(tokenizer_path, self.model_dir)
else:
self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer'))

nemo_model_config = os.path.join(nemo_export_dir, "model_config.yaml")
if os.path.exists(nemo_model_config):
shutil.copy(nemo_model_config, self.model_dir)

nemo_model_config = os.path.join(nemo_export_dir, "model_config.yaml")
if os.path.exists(nemo_model_config):
shutil.copy(nemo_model_config, self.model_dir)
tmp_dir.cleanup()

tmp_dir.cleanup()
if tensorrt_llm.mpi_world_size() > 1:
tensorrt_llm.mpi_barrier()

if load_model:
self._load()
Expand Down Expand Up @@ -394,7 +401,7 @@ def forward(
), "Task: {0} doesn't exist in the task list.".format(task_ids[i])
input_task_ids.append(self.task_ids[task_ids[i]])
if not streaming:
if torch.distributed.is_initialized():
if torch.distributed.is_initialized() or tensorrt_llm.mpi_world_size() > 1:
multiprocessed_env = True
else:
multiprocessed_env = False
Expand Down Expand Up @@ -478,7 +485,7 @@ def get_hidden_size(self):
if self.config is None:
return None
else:
return self.config["builder_config"]["hidden_size"]
return self.config["pretrained_config"]["hidden_size"]

@property
def get_triton_input(self):
Expand Down Expand Up @@ -694,15 +701,15 @@ def _get_prompt_embedding_table(
raise TypeError(prompt_embeddings_checkpoint_path + " is not a nemo file.")
prompt_embeddings_table = self._get_prompt_embedding_table_ckpt(prompt_embeddings_checkpoint_path)

dtype = self.config['builder_config']['precision']
dtype = self.config['pretrained_config']['dtype']
prompt_embeddings_table = prompt_embeddings_table.to(
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)
).cuda()

if prompt_embeddings_table.size(dim=1) != self.config["builder_config"]["hidden_size"]:
if prompt_embeddings_table.size(dim=1) != self.config["pretrained_config"]["hidden_size"]:
raise Exception(
"Hidden dimension of the model is {0} and does not match with the dimension of the prompt table.".format(
self.config["builder_config"]["hidden_size"]
self.config["pretrained_config"]["hidden_size"]
)
)

Expand Down
8 changes: 8 additions & 0 deletions nemo/export/trt_llm/decoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@
DECODER_GEMMA: GemmaDecoderLayerConfigBuilder,
}

DECODER_MODEL_TYPE = {
DECODER_GPT2: 'GPTForCausalLM',
DECODER_GPTNEXT: 'GPTForCausalLM',
DECODER_LLAMA: 'LLaMAForCausalLM',
DECODER_GEMMA: 'GemmaForCausalLM',
DECODER_FALCON: 'FalconForCausalLM',
}


def build_decoder_layer_config(layer, decoder: str, dtype=trt.float16, rank=0, tensor_parallel=1):
"""Builds the decoder layer config with the input torch module."""
Expand Down
72 changes: 47 additions & 25 deletions nemo/export/trt_llm/nemo/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@


def save_val(val, dir, key, tp_num=None):
suffix = "bin" if tp_num is None else f"{tp_num}.bin"
suffix = "" if tp_num is None else f".{tp_num}.bin"
# Transpose linear layer weights to the correct shape.
if len(val.shape) >= 2:
val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0]))
global weights_dict
weights_dict[f"model.{key}.{suffix}"] = val
weights_dict[f"{key}{suffix}"] = val


def save_split(split_vals, dir, key, i, split_factor):
Expand All @@ -55,10 +55,10 @@
def save_expert_split(split_vals, dir, key, i, split_factor):
for j, val in enumerate(split_vals):
tp_num = i * split_factor + j
suffix = "bin" if tp_num is None else f"{tp_num}.bin"
suffix = "" if tp_num is None else f".{tp_num}.bin"

global weights_dict
weights_dict[f"model.{key}.{suffix}"] = val
weights_dict[f"{key}{suffix}"] = val


def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
Expand Down Expand Up @@ -178,11 +178,15 @@
tp_size = config.get("tp_size", 1)
int8_outputs = config.get("int8_outputs", None)
multi_query_mode = config.get("multi_query_mode", False)
local_dim = config.get("local_dim", None)
Fixed Show fixed Hide fixed
num_kv_heads = config.get("num_kv_heads", num_attention_heads)
size_per_head = config.get("kv_channels", None)

save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only"

layer_num = key.split(".")[1]
layer_prefix = f'transformer.layers.{layer_num}'

if not isinstance(vals, list):
vals = [vals]

Expand Down Expand Up @@ -210,12 +214,27 @@
or "final_layernorm.bias" in key
):
# shared weights, only need to convert the weights of rank 0
if "post_self_attn_layernorm.weight" in key:
key = key.replace("post_self_attn_layernorm.weight", "post_attention_layernorm.weight")
elif "mlp.linear_fc2.bias" in key:
key = key.replace("mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias")
elif "attention.linear_proj.bias" in key:
key = key.replace("attention.linear_proj.bias", "attention.dense.bias")
if "post_self_attn_layernorm" in key or "post_attention_layernorm" in key:
if key.endswith('weight'):
key = f'{layer_prefix}.post_layernorm.weight'
else:
key = f'{layer_prefix}.post_layernorm.bias'
elif "mlp.linear_fc2.bias" in key or "mlp.dense_4h_to_h.bias" in key:
key = f'{layer_prefix}.mlp.proj.bias'
elif "attention.linear_proj.bias" in key or "attention.dense.bias" in key:
key = f'{layer_prefix}.attention.dense.bias'
elif "final_layernorm" in key:
key = key.replace("final_layernorm", "transformer.ln_f")
elif "input_layernorm" in key:
if key.endswith('weight'):
key = f'{layer_prefix}.input_layernorm.weight'
else:
key = f'{layer_prefix}.input_layernorm.bias'
elif "pre_mlp_layernorm" in key:
if key.endswith('weight'):
key = f'{layer_prefix}.post_layernorm.weight'
else:
key = f'{layer_prefix}.post_layernorm.bias'
if tp_rank == 0:
save_val(vals[0], saved_dir, key)

Expand All @@ -228,10 +247,10 @@
cat_dim = 0
val = np.concatenate(vals, axis=cat_dim)
split_vals = np.split(val, split_factor, axis=cat_dim)
if "attention.linear_proj.weight" in key:
key = key.replace("attention.linear_proj.weight", "attention.dense.weight")
elif "mlp.linear_fc2.weight" in key:
key = key.replace("mlp.linear_fc2.weight", "mlp.dense_4h_to_h.weight")
if "attention.linear_proj.weight" in key or "attention.dense.weight" in key:
key = f'{layer_prefix}.attention.dense.weight'
elif "mlp.linear_fc2.weight" in key or "mlp.dense_4h_to_h.weight" in key:
key = f'{layer_prefix}.mlp.proj.weight'
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
if act_range is not None and int8_outputs == "all":
base_key = key.replace(".weight", "")
Expand All @@ -251,8 +270,10 @@
val = np.concatenate(vals, axis=cat_dim)
split_vals = np.split(val, split_factor, axis=cat_dim)

if "mlp.linear_fc1" in key:
key = key.replace("mlp.linear_fc1", "mlp.dense_h_to_4h")
if key.endswith("weight"):
key = f'{layer_prefix}.mlp.fc.weight'
else:
key = f'{layer_prefix}.mlp.fc.bias'
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
if act_range is not None and int8_outputs == "all":
base_key = key.replace(".weight", "")
Expand All @@ -261,8 +282,10 @@

if split_gated_activation:
assert not save_int8
prefix, dot, suffix = key.rpartition(".")
key = prefix + ".gate" + dot + suffix
if key.endswith("weight"):
key = f'{layer_prefix}.mlp.gate.weight'
else:
key = f'{layer_prefix}.mlp.gate.bias'

gate = np.concatenate(gates, axis=cat_dim)
split_vals = np.split(gate, split_factor, axis=cat_dim)
Expand All @@ -279,9 +302,6 @@
write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor)

elif "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key:
if "attention.linear_qkv.bias" in key:
key = key.replace("attention.linear_qkv.bias", "attention.query_key_value.bias")

qkv_hidden_dim = vals[0].shape[0]
size_per_head = qkv_hidden_dim // (num_attention_heads + 2 * num_kv_heads)
q_num = num_attention_heads // num_kv_heads
Expand All @@ -304,6 +324,7 @@
np.concatenate([q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], axis=0)
for i in range(split_factor)
]
key = f'{layer_prefix}.attention.qkv.bias'
save_split(split_vals, saved_dir, key, tp_rank, split_factor)

elif "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key:
Expand Down Expand Up @@ -342,8 +363,7 @@
for i in range(split_factor)
]

if "attention.linear_qkv.weight" in key:
key = key.replace("attention.linear_qkv.weight", "attention.query_key_value.weight")
key = f'{layer_prefix}.attention.qkv.weight'
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
if save_int8:
base_key = key.replace(".weight", "")
Expand All @@ -366,8 +386,8 @@
pass
elif "mlp.router.weight" in key:
val = np.concatenate(vals, axis=1)
split_vals = np.split(val, split_factor, axis=0)
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
key = f'{layer_prefix}.mlp.router.weight'
save_val(val, saved_dir, key)
elif "experts.linear_fc1.weight" in key:
cat_dim = -1
val = np.concatenate(vals, axis=cat_dim)
Expand All @@ -378,12 +398,14 @@
split_w3s = np.split(w3, split_factor, axis=1)

split_vals = [np.concatenate(item, axis=1) for item in zip(split_w3s, split_w1s)]
key = f'{layer_prefix}.mlp.experts_weight_1'
save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor)

elif "experts.linear_fc2.weight" in key:
cat_dim = -1
val = np.concatenate(vals, axis=cat_dim)
split_vals = np.split(val, split_factor, axis=cat_dim)
key = f'{layer_prefix}.mlp.experts_weight_2'
save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor)
else:
print(f"[WARNING] {key} not handled by converter")
Expand Down
Loading
Loading