Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#193 from PaddlePaddle/model_check_for…
Browse files Browse the repository at this point in the history
…_bert

Add cuda check for bert
  • Loading branch information
Yibing Liu committed Jul 6, 2019
2 parents 67edb3e + 5d94fd1 commit c0c669a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 4 deletions.
5 changes: 3 additions & 2 deletions BERT/run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from model.bert import BertConfig
from model.classifier import create_model
from optimization import optimization
from utils.args import ArgumentGroup, print_arguments
from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.init import init_pretraining_params, init_checkpoint
import dist_utils

Expand Down Expand Up @@ -281,7 +281,7 @@ def main(args):
exec_strategy=exec_strategy,
build_strategy = build_strategy,
main_program=train_program)

train_pyreader.decorate_tensor_provider(train_data_generator)
else:
train_exe = None
Expand Down Expand Up @@ -415,4 +415,5 @@ def main(args):

if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
main(args)
3 changes: 2 additions & 1 deletion BERT/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from reader.squad import DataProcessor, write_predictions
from model.bert import BertConfig, BertModel
from utils.args import ArgumentGroup, print_arguments
from utils.args import ArgumentGroup, print_arguments, check_cuda
from optimization import optimization
from utils.init import init_pretraining_params, init_checkpoint

Expand Down Expand Up @@ -424,4 +424,5 @@ def train(args):

if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
train(args)
3 changes: 2 additions & 1 deletion BERT/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from reader.pretraining import DataReader
from model.bert import BertModel, BertConfig
from optimization import optimization
from utils.args import ArgumentGroup, print_arguments
from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.init import init_checkpoint, init_pretraining_params

# yapf: disable
Expand Down Expand Up @@ -418,6 +418,7 @@ def train(args):

if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
if args.do_test:
test(args)
else:
Expand Down
11 changes: 11 additions & 0 deletions BERT/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ def print_arguments(args):
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')

def check_cuda(use_cuda, err = \
"\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
):
try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err)
sys.exit(1)
except Exception as e:
pass

0 comments on commit c0c669a

Please sign in to comment.