From 470a8ff234a3026c80f6958561fbaa663810b39a Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Mon, 13 Dec 2021 12:00:02 +0800 Subject: [PATCH] fix max_steps of msra_ner. --- examples/information_extraction/msra_ner/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/information_extraction/msra_ner/train.py b/examples/information_extraction/msra_ner/train.py index fa35d2fdf62d..773be0fd2983 100644 --- a/examples/information_extraction/msra_ner/train.py +++ b/examples/information_extraction/msra_ner/train.py @@ -42,7 +42,7 @@ parser = argparse.ArgumentParser() # yapf: disable -parser.add_argument("--model_type", default="bert", type=str, required=True, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), ) +parser.add_argument("--model_type", default="bert", type=str, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), ) parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join( sum([ list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values() ], [])), ) parser.add_argument("--dataset", default="msra_ner", type=str, choices=["msra_ner", "peoples_daily_ner"] ,help="The named entity recognition datasets.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") @@ -218,7 +218,7 @@ def do_train(args): optimizer.step() lr_scheduler.step() optimizer.clear_grad() - if global_step % args.save_steps == 0 or global_step == last_step: + if global_step % args.save_steps == 0 or global_step == num_training_steps: if paddle.distributed.get_rank() == 0: if args.dataset == "peoples_daily_ner": evaluate(model, loss_fct, metric, dev_data_loader, @@ -229,6 +229,8 @@ def do_train(args): paddle.save(model.state_dict(), os.path.join(args.output_dir, "model_%d.pdparams" % global_step)) + if global_step >= num_training_steps: + return if __name__ == "__main__":