Skip to content

Commit

Permalink
Updated Fabric trainer example to not call self.trainer.model durin…
Browse files Browse the repository at this point in the history
…g validation (#19993)
  • Loading branch information
liambsmith committed Jun 21, 2024
1 parent 5981aeb commit 709a2a9
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions examples/fabric/build_your_own_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def val_loop(
val_loader: Optional[torch.utils.data.DataLoader],
limit_batches: Union[int, float] = float("inf"),
):
"""The validation loop ruunning a single validation epoch.
"""The validation loop running a single validation epoch.
Args:
model: the LightningModule to evaluate
Expand All @@ -285,7 +285,10 @@ def val_loop(
)
return

self.fabric.call("on_validation_model_eval") # calls `model.eval()`
if not is_overridden("on_validation_model_eval", _unwrap_objects(model)):
model.eval()
else:
self.fabric.call("on_validation_model_eval") # calls `model.eval()`

torch.set_grad_enabled(False)

Expand All @@ -311,7 +314,10 @@ def val_loop(

self.fabric.call("on_validation_epoch_end")

self.fabric.call("on_validation_model_train")
if not is_overridden("on_validation_model_train", _unwrap_objects(model)):
model.train()
else:
self.fabric.call("on_validation_model_train")
torch.set_grad_enabled(True)

def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor:
Expand Down

0 comments on commit 709a2a9

Please sign in to comment.