Skip to content

Commit

Permalink
Merge pull request #10 from microsoft/raviskolli/ort
Browse files Browse the repository at this point in the history
Update for ORTModule package
  • Loading branch information
raviskolli committed Apr 28, 2021
2 parents e522c75 + 76cbfc1 commit ae1411f
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,20 +1079,18 @@ def train(
num_update_steps_per_epoch = max_steps

delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
model = self.model
if self.args.ort:
from onnxruntime.training import ORTModule
if args.ort:
from onnxruntime.training.ortmodule import ORTModule
logger.info("Converting to ORTModule ....")
model = ORTModule(self.model)
self.model_wrapped = model
if self.args.deepspeed:
if self.args.ort:
self.model = model
if args.deepspeed:
if args.ort:
self.model = model
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
)
self.model = deepspeed_engine.module._original_module if self.args.ort else deepspeed_engine.module
self.model = deepspeed_engine.module._original_module if args.ort else deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
Expand Down

0 comments on commit ae1411f

Please sign in to comment.