From 452b84c2058d1b5adf914ba8532e71874cd2ed4f Mon Sep 17 00:00:00 2001 From: Ravi shankar Kolli Date: Wed, 28 Apr 2021 09:13:19 -0700 Subject: [PATCH 1/2] Update for ORTModule package --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f717c7f62a534c..7c608bf69de121 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1081,7 +1081,7 @@ def train( delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE model = self.model if self.args.ort: - from onnxruntime.training import ORTModule + from onnxruntime.training.ortmodule import ORTModule logger.info("Converting to ORTModule ....") model = ORTModule(self.model) self.model_wrapped = model From 76cbfc186c28528e50df26682b358ecdaf2fa2f5 Mon Sep 17 00:00:00 2001 From: Ravi shankar Kolli Date: Wed, 28 Apr 2021 10:51:33 -0700 Subject: [PATCH 2/2] Cleanup deepspeed and ort code paths --- src/transformers/trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7c608bf69de121..f35bee6936eb8a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1079,20 +1079,18 @@ def train( num_update_steps_per_epoch = max_steps delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE - model = self.model - if self.args.ort: + if args.ort: from onnxruntime.training.ortmodule import ORTModule logger.info("Converting to ORTModule ....") model = ORTModule(self.model) self.model_wrapped = model - if self.args.deepspeed: - if self.args.ort: - self.model = model if args.deepspeed: + if args.ort: + self.model = model deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint ) - self.model = deepspeed_engine.module._original_module if self.args.ort else deepspeed_engine.module + self.model = deepspeed_engine.module._original_module if args.ort else deepspeed_engine.module self.model_wrapped = deepspeed_engine self.deepspeed = deepspeed_engine self.optimizer = optimizer