Skip to content

Commit

Permalink
Add get_model method in Trainer (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
bepuca committed Oct 14, 2022
1 parent 34cd222 commit e26e219
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ Utility Methods
.. automethod:: Trainer.load_checkpoint
.. automethod:: Trainer.print
.. automethod:: Trainer.gather
.. automethod:: Trainer.get_model

Customizing Trainer Behaviour
================================
Expand Down
4 changes: 2 additions & 2 deletions pytorch_accelerated/schedulers/cosine_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def get_updated_values(self, num_updates: int):
1
+ math.cos(
math.pi
* num_updates ** self.k_decay
/ total_cosine_iterations ** self.k_decay
* num_updates**self.k_decay
/ total_cosine_iterations**self.k_decay
)
)
for lr_max in self.base_lr_values
Expand Down
17 changes: 12 additions & 5 deletions pytorch_accelerated/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def _create_run_config(
else False,
"mixed_precision": self._accelerator.mixed_precision,
"gradient_clip_value": gradient_clip_value,
"num_processes": self._accelerator.num_processes
"num_processes": self._accelerator.num_processes,
}

return TrainerRunConfig(**config)
Expand Down Expand Up @@ -884,7 +884,7 @@ def save_checkpoint(
# TODO: add save method for run history?

checkpoint = {
"model_state_dict": self._accelerator.unwrap_model(self.model).state_dict(),
"model_state_dict": self.get_model().state_dict(),
}

if save_optimizer:
Expand Down Expand Up @@ -918,9 +918,7 @@ def load_checkpoint(self, checkpoint_path, load_optimizer=True):
"""
self._accelerator.wait_for_everyone()
checkpoint = torch.load(checkpoint_path, map_location="cpu")
self._accelerator.unwrap_model(self.model).load_state_dict(
checkpoint["model_state_dict"]
)
self.get_model().load_state_dict(checkpoint["model_state_dict"])
if load_optimizer and "optimizer_state_dict" in checkpoint:
if self.optimizer is None:
raise ValueError(
Expand All @@ -930,6 +928,15 @@ def load_checkpoint(self, checkpoint_path, load_optimizer=True):
)
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

def get_model(self):
"""
Extract the model in :class:`Trainer` from its distributed containers.
Useful before saving a model.
:return: the model in :class:`Trainer`, subclassed from :class:`~torch.nn.Module`
"""
return self._accelerator.unwrap_model(self.model)


class TrainerWithTimmScheduler(Trainer):
"""Subclass of the :class:`Trainer` that works with `timm schedulers <https://fastai.github.io/timmdocs/schedulers>`_ instead
Expand Down
2 changes: 1 addition & 1 deletion requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
black
versioneer==0.21
pytest==1.11.0
pytest==7.1.3
pytest-mock==3.6.1
File renamed without changes.
2 changes: 1 addition & 1 deletion test/placeholders.py → test/test_placeholders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_run_config(
is_local_process_zero=True,
is_world_process_zero=True,
is_distributed=True,
mixed_precision='fp16',
mixed_precision="fp16",
num_processes=1,
num_update_steps_per_epoch=num_update_steps_per_epoch,
)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit e26e219

Please sign in to comment.