Skip to content

Commit

Permalink
[MLU] support SQuAD_Bert with mlu device (#3434)
Browse files Browse the repository at this point in the history
  • Loading branch information
qipengh authored Oct 12, 2022
1 parent 907144f commit 0b4985d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 10 deletions.
23 changes: 22 additions & 1 deletion examples/machine_reading_comprehension/SQuAD/args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse


Expand Down Expand Up @@ -78,7 +92,7 @@ def parse_args():
help="random seed for initialization")
parser.add_argument(
'--device',
choices=['cpu', 'gpu'],
choices=['cpu', 'gpu', 'mlu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument(
Expand Down Expand Up @@ -131,5 +145,12 @@ def parse_args():
parser.add_argument("--do_predict",
action='store_true',
help="Whether to predict.")
parser.add_argument("--use_amp",
action='store_true',
help="Whether to use AMP.")
parser.add_argument("--scale_loss",
type=float,
default=2**15,
help="The value of scale_loss for fp16.")
args = parser.parse_args()
return args
37 changes: 28 additions & 9 deletions examples/machine_reading_comprehension/SQuAD/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,27 +288,46 @@ def run(args):
apply_decay_param_fun=lambda x: x in decay_params)
criterion = CrossEntropyLossForSQuAD()

if args.use_amp:
scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)

global_step = 0
tic_train = time.time()

for epoch in range(num_train_epochs):
for step, batch in enumerate(train_data_loader):
global_step += 1
logits = model(input_ids=batch['input_ids'],
token_type_ids=batch['token_type_ids'],
attention_mask=batch['attention_mask'])
loss = criterion(
logits, (batch['start_positions'], batch['end_positions']))
if args.use_amp:
with paddle.amp.auto_cast(
args.use_amp,
custom_white_list=["layer_norm", "softmax",
"gelu"]):
logits = model(input_ids=batch['input_ids'],
token_type_ids=batch['token_type_ids'],
attention_mask=batch['attention_mask'])
loss = criterion(
logits,
(batch['start_positions'], batch['end_positions']))
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
logits = model(input_ids=batch['input_ids'],
token_type_ids=batch['token_type_ids'],
attention_mask=batch['attention_mask'])
loss = criterion(
logits,
(batch['start_positions'], batch['end_positions']))
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()

if global_step % args.logging_steps == 0:
print(
"global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
% (global_step, epoch + 1, step + 1, loss,
args.logging_steps / (time.time() - tic_train)))
tic_train = time.time()
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()

if global_step % args.save_steps == 0 or global_step == num_training_steps:
if rank == 0:
Expand Down

0 comments on commit 0b4985d

Please sign in to comment.