diff --git a/README.md b/README.md
index 7f8e75033c..961d762404 100644
--- a/README.md
+++ b/README.md
@@ -87,10 +87,13 @@ docker run --gpus all --rm -v $(pwd)/workspace:/workspace -it openmmlab/lmdeploy
python3 -m lmdeploy.turbomind.chat /workspace
```
-```{note}
-When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind. It is recommended to use NVIDIA cards such as 3090, V100, A100, etc.
-Disable GPU ECC can free up 10% memory, try `sudo nvidia-smi --ecc-config=0` and reboot system.
-```
+> **Note**
+> When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind.
+> It is recommended to use NVIDIA cards such as 3090, V100, A100, etc.
+> Disable GPU ECC can free up 10% memory, try `sudo nvidia-smi --ecc-config=0` and reboot system.
+
+> **Note**
+> Tensor parallel is available to perform inference on multiple GPUs. Add `--tp=` on `chat` to enable runtime TP.
#### Serving
@@ -166,6 +169,9 @@ Then adjust `workspace/triton_models/weights/config.ini`
Here is [quantization test results](./docs/en/quantization.md).
+> **Warning**
+> runtime Tesnor Parallel for quantilized model is not available. Please setup `--tp` on `deploy` to enable static TP.
+
## Contributing
We appreciate all contributions to LMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 4e260ac3ac..27fb57a868 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -86,10 +86,12 @@ docker run --gpus all --rm -v $(pwd)/workspace:/workspace -it openmmlab/lmdeploy
python3 -m lmdeploy.turbomind.chat /workspace
```
-```{note}
-turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。
-关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。
-```
+> **Note**
+> turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。
+> 关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。
+
+> **Note**
+> 使用 Tensor 并发可以利用多张 GPU 进行推理。在 `chat` 时添加参数 `--tp=` 可以启动运行时 TP。
#### 部署推理服务
@@ -165,6 +167,9 @@ python3 -m lmdeploy.lite.apis.kv_qparams \
这里是[量化测试结果](./docs/zh_cn/quantization.md)。
+> **Warning**
+> 量化部署不支持运行时 Tensor 并发。如果希望使用 Tensor 并发,需要在 deploy 时配置 tp 参数。
+
## 贡献指南
我们感谢所有的贡献者为改进和提升 LMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py
index 9249595f55..affec1d98e 100644
--- a/benchmark/profile_generation.py
+++ b/benchmark/profile_generation.py
@@ -76,10 +76,11 @@ def main(model_path: str,
concurrency: int = 1,
input_seqlen: int = 0,
output_seqlen: int = 512,
- test_round: int = 10):
+ test_round: int = 10,
+ tp: int = 1):
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
- tm_model = TurboMind(model_path=model_path)
+ tm_model = TurboMind(model_path=model_path, tp=tp)
warmup(tm_model, concurrency, output_seqlen)
diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py
index 3a6500de6f..d8100113c5 100644
--- a/benchmark/profile_throughput.py
+++ b/benchmark/profile_throughput.py
@@ -54,11 +54,11 @@ def sample_requests(
class Engine:
- def __init__(self, model_path: str):
+ def __init__(self, model_path: str, tp: int = 1):
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
- tm_model = TurboMind(model_path=model_path)
+ tm_model = TurboMind(model_path=model_path, tp=tp)
self.tm_model = tm_model
self.tokenizer = tokenizer
@@ -117,9 +117,10 @@ def process_request(self, requests, concurrency: int = 1):
def main(dataset: str,
model_path: str,
concurrency: int = 1,
- num_prompts: int = 1000):
+ num_prompts: int = 1000,
+ tp: int = 1):
- engine = Engine(model_path)
+ engine = Engine(model_path, tp=tp)
tokenizer = engine.tokenizer
requests = sample_requests(dataset, num_prompts, tokenizer)
diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py
index be2fc39cd2..3b43baa637 100644
--- a/lmdeploy/serve/turbomind/chatbot.py
+++ b/lmdeploy/serve/turbomind/chatbot.py
@@ -52,7 +52,7 @@ def stream_callback(que, result, error):
def get_logger(log_file=None, log_level=logging.INFO):
"""Return the logger."""
- from .utils import get_logger
+ from lmdeploy.turbomind.utils import get_logger
logger = get_logger('service.ft', log_file=log_file, log_level=log_level)
return logger
diff --git a/lmdeploy/serve/turbomind/utils.py b/lmdeploy/serve/turbomind/utils.py
index ba3fd89211..bd1c3a16c2 100644
--- a/lmdeploy/serve/turbomind/utils.py
+++ b/lmdeploy/serve/turbomind/utils.py
@@ -1,88 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
-# Copyright (c) OpenMMLab. All rights reserved.
-import logging
-from typing import List, Optional, Union
+from typing import List, Union
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype
-logger_initialized = {}
-
-
-def get_logger(name: str,
- log_file: Optional[str] = None,
- log_level: int = logging.INFO,
- file_mode: str = 'w'):
- """Initialize and get a logger by name.
-
- If the logger has not been initialized, this method will initialize the
- logger by adding one or two handlers, otherwise the initialized logger will
- be directly returned. During initialization, a StreamHandler will always be
- added. If `log_file` is specified, a FileHandler will also be added.
- Args:
- name (str): Logger name.
- log_file (str | None): The log filename. If specified, a FileHandler
- will be added to the logger.
- log_level (int): The logger level.
- file_mode (str): The file mode used in opening log file.
- Defaults to 'w'.
- Returns:
- logging.Logger: The expected logger.
- """
- # use logger in mmengine if exists.
- try:
- from mmengine.logging import MMLogger
- if MMLogger.check_instance_created(name):
- logger = MMLogger.get_instance(name)
- else:
- logger = MMLogger.get_instance(name,
- logger_name=name,
- log_file=log_file,
- log_level=log_level,
- file_mode=file_mode)
- return logger
-
- except Exception:
- pass
-
- logger = logging.getLogger(name)
- if name in logger_initialized:
- return logger
- # handle hierarchical names
- # e.g., logger "a" is initialized, then logger "a.b" will skip the
- # initialization since it is a child of "a".
- for logger_name in logger_initialized:
- if name.startswith(logger_name):
- return logger
-
- # handle duplicate logs to the console
- for handler in logger.root.handlers:
- if type(handler) is logging.StreamHandler:
- handler.setLevel(logging.ERROR)
-
- stream_handler = logging.StreamHandler()
- handlers = [stream_handler]
-
- if log_file is not None:
- # Here, the default behaviour of the official logger is 'a'. Thus, we
- # provide an interface to change the file mode to the default
- # behaviour.
- file_handler = logging.FileHandler(log_file, file_mode)
- handlers.append(file_handler)
-
- formatter = logging.Formatter(
- '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- for handler in handlers:
- handler.setFormatter(formatter)
- handler.setLevel(log_level)
- logger.addHandler(handler)
-
- logger.setLevel(log_level)
- logger_initialized[name] = True
-
- return logger
-
def prepare_tensor(name, input_tensor):
"""Create grpcclient's InferInput instance according to a given tensor."""
diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py
index 4277350731..f3f2991d43 100644
--- a/lmdeploy/turbomind/chat.py
+++ b/lmdeploy/turbomind/chat.py
@@ -29,7 +29,10 @@ def valid_str(string, coding='utf-8'):
return ret
-def main(model_path, session_id: int = 1, repetition_penalty: float = 1.0):
+def main(model_path,
+ session_id: int = 1,
+ repetition_penalty: float = 1.0,
+ tp=1):
"""An example to perform model inference through the command line
interface.
@@ -39,7 +42,7 @@ def main(model_path, session_id: int = 1, repetition_penalty: float = 1.0):
"""
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
- tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id)
+ tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
generator = tm_model.create_instance()
nth_round = 1
diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py
index efe7e65b7b..0e6859befa 100644
--- a/lmdeploy/turbomind/turbomind.py
+++ b/lmdeploy/turbomind/turbomind.py
@@ -13,6 +13,7 @@
import lmdeploy
from lmdeploy.model import MODELS
+from lmdeploy.turbomind.utils import get_logger
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
@@ -69,14 +70,11 @@ class TurboMind:
Args:
model_path (str): the path of turbomind's model
- data_type (str): the data type
eos_id (int): eos token id
+ tp (int): tensor parallel
"""
- def __init__(self,
- model_path: str,
- data_type: str = 'fp16',
- eos_id: int = 2):
+ def __init__(self, model_path: str, eos_id: int = 2, tp: int = 1):
self.eos_id = eos_id
# TODO: support mpi
@@ -84,8 +82,9 @@ def __init__(self,
node_num = 1
# read meta from model path
- self.gpu_count = 1
+ self.gpu_count = tp
self.session_len = 2048
+ data_type = 'fp16'
ini_path = osp.join(model_path, 'triton_models/weights/config.ini')
with open(ini_path, 'r') as f:
parser = ConfigParser()
@@ -97,10 +96,14 @@ def __init__(self,
section_name = 'llama'
if len(section_name) > 0:
- self.gpu_count = parser.getint(section_name,
- 'tensor_para_size')
+ tp_cfg = parser.getint(section_name, 'tensor_para_size')
self.session_len = parser.getint(section_name, 'session_len')
+ if tp_cfg != 1 and tp_cfg != tp:
+ get_logger('turbomind').info(
+ f'found tp={tp_cfg} in config.ini.')
+ self.gpu_count = tp_cfg
self.model_name = parser.get(section_name, 'model_name')
+ data_type = parser.get(section_name, 'weight_type')
model = MODELS.get(self.model_name)()
self.stop_words = _stop_words(model.stop_words)
diff --git a/lmdeploy/turbomind/utils.py b/lmdeploy/turbomind/utils.py
new file mode 100644
index 0000000000..7b6d51a01a
--- /dev/null
+++ b/lmdeploy/turbomind/utils.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+from typing import Optional
+
+logger_initialized = {}
+
+
+def get_logger(name: str,
+ log_file: Optional[str] = None,
+ log_level: int = logging.INFO,
+ file_mode: str = 'w'):
+ """Initialize and get a logger by name.
+
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified, a FileHandler will also be added.
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level.
+ file_mode (str): The file mode used in opening log file.
+ Defaults to 'w'.
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ # use logger in mmengine if exists.
+ try:
+ from mmengine.logging import MMLogger
+ if MMLogger.check_instance_created(name):
+ logger = MMLogger.get_instance(name)
+ else:
+ logger = MMLogger.get_instance(name,
+ logger_name=name,
+ log_file=log_file,
+ log_level=log_level,
+ file_mode=file_mode)
+ return logger
+
+ except Exception:
+ pass
+
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ # handle duplicate logs to the console
+ for handler in logger.root.handlers:
+ if type(handler) is logging.StreamHandler:
+ handler.setLevel(logging.ERROR)
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ if log_file is not None:
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
+ # provide an interface to change the file mode to the default
+ # behaviour.
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+
+ formatter = logging.Formatter(
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ logger.setLevel(log_level)
+ logger_initialized[name] = True
+
+ return logger
diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
index 48b43f8cd9..e39cb0bef6 100644
--- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
+++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
@@ -21,6 +21,7 @@
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h"
+#include
namespace turbomind {
@@ -99,25 +100,135 @@ void mallocWeights(LlamaDenseWeight& weights, bool bias)
}
template
-void loadWeights(LlamaDenseWeight& w, std::string prefix, int rank, FtCudaDataType model_file_type)
+void loadWeights(LlamaDenseWeight& w,
+ std::string prefix,
+ int rank,
+ FtCudaDataType model_file_type,
+ size_t tensor_para_size,
+ int slice_dim = 0,
+ std::vector slice_shape = {})
{
- prefix += "." + std::to_string(rank);
- const auto type = model_file_type;
+ auto max_prefix = prefix + "." + std::to_string(tensor_para_size - 1);
+ const auto type = model_file_type;
+
+ bool enable_slice = true;
+ // Disable slice if tensor param rank is 1
+ if (tensor_para_size <= 1) {
+ enable_slice = false;
+ }
+ else {
+ // Disable slice if weight has already been sliced
+ if (std::filesystem::exists(max_prefix + ".weight") || std::filesystem::exists(max_prefix + ".qweight")) {
+ TM_LOG_DEBUG("TP weight exists. Disable runtime TP.");
+ enable_slice = false;
+ }
+ }
+
+ size_t dim0 = w.input_dims;
+ size_t dim1 = w.output_dims;
+ if (enable_slice) {
+ // multiple tp size for slice stride
+ if (slice_dim == 0) {
+ dim0 = dim0 * tensor_para_size;
+ if (slice_shape.size() == 0) {
+ slice_shape = {dim0};
+ }
+ }
+ else {
+ dim1 = dim1 * tensor_para_size;
+ if (slice_shape.size() == 0) {
+ slice_shape = {dim1};
+ }
+ }
+
+ prefix += "." + std::to_string(0);
+ }
+ else {
+ prefix += "." + std::to_string(rank);
+ }
if (w.bias) {
- loadWeightFromBin((T*)w.bias, {w.output_dims}, prefix + ".bias", type);
+ std::vector bias_slices{};
+ if (enable_slice) {
+ if (slice_dim == 1) {
+ size_t start = 0;
+ ConcateSlice slice0{.slices = {{0, 1}}};
+ ConcateSlice slice1{.slices = {{}}};
+ for (auto len : slice_shape) {
+ size_t stride = len / tensor_para_size;
+ slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
+ start += len;
+ }
+ bias_slices = {slice0, slice1};
+ }
+ }
+ loadWeightFromBin((T*)w.bias, {1, dim1}, prefix + ".bias", type, bias_slices);
}
const size_t bit_size = getBitSize(w.type);
if (bit_size >= 16) { // fp16, fp32
- loadWeightFromBin((T*)w.kernel, {w.input_dims, w.output_dims}, prefix + ".weight", type);
+ std::vector weight_slices{};
+ if (enable_slice) {
+ if (slice_dim == 1) {
+ size_t start = 0;
+ ConcateSlice slice0{.slices = {{0, dim0}}};
+ ConcateSlice slice1{.slices = {{}}};
+ for (auto len : slice_shape) {
+ size_t stride = len / tensor_para_size;
+ slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
+ start += len;
+ }
+ weight_slices = {slice0, slice1};
+ }
+ else {
+ size_t start = 0;
+ ConcateSlice slice0{.slices = {}};
+ ConcateSlice slice1{.slices = {{0, dim1}}};
+ for (auto len : slice_shape) {
+ size_t stride = len / tensor_para_size;
+ slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
+ start += len;
+ }
+ weight_slices = {slice0, slice1};
+ }
+ }
+ loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type, weight_slices);
}
else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size;
- FT_CHECK(w.input_dims % factor == 0);
- const auto f32_type = FtCudaDataType::FP32;
- loadWeightFromBin((float*)w.kernel, {w.input_dims / factor, w.output_dims}, prefix + ".qweight", f32_type);
- loadWeightFromBin((T*)w.scales, {w.output_dims}, prefix + ".scales", type);
- loadWeightFromBin((T*)w.zeros, {w.output_dims}, prefix + ".zeros", type);
+ FT_CHECK(dim0 % factor == 0);
+ const auto f32_type = FtCudaDataType::FP32;
+ std::vector weight_slices{};
+ std::vector bias_slices{};
+ if (enable_slice) {
+ if (slice_dim == 1) {
+ size_t start = 0;
+ ConcateSlice slice0{.slices = {{0, dim0}}};
+ ConcateSlice slice1{.slices = {{}}};
+ for (auto len : slice_shape) {
+ size_t stride = len / tensor_para_size;
+ slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
+ start += len;
+ }
+ weight_slices = {slice0, slice1};
+
+ ConcateSlice bias_slice0{.slices = {{0, 1}}};
+ bias_slices = {bias_slice0, slice1};
+ }
+ else {
+ size_t start = 0;
+ ConcateSlice slice0{.slices = {}};
+ ConcateSlice slice1{.slices = {{0, dim1}}};
+ for (auto len : slice_shape) {
+ size_t stride = len / factor / tensor_para_size;
+ slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
+ start += len;
+ }
+ weight_slices = {slice0, slice1};
+ }
+ }
+ loadWeightFromBin((float*)w.kernel, {dim0 / factor, dim1}, prefix + ".qweight", f32_type, weight_slices);
+ loadWeightFromBin((T*)w.scales, {1, dim1}, prefix + ".scales", type, bias_slices);
+ loadWeightFromBin((T*)w.zeros, {1, dim1}, prefix + ".zeros", type, bias_slices);
}
}
@@ -158,11 +269,17 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType
(T*)self_attn_norm_weights, {hidden_units_}, dir_path + ".attention_norm.weight", model_file_type);
loadWeightFromBin((T*)ffn_norm_weights, {hidden_units_}, dir_path + ".ffn_norm.weight", model_file_type);
- loadWeights(self_attn_weights.qkv, dir_path + ".attention.w_qkv", tensor_para_rank_, type);
- loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type);
- loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type);
- loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type);
- loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type);
+ loadWeights(self_attn_weights.qkv,
+ dir_path + ".attention.w_qkv",
+ tensor_para_rank_,
+ type,
+ tensor_para_size_,
+ 1,
+ {head_num_ * size_per_head_, kv_head_num_ * size_per_head_, kv_head_num_ * size_per_head_});
+ loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_, 0);
+ loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_, 1);
+ loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1);
+ loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0);
// load kv_cache quant scale
// if file not exist, get empty vector
diff --git a/src/turbomind/utils/memory_utils.cu b/src/turbomind/utils/memory_utils.cu
index 02be105778..a419e2a9c5 100644
--- a/src/turbomind/utils/memory_utils.cu
+++ b/src/turbomind/utils/memory_utils.cu
@@ -301,58 +301,158 @@ template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size);
// loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or
// the product of the elements in shape is 0, this function will return an empty vector.
template
-std::vector loadWeightFromBinHelper(std::vector shape, std::string filename)
+std::vector
+loadWeightFromBinHelper(std::vector shape, std::string filename, std::vector slices = {})
{
if (shape.size() > 2) {
printf("[ERROR] shape should have less than two dims \n");
return std::vector();
}
+
size_t dim0 = shape[0], dim1 = 1;
if (shape.size() == 2) {
dim1 = shape[1];
}
- size_t size = dim0 * dim1;
- if (size == 0) {
- TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str());
- return std::vector();
- }
- std::vector host_array(size);
- std::ifstream in(filename, std::ios::in | std::ios::binary);
- if (!in.is_open()) {
- TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str());
- return std::vector();
+ if (slices.size() == 0) {
+ size_t size = dim0 * dim1;
+ if (size == 0) {
+ TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str());
+ return std::vector();
+ }
+
+ std::vector host_array(size);
+ std::ifstream in(filename, std::ios::in | std::ios::binary);
+ if (!in.is_open()) {
+ TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str());
+ return std::vector();
+ }
+
+ size_t loaded_data_size = sizeof(T) * size;
+ in.seekg(0, in.end);
+ in.seekg(0, in.beg);
+
+ TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename);
+ in.read((char*)host_array.data(), loaded_data_size);
+
+ size_t in_get_size = in.gcount();
+ if (in_get_size != loaded_data_size) {
+ TM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n",
+ filename.c_str(),
+ in_get_size,
+ loaded_data_size);
+ return std::vector();
+ }
+ in.close();
+ // If we succeed, return an array with values.
+ return host_array;
}
+ else {
+ // concate all slices on the same dims
+
+ if (slices.size() != shape.size()) {
+ printf("[ERROR] slices should have same dims as shape \n");
+ return std::vector();
+ }
- size_t loaded_data_size = sizeof(T) * size;
- in.seekg(0, in.end);
- in.seekg(0, in.beg);
+ // get slices
+ ConcateSlice slice0{.slices = {{0, dim0}}};
+ ConcateSlice slice1{.slices = {{0, dim1}}};
+ if (slices.size() > 0 && slices[0].slices.size() > 0) {
+ slice0 = slices[0];
+ }
+ if (shape.size() == 2 && slices[1].slices.size() > 0) {
+ slice1 = slices[1];
+ }
- TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename);
- in.read((char*)host_array.data(), loaded_data_size);
+ size_t w0 = 0;
+ for (auto& s : slice0.slices) {
+ if (s.second > dim0) {
+ s.second = dim0;
+ }
+ if (s.second < s.first) {
+ printf("[ERROR] slice0: end < start \n");
+ return std::vector();
+ }
+ w0 += s.second - s.first;
+ }
- size_t in_get_size = in.gcount();
- if (in_get_size != loaded_data_size) {
- TM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n",
- filename.c_str(),
- in_get_size,
- loaded_data_size);
- return std::vector();
+ size_t w1 = 0;
+ for (auto& s : slice1.slices) {
+ if (s.second > dim1) {
+ s.second = dim1;
+ }
+ if (s.second < s.first) {
+ printf("[ERROR] slice1: end < start \n");
+ return std::vector();
+ }
+ w1 += s.second - s.first;
+ }
+
+ size_t size = w0 * w1;
+ size_t loaded_data_size = size * sizeof(T);
+
+ TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename + " with slice.");
+ if (size == 0) {
+ TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str());
+ return std::vector();
+ }
+
+ std::vector host_array(size);
+ std::ifstream in(filename, std::ios::in | std::ios::binary);
+ if (!in.is_open()) {
+ TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str());
+ return std::vector();
+ }
+
+ char* host_ptr = (char*)host_array.data();
+ if (slice1.slices.size() == 0
+ || (slice1.slices.size() == 1 && slice1.slices[0].second - slice1.slices[0].first == dim1)) {
+ for (auto& s : slice0.slices) {
+ size_t read_size = (s.second - s.first) * dim1 * sizeof(T);
+ size_t pos = s.first * dim1;
+ in.seekg(pos * sizeof(T));
+ in.read((char*)host_ptr, read_size);
+ host_ptr += read_size;
+ }
+ in.close();
+ return host_array;
+ }
+
+ {
+ for (auto& s0 : slice0.slices) {
+ // loop over outer slice
+ for (size_t line_id = s0.first; line_id < s0.second; ++line_id) {
+ // loop over lines
+ size_t pos0 = line_id * dim1;
+ for (auto& s1 : slice1.slices) {
+ // loop over inner slice
+ size_t pos = pos0 + s1.first;
+ size_t read_size = (s1.second - s1.first) * sizeof(T);
+ in.seekg(pos * sizeof(T));
+ in.read(host_ptr, read_size);
+ host_ptr += read_size;
+ }
+ }
+ }
+ in.close();
+ }
+ return host_array;
}
- in.close();
- // If we succeed, return an array with values.
- return host_array;
}
-std::vector loadArrayFromBin(std::vector shape, std::string filename)
+std::vector loadArrayFromBin(std::vector shape, std::string filename, std::vector slices)
{
- return loadWeightFromBinHelper(shape, filename);
+ return loadWeightFromBinHelper(shape, filename, slices);
}
template
-int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename)
+int loadWeightFromBinFunc(T* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices = std::vector())
{
- std::vector host_array = loadWeightFromBinHelper(shape, filename);
+ std::vector host_array = loadWeightFromBinHelper(shape, filename, slices);
if (host_array.empty()) {
return 0;
@@ -371,49 +471,84 @@ int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filenam
return 0;
}
-template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename);
-template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename);
-template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename);
-template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename);
-template int loadWeightFromBinFunc(int8_t* ptr, std::vector shape, std::string filename);
+template int loadWeightFromBinFunc(float* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
+template int loadWeightFromBinFunc(half* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
+template int loadWeightFromBinFunc(float* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
+template int loadWeightFromBinFunc(half* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
+template int loadWeightFromBinFunc(int8_t* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
#ifdef ENABLE_BF16
-template int
-loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, std::vector shape, std::string filename);
-template int
-loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, std::vector shape, std::string filename);
-template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename);
-template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename);
-template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr,
- std::vector shape,
- std::string filename);
+template int loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
+template int loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
+template int loadWeightFromBinFunc(float* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
+template int loadWeightFromBinFunc(half* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
+template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
#endif // ENABLE_BF16
-template int loadWeightFromBinFunc(int* ptr, std::vector shape, std::string filename);
+template int loadWeightFromBinFunc(int* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
#ifdef ENABLE_FP8
-template int
-loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr, std::vector shape, std::string filename);
+template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr,
+ std::vector shape,
+ std::string filename,
+ std::vector slices);
#endif // ENABLE_FP8
template
-int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type)
+int loadWeightFromBin(T* ptr,
+ std::vector shape,
+ std::string filename,
+ FtCudaDataType model_file_type,
+ std::vector slices)
{
switch (model_file_type) {
case FtCudaDataType::FP32:
- loadWeightFromBinFunc(ptr, shape, filename);
+ loadWeightFromBinFunc(ptr, shape, filename, slices);
break;
case FtCudaDataType::FP16:
- loadWeightFromBinFunc(ptr, shape, filename);
+ loadWeightFromBinFunc(ptr, shape, filename, slices);
break;
case FtCudaDataType::INT8:
- loadWeightFromBinFunc(ptr, shape, filename);
+ loadWeightFromBinFunc(ptr, shape, filename, slices);
break;
#ifdef ENABLE_BF16
case FtCudaDataType::BF16:
- loadWeightFromBinFunc(ptr, shape, filename);
+ loadWeightFromBinFunc(ptr, shape, filename, slices);
break;
#endif
#ifdef ENABLE_FP8
case FtCudaDataType::FP8:
- loadWeightFromBinFunc(ptr, shape, filename);
+ loadWeightFromBinFunc(ptr, shape, filename, slices);
break;
#endif
default:
@@ -424,28 +559,50 @@ int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, F
}
template<>
-int loadWeightFromBin(int* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type)
+int loadWeightFromBin(int* ptr,
+ std::vector shape,
+ std::string filename,
+ FtCudaDataType model_file_type,
+ std::vector slices)
{
- loadWeightFromBinFunc(ptr, shape, filename);
+ loadWeightFromBinFunc(ptr, shape, filename, slices);
return 0;
}
-template int
-loadWeightFromBin(float* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type);
-template int
-loadWeightFromBin(half* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type);
-template int
-loadWeightFromBin(int8_t* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type);
+template int loadWeightFromBin(float* ptr,
+ std::vector shape,
+ std::string filename,
+ FtCudaDataType model_file_type,
+ std::vector slices);
+template int loadWeightFromBin(half* ptr,
+ std::vector shape,
+ std::string filename,
+ FtCudaDataType model_file_type,
+ std::vector slices);
+template int loadWeightFromBin(int8_t* ptr,
+ std::vector shape,
+ std::string filename,
+ FtCudaDataType model_file_type,
+ std::vector slices);
#ifdef ENABLE_BF16
-template int
-loadWeightFromBin(__nv_bfloat16* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type);
+template int loadWeightFromBin(__nv_bfloat16* ptr,
+ std::vector shape,
+ std::string filename,
+ FtCudaDataType model_file_type,
+ std::vector slices);
#endif
#ifdef ENABLE_FP8
-template int
-loadWeightFromBin(__nv_fp8_e4m3* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type);
+template int loadWeightFromBin(__nv_fp8_e4m3* ptr,
+ std::vector shape,
+ std::string filename,
+ FtCudaDataType model_file_type,
+ std::vector slices);
#endif
-template int
-loadWeightFromBin(int* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type);
+template int loadWeightFromBin(int* ptr,
+ std::vector shape,
+ std::string filename,
+ FtCudaDataType model_file_type,
+ std::vector slices);
template
__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size)
diff --git a/src/turbomind/utils/memory_utils.h b/src/turbomind/utils/memory_utils.h
index 6ce11e7719..e51c903905 100644
--- a/src/turbomind/utils/memory_utils.h
+++ b/src/turbomind/utils/memory_utils.h
@@ -49,13 +49,20 @@ void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream =
template
void cudaRandomUniform(T* buffer, const size_t size);
+struct ConcateSlice {
+ std::vector> slices;
+};
+
template
-int loadWeightFromBin(T* ptr,
- std::vector shape,
- std::string filename,
- FtCudaDataType model_file_type = FtCudaDataType::FP32);
+int loadWeightFromBin(T* ptr,
+ std::vector shape,
+ std::string filename,
+ FtCudaDataType model_file_type = FtCudaDataType::FP32,
+ std::vector slices = std::vector());
-std::vector loadArrayFromBin(std::vector shape, std::string filename);
+std::vector loadArrayFromBin(std::vector shape,
+ std::string filename,
+ std::vector slices = std::vector());
// template
// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr,