-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Benchmark] Add BERT GPT T5 finetune process (#4492)
- Loading branch information
Showing
11 changed files
with
744 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
|
||
|
||
def rand_int_tensor(low, high, shape): | ||
return paddle.randint( | ||
low, | ||
high, | ||
shape=shape, | ||
dtype=paddle.int64, | ||
) | ||
|
||
|
||
def clone_tensor(x): | ||
y = x.clone() | ||
return y | ||
|
||
|
||
def clone_input(x): | ||
def paddle_clone(x): | ||
y = paddle.clone(x) | ||
if x.is_leaf: | ||
y.stop_gradient = x.stop_gradient | ||
if x.is_leaf and x.grad is not None: | ||
y.grad = clone_input(x.grad) | ||
return y | ||
|
||
with paddle.no_grad(): | ||
result = paddle.empty(x.shape, dtype=x.dtype) | ||
result.copy_(x.clone(), True) | ||
if x.is_leaf: | ||
result.stop_gradient = x.stop_gradient | ||
if x.is_leaf and x.grad is not None: | ||
result.grad = clone_input(x.grad) | ||
return result | ||
|
||
|
||
def clone_inputs(example_inputs): | ||
if isinstance(example_inputs, dict): | ||
res = dict(example_inputs) | ||
for key, value in res.items(): | ||
assert isinstance(value, paddle.Tensor) | ||
res[key] = clone_input(value) | ||
return res | ||
|
||
res = list(example_inputs) | ||
for i in range(len(res)): | ||
if isinstance(res[i], paddle.Tensor): | ||
res[i] = clone_input(res[i]) | ||
return res |
106 changes: 106 additions & 0 deletions
106
tests/test_tipc/benchmark/modules/bert_for_question_answering.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
|
||
from paddlenlp.transformers import BertForQuestionAnswering | ||
from paddlenlp.utils.log import logger | ||
|
||
from .benchmark_utils import rand_int_tensor | ||
from .model_base import BenchmarkBase | ||
|
||
|
||
class BertForQuestionAnsweringBenchmark(BenchmarkBase): | ||
def __init__(self): | ||
self.label_list = None | ||
super().__init__() | ||
|
||
@staticmethod | ||
def add_args(args, parser): | ||
parser.add_argument( | ||
"--model_name_or_path", type=str, default="bert-base-cased", help="Model name. Defaults to bert-base. " | ||
) | ||
# args.max_seq_len | ||
|
||
def create_data_loader(self, args, **kwargs): | ||
raise NotImplementedError( | ||
"bert_for_question_answering's DataLoader is not implemented. Please use --generated_inputs. " | ||
) | ||
|
||
def generate_inputs_for_model(self, args, model, **kwargs): | ||
input_ids = rand_int_tensor(0, model.config.vocab_size, [args.batch_size, args.max_seq_len]) | ||
start_positions = rand_int_tensor( | ||
0, | ||
args.max_seq_len, | ||
[ | ||
args.batch_size, | ||
], | ||
) | ||
end_positions = rand_int_tensor( | ||
0, | ||
args.max_seq_len, | ||
[ | ||
args.batch_size, | ||
], | ||
) | ||
return {"input_ids": input_ids, "start_positions": start_positions, "end_positions": end_positions} | ||
|
||
def build_model(self, args, **kwargs): | ||
model = BertForQuestionAnswering.from_pretrained(args.model_name_or_path) | ||
return model | ||
|
||
def forward(self, model, args, input_data=None, **kwargs): | ||
start_positions = input_data.pop("start_positions") | ||
end_positions = input_data.pop("end_positions") | ||
start_logits, end_logits = model(**input_data) | ||
|
||
total_loss = None | ||
if start_positions is not None and end_positions is not None: | ||
# If we are on multi-GPU, split add a dimension | ||
if start_positions.ndim > 1: | ||
start_positions = start_positions.squeeze(-1) | ||
if start_positions.ndim > 1: | ||
end_positions = end_positions.squeeze(-1) | ||
# sometimes the start/end positions are outside our model inputs, we ignore these terms | ||
ignored_index = paddle.shape(start_logits)[1] | ||
start_positions = start_positions.clip(0, ignored_index) | ||
end_positions = end_positions.clip(0, ignored_index) | ||
|
||
loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index) | ||
start_loss = loss_fct(start_logits, start_positions) | ||
end_loss = loss_fct(end_logits, end_positions) | ||
total_loss = (start_loss + end_loss) / 2 | ||
|
||
return ( | ||
total_loss, | ||
paddle.sum((input_data["input_ids"] != model.config.pad_token_id)).numpy().astype("int64").item(), | ||
) | ||
|
||
def logger( | ||
self, | ||
args, | ||
step_id=None, | ||
pass_id=None, | ||
batch_id=None, | ||
loss=None, | ||
batch_cost=None, | ||
reader_cost=None, | ||
num_samples=None, | ||
ips=None, | ||
**kwargs | ||
): | ||
logger.info( | ||
"global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f words/sec" | ||
% (step_id, args.epoch * self.num_batch, loss, reader_cost, batch_cost, num_samples, ips) | ||
) |
79 changes: 79 additions & 0 deletions
79
tests/test_tipc/benchmark/modules/gpt_for_sequence_classification.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
|
||
from paddlenlp.transformers import GPTForSequenceClassification | ||
from paddlenlp.utils.log import logger | ||
|
||
from .benchmark_utils import rand_int_tensor | ||
from .model_base import BenchmarkBase | ||
|
||
|
||
class GPTForSequenceClassificationBenchmark(BenchmarkBase): | ||
def __init__(self): | ||
self.label_list = None | ||
super().__init__() | ||
|
||
@staticmethod | ||
def add_args(args, parser): | ||
parser.add_argument( | ||
"--model_name_or_path", type=str, default="gpt2-en", help="Model name. Defaults to gpt2-en. " | ||
) | ||
# args.max_seq_len | ||
|
||
def create_data_loader(self, args, **kwargs): | ||
raise NotImplementedError( | ||
"gpt_for_sequence_classification's DataLoader is not implemented. Please use --generated_inputs. " | ||
) | ||
|
||
def create_input_specs(self): | ||
input_ids = paddle.static.InputSpec(name="input_ids", shape=[-1, -1], dtype="int64") | ||
labels = paddle.static.InputSpec(name="labels", shape=[-1], dtype="int64") | ||
return [input_ids, None, None, None, labels] | ||
|
||
def generate_inputs_for_model(self, args, model, **kwargs): | ||
input_ids = rand_int_tensor(0, model.config.vocab_size, [args.batch_size, args.max_seq_len]) | ||
labels = rand_int_tensor(0, model.config.num_classes - 1, [args.batch_size]) | ||
|
||
return {"input_ids": input_ids, "labels": labels} | ||
|
||
def build_model(self, args, **kwargs): | ||
model = GPTForSequenceClassification.from_pretrained(args.model_name_or_path) | ||
return model | ||
|
||
def forward(self, model, args, input_data=None, **kwargs): | ||
res = model(**input_data) | ||
return ( | ||
res[0], | ||
paddle.sum((input_data["input_ids"] != model.config.pad_token_id)).numpy().astype("int64").item(), | ||
) | ||
|
||
def logger( | ||
self, | ||
args, | ||
step_id=None, | ||
pass_id=None, | ||
batch_id=None, | ||
loss=None, | ||
batch_cost=None, | ||
reader_cost=None, | ||
num_samples=None, | ||
ips=None, | ||
**kwargs | ||
): | ||
logger.info( | ||
"global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f words/sec" | ||
% (step_id, args.epoch * self.num_batch, loss, reader_cost, batch_cost, num_samples, ips) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
tests/test_tipc/benchmark/modules/t5_for_conditional_generation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
|
||
from paddlenlp.transformers import T5ForConditionalGeneration | ||
from paddlenlp.utils.log import logger | ||
|
||
from .benchmark_utils import rand_int_tensor | ||
from .model_base import BenchmarkBase | ||
|
||
|
||
class T5ForConditionalGenerationBenchmark(BenchmarkBase): | ||
def __init__(self): | ||
self.label_list = None | ||
super().__init__() | ||
|
||
@staticmethod | ||
def add_args(args, parser): | ||
parser.add_argument( | ||
"--model_name_or_path", type=str, default="t5-small", help="Model name. Defaults to t5-small. " | ||
) | ||
# args.max_seq_len | ||
|
||
def create_data_loader(self, args, **kwargs): | ||
raise NotImplementedError( | ||
"t5_for_conditional_genneration's DataLoader is not implemented. Please use --generated_inputs. " | ||
) | ||
|
||
def create_input_specs(self): | ||
input_ids = paddle.static.InputSpec(name="input_ids", shape=[-1, -1], dtype="int64") | ||
decoder_input_ids = paddle.static.InputSpec(name="decoder_input_ids", shape=[-1, -1], dtype="int64") | ||
labels = paddle.static.InputSpec(name="labels", shape=[-1, -1], dtype="int64") | ||
return [input_ids, None, decoder_input_ids, None, None, None, labels] | ||
|
||
def generate_inputs_for_model(self, args, model, **kwargs): | ||
input_ids = rand_int_tensor(0, model.config.vocab_size, [args.batch_size, args.max_seq_len]) | ||
decoder_input_ids = input_ids | ||
labels = rand_int_tensor(0, model.config.vocab_size - 1, [args.batch_size, args.max_seq_len]) | ||
|
||
return {"input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "labels": labels} | ||
|
||
def build_model(self, args, **kwargs): | ||
model = T5ForConditionalGeneration.from_pretrained(args.model_name_or_path) | ||
return model | ||
|
||
def forward(self, model, args, input_data=None, **kwargs): | ||
res = model(**input_data) | ||
return ( | ||
res[0], | ||
paddle.sum((input_data["input_ids"] != model.config.pad_token_id)).numpy().astype("int64").item(), | ||
) | ||
|
||
def logger( | ||
self, | ||
args, | ||
step_id=None, | ||
pass_id=None, | ||
batch_id=None, | ||
loss=None, | ||
batch_cost=None, | ||
reader_cost=None, | ||
num_samples=None, | ||
ips=None, | ||
**kwargs | ||
): | ||
logger.info( | ||
"global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f words/sec" | ||
% (step_id, args.epoch * self.num_batch, loss, reader_cost, batch_cost, num_samples, ips) | ||
) |
Oops, something went wrong.