Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer does not switch to train mode after validation step #20177

Open
ClemensSchwarke opened this issue Aug 7, 2024 · 2 comments
Open

Trainer does not switch to train mode after validation step #20177

ClemensSchwarke opened this issue Aug 7, 2024 · 2 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x ver: 2.4.x

Comments

@ClemensSchwarke
Copy link

ClemensSchwarke commented Aug 7, 2024

Bug description

After the validation step, the model is not set back to train mode because the following hook

def on_validation_model_train(self) -> None:
     """Called when the validation loop ends.

     The validation loop by default restores the `training` mode of the LightningModule to what it was before
     starting validation. Override this hook to change the behavior. See also
     :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval`.

     """
     # The loop won't call this hook unless it is overridden. The line below is here in case the user calls super().
     self.trainer.model.train()

is not called as can be seen here:

 def _on_evaluation_model_train(self) -> None:
        """Undoes the eval mode."""
        trainer = self.trainer
        hook_name = "on_test_model_train" if trainer.testing else "on_validation_model_train"
        if is_overridden(hook_name, trainer.lightning_module):
            call._call_lightning_module_hook(trainer, hook_name)
        else:
            self._module_mode.restore(trainer.lightning_module)

I don't see the point of this behavior and I think it is very likely to cause bugs if people start implementing their own mode logic and expect it to be called. Would be happy to understand this (:

What version are you seeing the problem on?

v2.4

How to reproduce the bug

def train(self, mode=True):
  super().train(mode)
  if mode:
    print("set to training mode")
  else:
    print("set to evaluation mode")

Error messages and logs

Epoch 0: 100%
set to evaluation mode
Epoch 1:  27%

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

@ClemensSchwarke ClemensSchwarke added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 7, 2024
@ClemensSchwarke ClemensSchwarke changed the title Trainer does not switch to train mode after validation step. Trainer does not switch to train mode after validation step Aug 7, 2024
@awaelchli
Copy link
Contributor

The implementation of how the mode is captured and restored is here:

class _ModuleMode:
"""Captures the ``nn.Module.training`` (bool) mode of every submodule, and allows it to be restored later on."""
def __init__(self) -> None:
self.mode: Dict[str, bool] = {}
def capture(self, module: nn.Module) -> None:
self.mode.clear()
for name, mod in module.named_modules():
self.mode[name] = mod.training
def restore(self, module: nn.Module) -> None:
for name, mod in module.named_modules():
if name not in self.mode:
_log.debug(
f"Restoring training mode on module '{name}' not possible, it was never captured."
f" Is your module structure changing?"
)
continue
mod.training = self.mode[name]

As you can see, it's on a per-module basis, and we don't call .train()/.eval() because that would set the mode on children, and we don't want that.

You can check that the model is in training mode/eval mode in your training_step or validation step through the module.training boolean.

@ClemensSchwarke
Copy link
Author

Thanks for the explanation! I guess it makes sense to not override the mode of children. However, how would one then implement my print example for a module? I.e. how could I make sure the state of a module is actually changed as opposed to just the boolean flag? I want to use a different image augmentation in my module depending on the mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

2 participants