From d2fa90b03c5a8fe56c50bd6710ef6d551279afa7 Mon Sep 17 00:00:00 2001 From: sam-writer Date: Wed, 22 Dec 2021 22:03:01 +0000 Subject: [PATCH] WIP: add support for byt5 and mt5, closes #1655 Signed-off-by: sam-writer --- demo/HuggingFace/T5/T5ModelConfig.py | 6 +- demo/HuggingFace/notebooks/t5.ipynb | 207 ++++++++++++++++++--------- 2 files changed, 140 insertions(+), 73 deletions(-) diff --git a/demo/HuggingFace/T5/T5ModelConfig.py b/demo/HuggingFace/T5/T5ModelConfig.py index f07756929..a53a4ae7b 100644 --- a/demo/HuggingFace/T5/T5ModelConfig.py +++ b/demo/HuggingFace/T5/T5ModelConfig.py @@ -107,7 +107,7 @@ def __init__(self): def get_python_requirements(self): base_requirements = super().get_python_requirements() - base_requirements.append("transformers==4.6.1") + base_requirements.append("transformers>=4.8.0") return base_requirements def get_network_segments(self): @@ -119,8 +119,8 @@ def get_network_segments(self): return T5ModelTRTConfig.NETWORK_SEGMENTS def get_metadata_string(self, metadata: NetworkMetadata) -> str: - # Remove redundant t5 name - metadata = metadata._replace(variant=metadata.variant.lstrip("t5-")) + # Remove redundant google/ if present + metadata = metadata._replace(variant=metadata.variant.lstrip("google/")) return super().get_metadata_string(metadata) @staticmethod diff --git a/demo/HuggingFace/notebooks/t5.ipynb b/demo/HuggingFace/notebooks/t5.ipynb index 929f242ad..436434f84 100644 --- a/demo/HuggingFace/notebooks/t5.ipynb +++ b/demo/HuggingFace/notebooks/t5.ipynb @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "0c36ecb7-c622-4d95-a851-b9a6eb18e81b", "metadata": {}, "outputs": [], @@ -79,9 +79,8 @@ "\n", "# huggingface\n", "from transformers import (\n", - " T5ForConditionalGeneration,\n", - " T5Tokenizer,\n", - " T5Config,\n", + " AutoTokenizer,\n", + " AutoConfig\n", ")" ] }, @@ -101,21 +100,32 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 26, "id": "fae66d58-f994-4987-8f1d-1fa8ac2ec8b4", "metadata": {}, "outputs": [], "source": [ - "T5_VARIANT = 't5-small' # choices: t5-small | t5-base | t5-large\n", + "# choices: t5-small | t5-base | t5-large |\n", + "# experimental choices: google/byt5-small | google/byt5-base | google/byt5-large\n", + "# google/mt5-small | google/mt5-base | google/mt5-large\n", + "T5_VARIANT = 't5-small'\n", + "\n", + "if \"mt5\" in T5_VARIANT:\n", + " from transformers import MT5ForConditionalGeneration\n", + " \n", + " t5_model = MT5ForConditionalGeneration.from_pretrained(T5_VARIANT)\n", + "else:\n", + " from transformers import T5ForConditionalGeneration\n", + " \n", + " t5_model = T5ForConditionalGeneration.from_pretrained(T5_VARIANT) # byt5 also uses this\n", "\n", - "t5_model = T5ForConditionalGeneration.from_pretrained(T5_VARIANT)\n", - "tokenizer = T5Tokenizer.from_pretrained(T5_VARIANT)\n", - "config = T5Config(T5_VARIANT)" + "tokenizer = AutoTokenizer.from_pretrained(T5_VARIANT)\n", + "config = AutoConfig.from_pretrained(T5_VARIANT)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 27, "id": "7252ca90-1104-40dc-8e72-f51c07a4cd11", "metadata": {}, "outputs": [ @@ -123,13 +133,18 @@ "name": "stdout", "output_type": "stream", "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "Pytorch Model saved to ./models/t5-small/pytorch\n" ] } ], "source": [ "# save model locally\n", - "pytorch_model_dir = './models/{}/pytorch'.format(T5_VARIANT)\n", + "T5_VARIANT_SAFE_NAME = T5_VARIANT.replace('/', '_')\n", + "pytorch_model_dir = './models/{}/pytorch'.format(T5_VARIANT_SAFE_NAME)\n", "!mkdir -p $pytorch_model_dir\n", "\n", "t5_model.save_pretrained(pytorch_model_dir)\n", @@ -145,12 +160,14 @@ "\n", "Next, we will carry out inference with the PyTorch model.\n", "\n", - "#### Single example inference" + "#### Single example inference\n", + "\n", + "Note: Only t5 had supervised pretraining. If you use byt5 or mt5 models, this won't work and will return gibberish. That is expected. The mt5 and byt5 variants needs supervised training to be used for tasks like translation or classificaiton. There are community checkpoints after supervised training avaialable for [mt5 here](https://huggingface.co/models?search=mt5) and for [byt5 here](https://huggingface.co/models?search=byt5)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 28, "id": "bc45d9bc-b6ef-485e-8832-6628c292e315", "metadata": {}, "outputs": [], @@ -167,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 29, "id": "98f7fd8b-2ee3-4d25-9204-7713eb7e90b3", "metadata": {}, "outputs": [ @@ -199,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 30, "id": "596ea542-d9e5-4367-b643-d60027fa05e6", "metadata": {}, "outputs": [], @@ -216,17 +233,17 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 31, "id": "be755fbc-c53e-4f8d-a9c2-4817167cf93a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0.0072555395308882" + "0.005005257500670268" ] }, - "execution_count": 7, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -242,17 +259,17 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 32, "id": "960f05fc-f572-4832-ad82-8a75823866b1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0.011791097989771515" + "0.008204548999856343" ] }, - "execution_count": 8, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -278,31 +295,46 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 33, + "id": "f1768bbb-07eb-46f5-8558-72e33d13706c", + "metadata": {}, + "outputs": [], + "source": [ + "from T5.T5ModelConfig import T5ModelTRTConfig\n", + "\n", + "# monkey-patch so we don't have to know about every t5 variant ahead of time\n", + "if T5_VARIANT not in T5ModelTRTConfig.TARGET_MODELS:\n", + " T5ModelTRTConfig.TARGET_MODELS.append(T5_VARIANT)\n", + " T5ModelTRTConfig.MAX_SEQUENCE_LENGTH[T5_VARIANT] = config.d_model\n", + " T5ModelTRTConfig.NUMBER_OF_LAYERS[T5_VARIANT] = config.num_layers\n", + " T5ModelTRTConfig.VOCAB_SIZE = config.vocab_size # different in byt5 and possibly in some fine-tuned models\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, "id": "39d511cf-d963-4629-be54-22e9a258716d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0.0667644675122574" + "0.04445125900019775" ] }, - "execution_count": 9, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from T5.T5ModelConfig import T5ModelTRTConfig\n", - "\n", "decoder_output_greedy, full_e2e_median_runtime = full_inference_greedy(\n", " t5_torch_encoder,\n", " t5_torch_decoder,\n", " input_ids,\n", " tokenizer,\n", " TimingProfile(iterations=10, number=1, warmup=1),\n", - " max_length=T5ModelTRTConfig.MAX_SEQUENCE_LENGTH[T5_VARIANT],\n", + " max_length=config.d_model,\n", ")\n", "full_e2e_median_runtime" ] @@ -317,7 +349,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 35, "id": "839bc6bc-65dc-499d-ac26-81456dbc1748", "metadata": {}, "outputs": [ @@ -358,7 +390,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 36, "id": "c2b2be1a-021c-4f6c-957d-2ff7d1b95976", "metadata": {}, "outputs": [], @@ -369,18 +401,29 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 37, "id": "c50346f7-6c2c-4e4b-ba70-875688947b75", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + } + ], "source": [ - "onnx_model_path = './models/{}/ONNX'.format(T5_VARIANT)\n", + "onnx_model_path = './models/{}/ONNX'.format(T5_VARIANT_SAFE_NAME)\n", "!mkdir -p $onnx_model_path\n", "\n", "metadata=NetworkMetadata(T5_VARIANT, Precision('fp16'), None)\n", "\n", - "encoder_onnx_model_fpath = T5_VARIANT + \"-encoder.onnx\"\n", - "decoder_onnx_model_fpath = T5_VARIANT + \"-decoder-with-lm-head.onnx\"\n", + "encoder_onnx_model_fpath = T5_VARIANT_SAFE_NAME + \"-encoder.onnx\"\n", + "decoder_onnx_model_fpath = T5_VARIANT_SAFE_NAME + \"-decoder-with-lm-head.onnx\"\n", "\n", "t5_encoder = T5EncoderTorchFile(t5_model.to('cpu'), metadata)\n", "t5_decoder = T5DecoderTorchFile(t5_model.to('cpu'), metadata)\n", @@ -407,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 38, "id": "037ac958-2627-439c-9db5-27640e3f7967", "metadata": {}, "outputs": [], @@ -417,21 +460,38 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "6bd6e3fc-6797-46b0-a211-ce42d3769105", "metadata": {}, "outputs": [], "source": [ - "tensorrt_model_path = './models/{}/tensorrt'.format(T5_VARIANT)\n", + "tensorrt_model_path = './models/{}/tensorrt'.format(T5_VARIANT_SAFE_NAME)\n", "!mkdir -p tensorrt_model_path" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 40, "id": "cfb64120-9012-40c8-b1e2-4a6366b71294", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[12/22/2021-21:48:58] [TRT] [W] Output type must be INT32 for shape outputs\n", + "[12/22/2021-21:48:58] [TRT] [W] Myelin graph with multiple dynamic values may have poor performance if they differ. Dynamic values are: \n", + "[12/22/2021-21:48:58] [TRT] [W] (# 1 (SHAPE encoder_hidden_states))\n", + "[12/22/2021-21:48:58] [TRT] [W] (# 1 (SHAPE input_ids))\n", + "[12/22/2021-21:50:53] [TRT] [W] Myelin graph with multiple dynamic values may have poor performance if they differ. Dynamic values are: \n", + "[12/22/2021-21:50:53] [TRT] [W] (# 1 (SHAPE encoder_hidden_states))\n", + "[12/22/2021-21:50:53] [TRT] [W] (# 1 (SHAPE input_ids))\n", + "[12/22/2021-21:53:16] [TRT] [W] Myelin graph with multiple dynamic values may have poor performance if they differ. Dynamic values are: \n", + "[12/22/2021-21:53:16] [TRT] [W] (# 1 (SHAPE encoder_hidden_states))\n", + "[12/22/2021-21:53:16] [TRT] [W] (# 1 (SHAPE input_ids))\n" + ] + } + ], "source": [ "t5_trt_encoder_engine = T5EncoderONNXFile(\n", " os.path.join(onnx_model_path, encoder_onnx_model_fpath), metadata\n", @@ -457,7 +517,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 41, "id": "3954f2f4-c393-463b-a44b-3e5335032b57", "metadata": {}, "outputs": [], @@ -465,10 +525,7 @@ "# Initialize TensorRT engines\n", "from T5.trt import T5TRTEncoder, T5TRTDecoder\n", "\n", - "tfm_config = T5Config(\n", - " use_cache=True,\n", - " num_layers=T5ModelTRTConfig.NUMBER_OF_LAYERS[T5_VARIANT],\n", - ")\n", + "tfm_config = config\n", " \n", "t5_trt_encoder = T5TRTEncoder(\n", " t5_trt_encoder_engine, metadata, tfm_config\n", @@ -480,19 +537,20 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 47, "id": "a9544ecb-2671-4b53-a544-08f13424cefe", "metadata": {}, "outputs": [], "source": [ "# Inference on a single sample\n", "encoder_last_hidden_state = t5_trt_encoder(input_ids=input_ids)\n", + "t5_trt_decoder.set_return_device(encoder_last_hidden_state.device)\n", "outputs = t5_trt_decoder(input_ids, encoder_last_hidden_state)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 48, "id": "8d71a327-546f-4b5b-bd42-caaffcceafc7", "metadata": {}, "outputs": [ @@ -500,7 +558,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Das ist gut.\n" + "Das clic\n" ] } ], @@ -511,7 +569,7 @@ " StoppingCriteriaList,\n", ")\n", "\n", - "max_length = 64\n", + "max_length = 64 if 'byt5' not in T5_VARIANT else 256\n", "\n", "decoder_input_ids = torch.full(\n", " (1, 1), tokenizer.convert_tokens_to_ids(tokenizer.pad_token), dtype=torch.int32\n", @@ -526,6 +584,24 @@ "print(tokenizer.decode(outputs[0], skip_special_tokens=True))" ] }, + { + "cell_type": "code", + "execution_count": 50, + "id": "d4ee9813-bd4e-48f6-9411-416dc3fbae92", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Das clic\n" + ] + } + ], + "source": [ + "print(tokenizer.decode(outputs[0], skip_special_tokens=False))" + ] + }, { "cell_type": "markdown", "id": "ed9d4a98-b034-470e-a9f8-096d4100b8d4", @@ -537,17 +613,17 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 51, "id": "70b37591-4398-40ff-8a39-5f75347192dc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0.000644649553578347" + "0.0010203349993389565" ] }, - "execution_count": 19, + "execution_count": 51, "metadata": {}, "output_type": "execute_result" } @@ -561,17 +637,17 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 52, "id": "7e5459da-a01b-4894-88dc-01b3637ded53", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0.0014052424812689424" + "0.0017294864992436487" ] }, - "execution_count": 20, + "execution_count": 52, "metadata": {}, "output_type": "execute_result" } @@ -595,26 +671,17 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "f31cb550-24b9-48cd-a4ec-0bf18ac5e40c", "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "Das ist gut.\n" + "[2021-12-22 21:57:39,889][OSS][WARNING] Unable to execute program using cuda compatible device: The expanded size of the tensor (63) must match the existing size (64) at non-singleton dimension 0. Target sizes: [63]. Tensor sizes: [64]\n", + "[2021-12-22 21:57:39,890][OSS][WARNING] Retrying using CPU only.\n" ] - }, - { - "data": { - "text/plain": [ - "0.007905778475105762" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -625,7 +692,7 @@ " tokenizer,\n", " TimingProfile(10,1,1),\n", " max_length=T5ModelTRTConfig.MAX_SEQUENCE_LENGTH[metadata.variant],\n", - " use_cuda=False,\n", + " use_cuda=True,\n", ")\n", "\n", "print(tokenizer.decode(decoder_output_greedy[0], skip_special_tokens=True))\n", @@ -655,7 +722,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -669,7 +736,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.8.12" } }, "nbformat": 4,