Skip to content

Commit

Permalink
[Benchmark] Add BERT GPT T5 finetune process (#4492)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrostML committed Feb 1, 2023
1 parent 6b0b8f9 commit a8a8338
Show file tree
Hide file tree
Showing 11 changed files with 744 additions and 23 deletions.
4 changes: 3 additions & 1 deletion paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,9 @@ def forward(
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits.gather_nd(paddle.stack([paddle.arange(logits.shape[0]), sequence_lengths], axis=-1))
pooled_logits = logits.gather_nd(
paddle.stack([paddle.arange(paddle.shape(logits)[0]), sequence_lengths], axis=-1)
)

loss = None
if labels is not None:
Expand Down
63 changes: 63 additions & 0 deletions tests/test_tipc/benchmark/modules/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle


def rand_int_tensor(low, high, shape):
return paddle.randint(
low,
high,
shape=shape,
dtype=paddle.int64,
)


def clone_tensor(x):
y = x.clone()
return y


def clone_input(x):
def paddle_clone(x):
y = paddle.clone(x)
if x.is_leaf:
y.stop_gradient = x.stop_gradient
if x.is_leaf and x.grad is not None:
y.grad = clone_input(x.grad)
return y

with paddle.no_grad():
result = paddle.empty(x.shape, dtype=x.dtype)
result.copy_(x.clone(), True)
if x.is_leaf:
result.stop_gradient = x.stop_gradient
if x.is_leaf and x.grad is not None:
result.grad = clone_input(x.grad)
return result


def clone_inputs(example_inputs):
if isinstance(example_inputs, dict):
res = dict(example_inputs)
for key, value in res.items():
assert isinstance(value, paddle.Tensor)
res[key] = clone_input(value)
return res

res = list(example_inputs)
for i in range(len(res)):
if isinstance(res[i], paddle.Tensor):
res[i] = clone_input(res[i])
return res
106 changes: 106 additions & 0 deletions tests/test_tipc/benchmark/modules/bert_for_question_answering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from paddlenlp.transformers import BertForQuestionAnswering
from paddlenlp.utils.log import logger

from .benchmark_utils import rand_int_tensor
from .model_base import BenchmarkBase


class BertForQuestionAnsweringBenchmark(BenchmarkBase):
def __init__(self):
self.label_list = None
super().__init__()

@staticmethod
def add_args(args, parser):
parser.add_argument(
"--model_name_or_path", type=str, default="bert-base-cased", help="Model name. Defaults to bert-base. "
)
# args.max_seq_len

def create_data_loader(self, args, **kwargs):
raise NotImplementedError(
"bert_for_question_answering's DataLoader is not implemented. Please use --generated_inputs. "
)

def generate_inputs_for_model(self, args, model, **kwargs):
input_ids = rand_int_tensor(0, model.config.vocab_size, [args.batch_size, args.max_seq_len])
start_positions = rand_int_tensor(
0,
args.max_seq_len,
[
args.batch_size,
],
)
end_positions = rand_int_tensor(
0,
args.max_seq_len,
[
args.batch_size,
],
)
return {"input_ids": input_ids, "start_positions": start_positions, "end_positions": end_positions}

def build_model(self, args, **kwargs):
model = BertForQuestionAnswering.from_pretrained(args.model_name_or_path)
return model

def forward(self, model, args, input_data=None, **kwargs):
start_positions = input_data.pop("start_positions")
end_positions = input_data.pop("end_positions")
start_logits, end_logits = model(**input_data)

total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if start_positions.ndim > 1:
start_positions = start_positions.squeeze(-1)
if start_positions.ndim > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = paddle.shape(start_logits)[1]
start_positions = start_positions.clip(0, ignored_index)
end_positions = end_positions.clip(0, ignored_index)

loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2

return (
total_loss,
paddle.sum((input_data["input_ids"] != model.config.pad_token_id)).numpy().astype("int64").item(),
)

def logger(
self,
args,
step_id=None,
pass_id=None,
batch_id=None,
loss=None,
batch_cost=None,
reader_cost=None,
num_samples=None,
ips=None,
**kwargs
):
logger.info(
"global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f words/sec"
% (step_id, args.epoch * self.num_batch, loss, reader_cost, batch_cost, num_samples, ips)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from paddlenlp.transformers import GPTForSequenceClassification
from paddlenlp.utils.log import logger

from .benchmark_utils import rand_int_tensor
from .model_base import BenchmarkBase


class GPTForSequenceClassificationBenchmark(BenchmarkBase):
def __init__(self):
self.label_list = None
super().__init__()

@staticmethod
def add_args(args, parser):
parser.add_argument(
"--model_name_or_path", type=str, default="gpt2-en", help="Model name. Defaults to gpt2-en. "
)
# args.max_seq_len

def create_data_loader(self, args, **kwargs):
raise NotImplementedError(
"gpt_for_sequence_classification's DataLoader is not implemented. Please use --generated_inputs. "
)

def create_input_specs(self):
input_ids = paddle.static.InputSpec(name="input_ids", shape=[-1, -1], dtype="int64")
labels = paddle.static.InputSpec(name="labels", shape=[-1], dtype="int64")
return [input_ids, None, None, None, labels]

def generate_inputs_for_model(self, args, model, **kwargs):
input_ids = rand_int_tensor(0, model.config.vocab_size, [args.batch_size, args.max_seq_len])
labels = rand_int_tensor(0, model.config.num_classes - 1, [args.batch_size])

return {"input_ids": input_ids, "labels": labels}

def build_model(self, args, **kwargs):
model = GPTForSequenceClassification.from_pretrained(args.model_name_or_path)
return model

def forward(self, model, args, input_data=None, **kwargs):
res = model(**input_data)
return (
res[0],
paddle.sum((input_data["input_ids"] != model.config.pad_token_id)).numpy().astype("int64").item(),
)

def logger(
self,
args,
step_id=None,
pass_id=None,
batch_id=None,
loss=None,
batch_cost=None,
reader_cost=None,
num_samples=None,
ips=None,
**kwargs
):
logger.info(
"global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f words/sec"
% (step_id, args.epoch * self.num_batch, loss, reader_cost, batch_cost, num_samples, ips)
)
21 changes: 20 additions & 1 deletion tests/test_tipc/benchmark/modules/model_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
import paddle
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddlenlp.utils.log import logger


Expand All @@ -16,6 +29,12 @@ def create_data_loader(self, args, **kwargs):
def build_model(self, args, **kwargs):
raise NotImplementedError

def generate_inputs_for_model(self, args, **kwargs):
raise NotImplementedError

def create_input_specs(self):
return None

def forward(self, model, args, input_data=None, **kwargs):
raise NotImplementedError

Expand Down
81 changes: 81 additions & 0 deletions tests/test_tipc/benchmark/modules/t5_for_conditional_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from paddlenlp.transformers import T5ForConditionalGeneration
from paddlenlp.utils.log import logger

from .benchmark_utils import rand_int_tensor
from .model_base import BenchmarkBase


class T5ForConditionalGenerationBenchmark(BenchmarkBase):
def __init__(self):
self.label_list = None
super().__init__()

@staticmethod
def add_args(args, parser):
parser.add_argument(
"--model_name_or_path", type=str, default="t5-small", help="Model name. Defaults to t5-small. "
)
# args.max_seq_len

def create_data_loader(self, args, **kwargs):
raise NotImplementedError(
"t5_for_conditional_genneration's DataLoader is not implemented. Please use --generated_inputs. "
)

def create_input_specs(self):
input_ids = paddle.static.InputSpec(name="input_ids", shape=[-1, -1], dtype="int64")
decoder_input_ids = paddle.static.InputSpec(name="decoder_input_ids", shape=[-1, -1], dtype="int64")
labels = paddle.static.InputSpec(name="labels", shape=[-1, -1], dtype="int64")
return [input_ids, None, decoder_input_ids, None, None, None, labels]

def generate_inputs_for_model(self, args, model, **kwargs):
input_ids = rand_int_tensor(0, model.config.vocab_size, [args.batch_size, args.max_seq_len])
decoder_input_ids = input_ids
labels = rand_int_tensor(0, model.config.vocab_size - 1, [args.batch_size, args.max_seq_len])

return {"input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "labels": labels}

def build_model(self, args, **kwargs):
model = T5ForConditionalGeneration.from_pretrained(args.model_name_or_path)
return model

def forward(self, model, args, input_data=None, **kwargs):
res = model(**input_data)
return (
res[0],
paddle.sum((input_data["input_ids"] != model.config.pad_token_id)).numpy().astype("int64").item(),
)

def logger(
self,
args,
step_id=None,
pass_id=None,
batch_id=None,
loss=None,
batch_cost=None,
reader_cost=None,
num_samples=None,
ips=None,
**kwargs
):
logger.info(
"global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f words/sec"
% (step_id, args.epoch * self.num_batch, loss, reader_cost, batch_cost, num_samples, ips)
)
Loading

0 comments on commit a8a8338

Please sign in to comment.