Skip to content

Commit

Permalink
Develop (#218)
Browse files Browse the repository at this point in the history
Add Numpy Array to TTTensor supports.
Upgrade transformers to 4.11.1
  • Loading branch information
feifeibear committed Sep 30, 2021
1 parent f43f35b commit 559f565
Show file tree
Hide file tree
Showing 56 changed files with 640 additions and 187 deletions.
2 changes: 0 additions & 2 deletions CMakeLists.txt
Expand Up @@ -22,11 +22,9 @@ set(CMAKE_CXX_FLAGS "-Wall")
set(CMAKE_C_FLAGS "-Wall")



set(TURBO_TRANSFORMERS_VERSION 0.6.0)



option(WITH_PROFILER "Compile with profiler" OFF)
option(WITH_GPU "Build with GPU" OFF)
option(WITH_MODULE_BENCHMAKR "Catch2 unitest with benchmarking" ON)
Expand Down
25 changes: 14 additions & 11 deletions README.md
Expand Up @@ -6,11 +6,12 @@
The WeChat AI open-sourced TurboTransformers with the following characteristics.

1. Supporting both Transformers Encoder and Decoder.
2. Supporting real-time variable length inputs. No time-consuming offline tuning is required. You can change the batch size and the sequence length of the request in real-time.
3. Excellent CPU / GPU performance. Backend is implemented with hand-crafted OpenMP and CUDA code and involved with some innovative tricks.
4. Perfect Usability. Supports python and C++ APIs. It can be used as a plugin for PyTorch. The end-to-end acceleration is obtained by adding a few lines of python code.
3. Supports Variable Length inputs. No time-consuming offline tuning is required. You can change batch size and sequence length at real-time.
3. Excellent CPU / GPU performance.
4. Perfect Usibility. TurboTransformers supports python and C++ APIs.
5. Smart Batching. Minimize zero-padding overhead for a batch of requests of different lengths.
6. Memory Efficiency. A new model-aware allocator ensures a small memory footprint during the variable-length request serving.
It can be used as a plugin for pytorch. Tthe end-to-end acceleration is obtained by adding a few lines of python code.


TurboTransformers has been applied to multiple online BERT service scenarios in Tencent.
For example, It brings 1.88x acceleration to the WeChat FAQ service, 2.11x acceleration to the public cloud sentiment analysis service, and 13.6x acceleration to the QQ recommendation system.
Expand All @@ -31,6 +32,7 @@ The following table is a comparison of TurboTransformers and related work.
### Supported Models
We currently support the following transformer models.


* [BERT](https://arxiv.org/abs/1810.04805) [[Python]](./example/python/bert_example.py) [[C++]](./example/python/bert_example.cpp)
* [ALBERT](https://arxiv.org/abs/1909.11942) [[Python]](./example/python/albert_example.py)
* [Roberta](https://arxiv.org/abs/1907.11692) [[Python]](./example/python/roberta_example.py)
Expand Down Expand Up @@ -155,7 +157,7 @@ TurboTransformers provides C++ / python API interfaces. We hope to do our best t


#### Pretrained Model Loading
The first step in using turbo is to load a pre-trained model. We provide a way to load PyTorch and TensorFlow pre-trained models in [huggingface/transformers](https://github.com/huggingface).
The first step in using turbo is to load a pre-trained model. We provide a way to load pytorch and tensorflow pre-trained models in [huggingface/transformers](https://github.com/huggingface).
The specific conversion method is to use the corresponding script in ./tools to convert the pre-trained model into an npz format file, and turbo uses the C ++ or python interface to load the npz format model.
In particular, we consider that most of the pre-trained models are in PyTorch format and used with python. We provide a shortcut for calling directly in python for the PyTorch saved model.

Expand All @@ -175,14 +177,15 @@ Users can link turbo-transformers to your code through add_subdirectory.
Usually, feeding a batch of requests of different lengths into a bert model for inference,
zero-padding is required to make all the requests have the same length.
For example, serving requests list of lengths (100, 10, 50), you need a preprocessing stage to pad them as lengths (100, 100, 100).
In this way, 90% and 50% of the last two sequence computation are wasted.
In this way, 90% and 50% of the last two sequence's computation are wasted.
As indicated in [Effective Transformer](https://github.com/bytedance/effective_transformer),
it is not necessary to pad the input tensors.
As an alternative, you have to pad the batch-gemm operations inside multi-headed attentions,
which accounts for a small proportion of the entire BERT computation.
Therefore most gemm operations are processed without zero-padding.
As an alternative, you just have to pad the batch-gemm operations inside multi-headed attentions,
which accouts to a small propation of the entire BERT computation.
Therefore most of gemm operations are processed without zero-padding.
Turbo provides a model as `BertModelSmartBatch` including a smart batching technique.
The example is presented in [./example/python/bert_smart_batch.py](./example/python/bert_smart_batch.py "smart_batching").
The example is presented in [./example/python/bert_smart_pad.py](./example/python/bert_smart_pad.py "smart_batching").


## How to contribute new models
[How to know hotspots of your code?](./docs/profiler.md)
Expand All @@ -205,7 +208,7 @@ Download PyTorch version to 1.1.0 will improve Turbo's Performance.
3. onnxruntime-cpu==1.4.0 and onnxruntime-gpu==1.3.0 can not work simultaneously.

## History
1. January 2021 v0.6.0, TurboTransformers supports smart batching.
1. Janurary 2021 v0.6.0, TurboTransformers supports smart batching.
2. July 2020 v0.4.0, TurboTransformers used onnxruntime as cpu backend, supports GPT2. Anded a Quantized BERT.
3. July 2020 v0.3.1, TurboTransformers added support for ALbert, Roberta on CPU/GPU.
4. June 2020 v0.3.0, TurboTransformers added support for Transformer Decoder on CPU/GPU.
Expand Down
Empty file added distrill/bert_model.txt
Empty file.
Empty file added distrill/distill_bert.txt
Empty file.
45 changes: 45 additions & 0 deletions distrill/distrill_bert.py
@@ -0,0 +1,45 @@
# Copyright (C) 2020 THL A29 Limited, a Tencent company.
# All rights reserved.
# Licensed under the BSD 3-Clause License (the "License"); you may
# not use this file except in compliance with the License. You may
# obtain a copy of the License at
# https://opensource.org/licenses/BSD-3-Clause
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" basis,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# See the AUTHORS file for names of contributors.

from transformers import DistilBertTokenizer, DistilBertModel
from transformers import BertTokenizer, BertModel
import torch

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
# inputs = torch.randint(low=0,
# high=cfg.vocab_size - 1,
# size=(1, 10),
# dtype=torch.long,
# device=torch.device("cpu:0"))

## distrillation model
model = DistilBertModel.from_pretrained("distilbert-base-uncased",
return_dict=True)

## bert model
bert_model = BertModel.from_pretrained("bert-base-uncased", return_dict=True)

cfg = model.config
print(cfg)
print(inputs)
outputs = model(**inputs)
bert_outputs = bert_model(**inputs)

print(model)
print(bert_model)

# print(bert_outputs - outputs)
#
# last_hidden_states = outputs.last_hidden_state
# print(last_hidden_states)
10 changes: 10 additions & 0 deletions example/python/README.md
Expand Up @@ -38,6 +38,16 @@ I have prepared an image for bert only runtime on dockerhub with .

`thufeifeibear/turbo_transformers_cpu:bert_only_v0.1`

**Attention** : If you want to use turbo with C++ backend instead of onnxrt.
Directly linking an MKL of Pytorch installed by conda will lead to poor performance
in our hand-crafted C++ version.
You should install an official MKL an set MKL PATH in CMakeLists.txt.
As a not so elegant alternative, you can uninstall OpenNMT-py and downgrade torch to 1.1.0.

I have prepared an image for bert only runtime on dockerhub with .

`thufeifeibear/turbo_transformers_cpu:bert_only_v0.1`

### How to customized your post-processing layers after BERT encoder
[Chinese Version](./README.md)
Because TurboTransformer has accelerated embedding + BERT encoder + pooler, which are major hotspots.
Expand Down
3 changes: 1 addition & 2 deletions example/python/bert_example.py
Expand Up @@ -78,8 +78,7 @@ def test(loadtype: LoadType, use_cuda: bool):
sys.exit("ERROR. can not open ", sys.argv[1])
else:
in_file = "/workspace/bert_torch.npz"
tt_model = turbo_transformers.BertModel.from_npz(
in_file, cfg, test_device)
tt_model = turbo_transformers.BertModel.from_npz(in_file, cfg)
else:
raise ("LoadType is not supported")

Expand Down
54 changes: 29 additions & 25 deletions example/python/bert_for_sequence_classification_example.py
Expand Up @@ -17,11 +17,10 @@
from turbo_transformers import ReturnType

# import the class of the acceleration model. here is the example of BertForSequenceClassification.
from transformers.modeling_bert import BertModel as TorchBertModel
from transformers.models.bert.modeling_bert import BertModel as TorchBertModel
from transformers import BertTokenizer
from transformers.modeling_bert import (
BertForSequenceClassification as TorchBertForSequenceClassification,
)
from transformers.models.bert.modeling_bert import (
BertForSequenceClassification as TorchBertForSequenceClassification, )
import os
import torch
from typing import Optional
Expand All @@ -31,19 +30,19 @@
# Contact me if you find it is wrong.
class BertForSequenceClassification: # create a new class for speeding up
def __init__(
self, bertmodel, classifier
self, bertmodel, classifier
): # the realization of the init function(we can just copy it)
self.bert = bertmodel
self.classifier = classifier

def __call__(
self, # the realization of the call function(we can just copy it)
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
pooling_type=PoolingType.FIRST,
return_type=None,
self, # the realization of the call function(we can just copy it)
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
pooling_type=PoolingType.FIRST,
return_type=None,
):
bert_outputs = self.bert(
input_ids,
Expand All @@ -61,9 +60,11 @@ def __call__(

@staticmethod
def from_torch(
model: TorchBertModel, device: Optional[torch.device] = None # from_torch函数实现
model: TorchBertModel,
device: Optional[torch.device] = None # from_torch函数实现
):
if device is not None and "cuda" in device.type and torch.cuda.is_available():
if device is not None and "cuda" in device.type and torch.cuda.is_available(
):
model.to(device)
bertmodel = turbo_transformers.BertModel.from_torch(model.bert)
# We can copy the following code and do not change it
Expand All @@ -72,11 +73,11 @@ def from_torch(
return BertForSequenceClassification(bertmodel, model.classifier)

@staticmethod
def from_pretrained(model_id_or_path: str, device: Optional[torch.device] = None):
def from_pretrained(model_id_or_path: str,
device: Optional[torch.device] = None):
# First, Use the function of from_pretrained to load the model you trained.
torch_model = TorchBertForSequenceClassification.from_pretrained(
model_id_or_path
)
model_id_or_path)
# Then, Use the init function of the acceleration model to get it.
model = BertForSequenceClassification.from_torch(torch_model, device)
model._torch_model = torch_model # prevent destroy torch model.
Expand All @@ -86,18 +87,20 @@ def from_pretrained(model_id_or_path: str, device: Optional[torch.device] = None
# use 4 threads for BERT inference
turbo_transformers.set_num_threads(4)

model_id = os.path.join(
os.path.dirname(__file__), "bert_model"
) # the model of huggingface's path
tokenizer = BertTokenizer.from_pretrained(model_id) # the initialization of tokenizer
model_id = os.path.join(os.path.dirname(__file__),
"bert_model") # the model of huggingface's path
tokenizer = BertTokenizer.from_pretrained(
model_id) # the initialization of tokenizer
turbo_model = BertForSequenceClassification.from_pretrained(
model_id, torch.device("cpu:0")
) # the initialization of the acceleration model
model_id,
torch.device("cpu:0")) # the initialization of the acceleration model

# predict after loading the model

text = "Sample input text"
inputs = tokenizer.encode_plus(text, add_special_tokens=True, return_tensors="pt")
inputs = tokenizer.encode_plus(text,
add_special_tokens=True,
return_tensors="pt")
# turbo_result holds the returned logits from TurboTransformers model
turbo_result = turbo_model(**inputs)

Expand All @@ -106,5 +109,6 @@ def from_pretrained(model_id_or_path: str, device: Optional[torch.device] = None
torch_result = torch_model(**inputs)[0]
print(turbo_result)
# tensor([[0.2716, 0.0318]], grad_fn=<AddmmBackward>)
print(torch_result) # torch_result and turbo_result should hold the same logits
print(
torch_result) # torch_result and turbo_result should hold the same logits
# tensor([[0.2716, 0.0318]], grad_fn=<AddmmBackward>)
85 changes: 85 additions & 0 deletions example/python/bert_smart_pad.py
@@ -0,0 +1,85 @@
# Copyright (C) 2020 THL A29 Limited, a Tencent company.
# All rights reserved.
# Licensed under the BSD 3-Clause License (the "License"); you may
# not use this file except in compliance with the License. You may
# obtain a copy of the License at
# https://opensource.org/licenses/BSD-3-Clause
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" basis,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# See the AUTHORS file for names of contributors.
import torch
import transformers
import turbo_transformers
import enum
import time
import sys


def serial_bert_inference(torch_model, input_list):
res_list = []
for input_seq in input_list:
res, _ = torch_model(input_seq)
res_list.append(res)

for i in range(len(res_list)):
if i == 0:
concat_res = res_list[i]
else:
concat_res = torch.cat((concat_res, res_list[i]), 1)
return concat_res


def batch_bert_inference(turbo_model, input_list, query_seq_len_list):
res, _ = turbo_model(input_list, query_seq_len_list)
return res


def test_smart_batch(use_cuda: bool):
test_device = torch.device('cuda:0') if use_cuda else \
torch.device('cpu:0')
cfg = transformers.BertConfig(attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0)
torch_model = transformers.BertModel(cfg)

# model_id = "bert-base-uncased"
# torch_model = transformers.BertModel.from_pretrained(model_id)
torch_model.eval()
torch_model.to(test_device)
torch.set_grad_enabled(False)

cfg = torch_model.config
# use 4 threads for computing
if not use_cuda:
turbo_transformers.set_num_threads(4)

# Initialize a turbo BertModel with smart batching from torch model.
turbo_model = turbo_transformers.BertModelSmartBatch.from_torch(
torch_model)

# a batch of queries with different lengths.
query_seq_len_list = [18, 2, 3, 51]
input_list = []

# generate random inputs. Of course you can use real data.
for query_seq_len in query_seq_len_list:
input_seq = torch.randint(low=0,
high=cfg.vocab_size - 1,
size=(1, query_seq_len),
dtype=torch.long,
device=test_device)
input_list.append(input_seq)

# start inference
s_res = serial_bert_inference(torch_model, input_list)
b_res = batch_bert_inference(turbo_model, input_list, query_seq_len_list)
print(torch.max(torch.abs(b_res - s_res)))
assert (torch.max(torch.abs(b_res - s_res)) < 1e-2)


if __name__ == "__main__":
if torch.cuda.is_available():
test_smart_batch(True)
test_smart_batch(False)

0 comments on commit 559f565

Please sign in to comment.