From a8a8338111c0c104cd981a7cf69dcbdaf573f968 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Wed, 1 Feb 2023 12:22:29 +0800 Subject: [PATCH] [Benchmark] Add BERT GPT T5 finetune process (#4492) --- paddlenlp/transformers/gpt/modeling.py | 4 +- .../benchmark/modules/benchmark_utils.py | 63 ++++++ .../modules/bert_for_question_answering.py | 106 ++++++++++ .../gpt_for_sequence_classification.py | 79 ++++++++ .../test_tipc/benchmark/modules/model_base.py | 21 +- .../modules/t5_for_conditional_generation.py | 81 ++++++++ tests/test_tipc/benchmark/options.py | 46 ++++- .../train_infer_python.txt | 59 ++++++ .../train_infer_python.txt | 59 ++++++ .../train_infer_python.txt | 59 ++++++ tests/test_tipc/train.py | 190 +++++++++++++++++- 11 files changed, 744 insertions(+), 23 deletions(-) create mode 100644 tests/test_tipc/benchmark/modules/benchmark_utils.py create mode 100644 tests/test_tipc/benchmark/modules/bert_for_question_answering.py create mode 100644 tests/test_tipc/benchmark/modules/gpt_for_sequence_classification.py create mode 100644 tests/test_tipc/benchmark/modules/t5_for_conditional_generation.py create mode 100644 tests/test_tipc/configs/bert_for_question_answering/train_infer_python.txt create mode 100644 tests/test_tipc/configs/gpt_for_sequence_classification/train_infer_python.txt create mode 100644 tests/test_tipc/configs/t5_for_conditional_generation/train_infer_python.txt diff --git a/paddlenlp/transformers/gpt/modeling.py b/paddlenlp/transformers/gpt/modeling.py index 36f8eb0eb431..f4b7d258fe6e 100644 --- a/paddlenlp/transformers/gpt/modeling.py +++ b/paddlenlp/transformers/gpt/modeling.py @@ -1445,7 +1445,9 @@ def forward( "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits.gather_nd(paddle.stack([paddle.arange(logits.shape[0]), sequence_lengths], axis=-1)) + pooled_logits = logits.gather_nd( + paddle.stack([paddle.arange(paddle.shape(logits)[0]), sequence_lengths], axis=-1) + ) loss = None if labels is not None: diff --git a/tests/test_tipc/benchmark/modules/benchmark_utils.py b/tests/test_tipc/benchmark/modules/benchmark_utils.py new file mode 100644 index 000000000000..2a062f2b8706 --- /dev/null +++ b/tests/test_tipc/benchmark/modules/benchmark_utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + + +def rand_int_tensor(low, high, shape): + return paddle.randint( + low, + high, + shape=shape, + dtype=paddle.int64, + ) + + +def clone_tensor(x): + y = x.clone() + return y + + +def clone_input(x): + def paddle_clone(x): + y = paddle.clone(x) + if x.is_leaf: + y.stop_gradient = x.stop_gradient + if x.is_leaf and x.grad is not None: + y.grad = clone_input(x.grad) + return y + + with paddle.no_grad(): + result = paddle.empty(x.shape, dtype=x.dtype) + result.copy_(x.clone(), True) + if x.is_leaf: + result.stop_gradient = x.stop_gradient + if x.is_leaf and x.grad is not None: + result.grad = clone_input(x.grad) + return result + + +def clone_inputs(example_inputs): + if isinstance(example_inputs, dict): + res = dict(example_inputs) + for key, value in res.items(): + assert isinstance(value, paddle.Tensor) + res[key] = clone_input(value) + return res + + res = list(example_inputs) + for i in range(len(res)): + if isinstance(res[i], paddle.Tensor): + res[i] = clone_input(res[i]) + return res diff --git a/tests/test_tipc/benchmark/modules/bert_for_question_answering.py b/tests/test_tipc/benchmark/modules/bert_for_question_answering.py new file mode 100644 index 000000000000..26f57acf88a1 --- /dev/null +++ b/tests/test_tipc/benchmark/modules/bert_for_question_answering.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from paddlenlp.transformers import BertForQuestionAnswering +from paddlenlp.utils.log import logger + +from .benchmark_utils import rand_int_tensor +from .model_base import BenchmarkBase + + +class BertForQuestionAnsweringBenchmark(BenchmarkBase): + def __init__(self): + self.label_list = None + super().__init__() + + @staticmethod + def add_args(args, parser): + parser.add_argument( + "--model_name_or_path", type=str, default="bert-base-cased", help="Model name. Defaults to bert-base. " + ) + # args.max_seq_len + + def create_data_loader(self, args, **kwargs): + raise NotImplementedError( + "bert_for_question_answering's DataLoader is not implemented. Please use --generated_inputs. " + ) + + def generate_inputs_for_model(self, args, model, **kwargs): + input_ids = rand_int_tensor(0, model.config.vocab_size, [args.batch_size, args.max_seq_len]) + start_positions = rand_int_tensor( + 0, + args.max_seq_len, + [ + args.batch_size, + ], + ) + end_positions = rand_int_tensor( + 0, + args.max_seq_len, + [ + args.batch_size, + ], + ) + return {"input_ids": input_ids, "start_positions": start_positions, "end_positions": end_positions} + + def build_model(self, args, **kwargs): + model = BertForQuestionAnswering.from_pretrained(args.model_name_or_path) + return model + + def forward(self, model, args, input_data=None, **kwargs): + start_positions = input_data.pop("start_positions") + end_positions = input_data.pop("end_positions") + start_logits, end_logits = model(**input_data) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if start_positions.ndim > 1: + start_positions = start_positions.squeeze(-1) + if start_positions.ndim > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = paddle.shape(start_logits)[1] + start_positions = start_positions.clip(0, ignored_index) + end_positions = end_positions.clip(0, ignored_index) + + loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + return ( + total_loss, + paddle.sum((input_data["input_ids"] != model.config.pad_token_id)).numpy().astype("int64").item(), + ) + + def logger( + self, + args, + step_id=None, + pass_id=None, + batch_id=None, + loss=None, + batch_cost=None, + reader_cost=None, + num_samples=None, + ips=None, + **kwargs + ): + logger.info( + "global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f words/sec" + % (step_id, args.epoch * self.num_batch, loss, reader_cost, batch_cost, num_samples, ips) + ) diff --git a/tests/test_tipc/benchmark/modules/gpt_for_sequence_classification.py b/tests/test_tipc/benchmark/modules/gpt_for_sequence_classification.py new file mode 100644 index 000000000000..708f52480257 --- /dev/null +++ b/tests/test_tipc/benchmark/modules/gpt_for_sequence_classification.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from paddlenlp.transformers import GPTForSequenceClassification +from paddlenlp.utils.log import logger + +from .benchmark_utils import rand_int_tensor +from .model_base import BenchmarkBase + + +class GPTForSequenceClassificationBenchmark(BenchmarkBase): + def __init__(self): + self.label_list = None + super().__init__() + + @staticmethod + def add_args(args, parser): + parser.add_argument( + "--model_name_or_path", type=str, default="gpt2-en", help="Model name. Defaults to gpt2-en. " + ) + # args.max_seq_len + + def create_data_loader(self, args, **kwargs): + raise NotImplementedError( + "gpt_for_sequence_classification's DataLoader is not implemented. Please use --generated_inputs. " + ) + + def create_input_specs(self): + input_ids = paddle.static.InputSpec(name="input_ids", shape=[-1, -1], dtype="int64") + labels = paddle.static.InputSpec(name="labels", shape=[-1], dtype="int64") + return [input_ids, None, None, None, labels] + + def generate_inputs_for_model(self, args, model, **kwargs): + input_ids = rand_int_tensor(0, model.config.vocab_size, [args.batch_size, args.max_seq_len]) + labels = rand_int_tensor(0, model.config.num_classes - 1, [args.batch_size]) + + return {"input_ids": input_ids, "labels": labels} + + def build_model(self, args, **kwargs): + model = GPTForSequenceClassification.from_pretrained(args.model_name_or_path) + return model + + def forward(self, model, args, input_data=None, **kwargs): + res = model(**input_data) + return ( + res[0], + paddle.sum((input_data["input_ids"] != model.config.pad_token_id)).numpy().astype("int64").item(), + ) + + def logger( + self, + args, + step_id=None, + pass_id=None, + batch_id=None, + loss=None, + batch_cost=None, + reader_cost=None, + num_samples=None, + ips=None, + **kwargs + ): + logger.info( + "global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f words/sec" + % (step_id, args.epoch * self.num_batch, loss, reader_cost, batch_cost, num_samples, ips) + ) diff --git a/tests/test_tipc/benchmark/modules/model_base.py b/tests/test_tipc/benchmark/modules/model_base.py index c3184febe01f..cca9b84c8f8a 100644 --- a/tests/test_tipc/benchmark/modules/model_base.py +++ b/tests/test_tipc/benchmark/modules/model_base.py @@ -1,4 +1,17 @@ -import paddle +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddlenlp.utils.log import logger @@ -16,6 +29,12 @@ def create_data_loader(self, args, **kwargs): def build_model(self, args, **kwargs): raise NotImplementedError + def generate_inputs_for_model(self, args, **kwargs): + raise NotImplementedError + + def create_input_specs(self): + return None + def forward(self, model, args, input_data=None, **kwargs): raise NotImplementedError diff --git a/tests/test_tipc/benchmark/modules/t5_for_conditional_generation.py b/tests/test_tipc/benchmark/modules/t5_for_conditional_generation.py new file mode 100644 index 000000000000..09addb1fe589 --- /dev/null +++ b/tests/test_tipc/benchmark/modules/t5_for_conditional_generation.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from paddlenlp.transformers import T5ForConditionalGeneration +from paddlenlp.utils.log import logger + +from .benchmark_utils import rand_int_tensor +from .model_base import BenchmarkBase + + +class T5ForConditionalGenerationBenchmark(BenchmarkBase): + def __init__(self): + self.label_list = None + super().__init__() + + @staticmethod + def add_args(args, parser): + parser.add_argument( + "--model_name_or_path", type=str, default="t5-small", help="Model name. Defaults to t5-small. " + ) + # args.max_seq_len + + def create_data_loader(self, args, **kwargs): + raise NotImplementedError( + "t5_for_conditional_genneration's DataLoader is not implemented. Please use --generated_inputs. " + ) + + def create_input_specs(self): + input_ids = paddle.static.InputSpec(name="input_ids", shape=[-1, -1], dtype="int64") + decoder_input_ids = paddle.static.InputSpec(name="decoder_input_ids", shape=[-1, -1], dtype="int64") + labels = paddle.static.InputSpec(name="labels", shape=[-1, -1], dtype="int64") + return [input_ids, None, decoder_input_ids, None, None, None, labels] + + def generate_inputs_for_model(self, args, model, **kwargs): + input_ids = rand_int_tensor(0, model.config.vocab_size, [args.batch_size, args.max_seq_len]) + decoder_input_ids = input_ids + labels = rand_int_tensor(0, model.config.vocab_size - 1, [args.batch_size, args.max_seq_len]) + + return {"input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "labels": labels} + + def build_model(self, args, **kwargs): + model = T5ForConditionalGeneration.from_pretrained(args.model_name_or_path) + return model + + def forward(self, model, args, input_data=None, **kwargs): + res = model(**input_data) + return ( + res[0], + paddle.sum((input_data["input_ids"] != model.config.pad_token_id)).numpy().astype("int64").item(), + ) + + def logger( + self, + args, + step_id=None, + pass_id=None, + batch_id=None, + loss=None, + batch_cost=None, + reader_cost=None, + num_samples=None, + ips=None, + **kwargs + ): + logger.info( + "global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f words/sec" + % (step_id, args.epoch * self.num_batch, loss, reader_cost, batch_cost, num_samples, ips) + ) diff --git a/tests/test_tipc/benchmark/options.py b/tests/test_tipc/benchmark/options.py index 87080b29afab..dc98f4614768 100644 --- a/tests/test_tipc/benchmark/options.py +++ b/tests/test_tipc/benchmark/options.py @@ -1,12 +1,31 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse -from .modules.seq2seq import Seq2SeqBenchmark +from .modules.bert_for_question_answering import BertForQuestionAnsweringBenchmark from .modules.bigru_crf import BiGruCrfBenchmark -from .modules.xlnet import XLNetBenchmark -from .modules.rnnlm import RNNLMBenchmark from .modules.ernie_tiny import ErnieTinyBenchmark -from .modules.optimizer import * -from .modules.lr_scheduler import * +from .modules.gpt_for_sequence_classification import ( + GPTForSequenceClassificationBenchmark, +) +from .modules.lr_scheduler import * # noqa: F403 +from .modules.optimizer import * # noqa: F403 +from .modules.rnnlm import RNNLMBenchmark +from .modules.seq2seq import Seq2SeqBenchmark +from .modules.t5_for_conditional_generation import T5ForConditionalGenerationBenchmark +from .modules.xlnet import XLNetBenchmark __all__ = [ "MODEL_REGISTRY", @@ -22,17 +41,20 @@ "lac": BiGruCrfBenchmark, "ptb": RNNLMBenchmark, "ernie_tiny": ErnieTinyBenchmark, + "bert_for_question_answering": BertForQuestionAnsweringBenchmark, + "gpt_for_sequence_classification": GPTForSequenceClassificationBenchmark, + "t5_for_conditional_generation": T5ForConditionalGenerationBenchmark, } OPTIMIZER_REGISTRY = { - "adam": AdamBenchmark, - "adamw": AdamWBenchmark, - "sgd": SGDBenchmark, + "adam": AdamBenchmark, # noqa: F405 + "adamw": AdamWBenchmark, # noqa: F405 + "sgd": SGDBenchmark, # noqa: F405 } LR_SCHEDULER_REGISTRY = { - "lambda_decay": LambdaDecayBenchmark, - "linear_decay_with_warmup": LinearDecayWithWarmupBenchmark, + "lambda_decay": LambdaDecayBenchmark, # noqa: F405 + "linear_decay_with_warmup": LinearDecayWithWarmupBenchmark, # noqa: F405 } @@ -93,9 +115,13 @@ def get_parser(): parser.add_argument("--amp_level", type=str, default="O1", help="AMP LEVEL. O1 or O2. ") parser.add_argument("--custom_black_list", type=str, nargs="+", default=None, help="Custom black list for AMP. ") + parser.add_argument("--to_static", action="store_true", help="Enable to static. ") + parser.add_argument("--max_steps", type=int, default=None, help="Maximum steps. ") parser.add_argument("--epoch", type=int, default=10, help="Number of epochs. ") + parser.add_argument("--generated_inputs", action="store_true", help="Use generated inputs. ") + # For benchmark. parser.add_argument( "--profiler_options", diff --git a/tests/test_tipc/configs/bert_for_question_answering/train_infer_python.txt b/tests/test_tipc/configs/bert_for_question_answering/train_infer_python.txt new file mode 100644 index 000000000000..5d88f9696940 --- /dev/null +++ b/tests/test_tipc/configs/bert_for_question_answering/train_infer_python.txt @@ -0,0 +1,59 @@ +===========================train_params=========================== +model_name:bert_for_question_answering +python:python3.7 +gpu_list:0|0,1 +null:null +--use_amp:null +--max_steps:null +null:null +--batch_size:null +null:null +null:null +null:null +null:null +## +trainer:norm_train +norm_train:test_tipc/train.py --model bert_for_question_answering --optimizer adam --max_seq_len 512 --generated_inputs +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +null:null +null:null +norm_export:null +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:null +inference:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +===========================to_static_train_benchmark_params=========================== +to_static_train:--to_static +===========================train_benchmark_params========================== +batch_size:32|64|96 +fp_items:fp32|fp16 +epoch:500 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/tests/test_tipc/configs/gpt_for_sequence_classification/train_infer_python.txt b/tests/test_tipc/configs/gpt_for_sequence_classification/train_infer_python.txt new file mode 100644 index 000000000000..4bb6ca1b736c --- /dev/null +++ b/tests/test_tipc/configs/gpt_for_sequence_classification/train_infer_python.txt @@ -0,0 +1,59 @@ +===========================train_params=========================== +model_name:gpt_for_sequence_classification +python:python3.7 +gpu_list:0|0,1 +null:null +--use_amp:null +--max_steps:null +null:null +--batch_size:null +null:null +null:null +null:null +null:null +## +trainer:norm_train +norm_train:test_tipc/train.py --model gpt_for_sequence_classification --optimizer adam --max_seq_len 1024 --generated_inputs +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +null:null +null:null +norm_export:null +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:null +inference:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +===========================to_static_train_benchmark_params=========================== +to_static_train:--to_static +===========================train_benchmark_params========================== +batch_size:2|8 +fp_items:fp32|fp16 +epoch:500 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/tests/test_tipc/configs/t5_for_conditional_generation/train_infer_python.txt b/tests/test_tipc/configs/t5_for_conditional_generation/train_infer_python.txt new file mode 100644 index 000000000000..5b50c48e0d17 --- /dev/null +++ b/tests/test_tipc/configs/t5_for_conditional_generation/train_infer_python.txt @@ -0,0 +1,59 @@ +===========================train_params=========================== +model_name:t5_for_conditional_generation +python:python3.7 +gpu_list:0|0,1 +null:null +--use_amp:null +--max_steps:null +null:null +--batch_size:null +null:null +null:null +null:null +null:null +## +trainer:norm_train +norm_train:test_tipc/train.py --model t5_for_conditional_generation --optimizer adam --max_seq_len 1024 --generated_inputs +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +null:null +null:null +norm_export:null +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:null +inference:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +===========================to_static_train_benchmark_params=========================== +to_static_train:--to_static +===========================train_benchmark_params========================== +batch_size:2 +fp_items:fp32|fp16 +epoch:500 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 diff --git a/tests/test_tipc/train.py b/tests/test_tipc/train.py index 3ca65018d626..77619bb9223f 100644 --- a/tests/test_tipc/train.py +++ b/tests/test_tipc/train.py @@ -1,21 +1,32 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect import os -import sys -import time import random -import inspect -import numpy as np +import time from pprint import pprint -from paddlenlp.utils import profiler - +import numpy as np import paddle import paddle.distributed as dist - -import benchmark from benchmark import options -from benchmark.options import MODEL_REGISTRY, OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY +from benchmark.modules.benchmark_utils import clone_inputs +from benchmark.options import LR_SCHEDULER_REGISTRY, MODEL_REGISTRY, OPTIMIZER_REGISTRY from benchmark.utils.record import AverageStatistical +from paddlenlp.utils import profiler from paddlenlp.utils.log import logger @@ -25,6 +36,150 @@ def set_seed(seed): np.random.seed(seed) +def do_generated_inputs(args): + if args.device == "gpu": + rank = dist.get_rank() + trainer_count = dist.get_world_size() + else: + rank = 0 + trainer_count = 1 + paddle.set_device("cpu") + + if trainer_count > 1: + dist.init_parallel_env() + + # Set seed for CE + if args.seed is not None: + set_seed(args.seed) + + benchmark_model = MODEL_REGISTRY[args.model]() + benchmark_optimizer = OPTIMIZER_REGISTRY[args.optimizer]() + + if args.max_steps is None or (args.max_steps is not None and args.max_steps < 0): + args.max_steps = 10000 + + # Define model + model = benchmark_model.build_model(args) + + if args.to_static: + input_spec = benchmark_model.create_input_specs() + model = paddle.jit.to_static(model, input_spec=input_spec) + logger.info("Successfully to apply @to_static with specs: {}".format(input_spec)) + + # Define data loader + example_inputs = benchmark_model.generate_inputs_for_model(args, model) + + if args.lr_scheduler is not None: + benchmark_lr_scheduler = LR_SCHEDULER_REGISTRY[args.lr_scheduler]() + lr = benchmark_lr_scheduler.build_scheculer(args) + else: + lr = args.learning_rate + + optimizer = benchmark_optimizer.build_optimizer(args, lr, model) + + # for amp training + if args.use_amp: + scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=args.scale_loss) + model = paddle.amp.decorate(models=model, level=args.amp_level, save_dtype="float32") + + # for distributed training + if trainer_count > 1: + model = paddle.DataParallel(model) + + step_id = 1 + + # For benchmark + reader_cost_avg = AverageStatistical() + batch_cost_avg = AverageStatistical() + batch_ips_avg = AverageStatistical() + + # Train loop + for pass_id in range(args.epoch): + epoch_start = time.time() + + batch_start = time.time() + for batch_id in range(args.max_steps): + train_reader_cost = time.time() - batch_start + cloned_inputs = clone_inputs(example_inputs) + + if args.use_amp: + with paddle.amp.auto_cast( + custom_black_list=args.custom_black_list if args.amp_level == "O2" else {}, level=args.amp_level + ): + loss, sample_per_cards = benchmark_model.forward(model, args, cloned_inputs) + + scaled = scaler.scale(loss) + scaled.backward() + + scaler.minimize(optimizer, scaled) + if "set_to_zero" in inspect.getfullargspec(optimizer.clear_grad).args: + optimizer.clear_grad(set_to_zero=False) + else: + optimizer.clear_grad() + else: + loss, sample_per_cards = benchmark_model.forward(model, args, cloned_inputs) + + loss.backward() + + optimizer.step() + optimizer.clear_grad() + + if args.profiler_options is not None: + profiler.add_profiler_step(args.profiler_options) + + if args.max_steps and step_id == args.max_steps: + if args.save_model and rank == 0: + model_dir = args.save_model + if not os.path.exists(model_dir): + os.makedirs(model_dir) + paddle.save(model.state_dict(), os.path.join(model_dir, "model.pdparams")) + paddle.save(optimizer.state_dict(), os.path.join(model_dir, "model.pdopt")) + return + + if args.lr_scheduler is not None and not args.scheduler_update_by_epoch: + lr.step() + + if step_id % args.logging_steps == 0: + total_avg_loss = loss.numpy() + + train_batch_cost = time.time() - batch_start + reader_cost_avg.record(train_reader_cost) + batch_cost_avg.record(train_batch_cost) + batch_ips_avg.record(train_batch_cost, sample_per_cards) + + benchmark_model.logger( + args, + step_id=step_id, + pass_id=pass_id, + batch_id=batch_id, + loss=total_avg_loss, + batch_cost=batch_cost_avg.get_average(), + reader_cost=reader_cost_avg.get_average(), + num_samples=sample_per_cards, + ips=batch_ips_avg.get_average_per_sec(), + ) + + reader_cost_avg.reset() + batch_cost_avg.reset() + batch_ips_avg.reset() + else: + train_batch_cost = time.time() - batch_start + reader_cost_avg.record(train_reader_cost) + batch_cost_avg.record(train_batch_cost) + batch_ips_avg.record(train_batch_cost, sample_per_cards) + + batch_start = time.time() + + batch_id += 1 + step_id += 1 + + if args.lr_scheduler is not None and args.scheduler_update_by_epoch: + lr.step() + + train_epoch_cost = time.time() - epoch_start + logger.info("train epoch: %d, epoch_cost: %.5f s" % (pass_id, train_epoch_cost)) + + def do_train(args): if args.device == "gpu": rank = dist.get_rank() @@ -53,6 +208,11 @@ def do_train(args): # Define model model = benchmark_model.build_model(args) + if args.to_static: + input_spec = benchmark_model.create_input_specs() + model = paddle.jit.to_static(model, input_spec=input_spec) + logger.info("Successfully to apply @to_static with specs: {}".format(input_spec)) + if args.lr_scheduler is not None: benchmark_lr_scheduler = LR_SCHEDULER_REGISTRY[args.lr_scheduler]() lr = benchmark_lr_scheduler.build_scheculer(args) @@ -165,7 +325,7 @@ def do_train(args): def do_hapi(args): - device = paddle.set_device(args.device) + paddle.set_device(args.device) # Set seed for CE if args.seed is not None: @@ -185,6 +345,11 @@ def do_hapi(args): model = benchmark_model.build_model(args) + if args.to_static: + input_spec = benchmark_model.create_input_specs() + model = paddle.jit.to_static(model, input_spec=input_spec) + logger.info("Successfully to apply @to_static with specs: {}".format(input_spec)) + optimizer = benchmark_optimizer.build_optimizer(args, lr, model) benchmark_model.forward(model, args, optimizer=optimizer, train_loader=train_loader, eval_loader=eval_loader) @@ -193,8 +358,11 @@ def do_hapi(args): if __name__ == "__main__": parser = options.get_training_parser() args = options.parse_args_and_model(parser) + pprint(args) - if getattr(args, "use_hapi", False): + if args.generated_inputs: + do_generated_inputs(args) + elif getattr(args, "use_hapi", False): do_hapi(args) else: do_train(args)