-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
Hi,
I have noticed that the train_dataloader()
's workers were still up, idle but withholding resources, whilst the val_dataloader()
's would be actively delivering batches.
After some investigation, I found the following pseudo-code describing fit()
, here simplified:
def fit(self):
[...]
for epoch in epochs:
fit_loop()
[...]
def fit_loop():
[...]
for batch in train_dataloader():
[...]
if should_check_val:
val_loop()
[...]
[...]
def val_loop():
[...]
for batch in val_dataloader():
[...]
[...]
And the actual behaviour matches the pseudo code, so this is not a bug and is working as intended.
However, I've been struggling to maintain the equilibrium between data processing speed and memory footprint when running instance segmentation runs on large and dense non-public datasets.
I understand that when val_check_interval
is different than None
, running the val_loop
within the train_dataloader()
loop is necessary. However, in when the val_check_interval
is None
, I think that it would be beneficial to modify the fit_loop()
to something like,
def fit_loop():
[...]
for batch in train_dataloader():
[...]
if should_check_val and val_check_interval is not None:
val_loop()
[...]
[...]
if should_check_val and val_check_interval is None:
val_loop()
[...]
That way resources would be freed as soon as they're not needed.
Pitch
Within the implementation, the val_loop()
is called within on_advance_end()
, and the fit_loop()
within run()
is considerably different than the pseudo-code.
I'm assuming that we need to modify and re-use on_advance_end()
after the completion of the while
-loop in run()
.
Is this correct?
Alternatives
No response
Additional context
I have made this boring.py
to illustrate the situation and have a concrete example to debug on,
import torch
from torch import Tensor
from torch.nn import Linear, MSELoss
from torch.optim import AdamW
from torch.utils.data import ConcatDataset, Dataset, DataLoader
from torchmetrics import regression
from lightning.pytorch import LightningModule, LightningDataModule, Trainer
class BoringDataset(Dataset):
def __init__(self, num_samples: int):
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index: int) -> dict[str, Tensor]:
x = torch.randn(1, dtype=torch.float32)
y = 5.0 * x + 2.0
return {"x": x, "y": y}
class BoringDataModule(LightningDataModule):
train_datasets: list[BoringDataset]
val_datasets: list[BoringDataset]
test_datasets: list[BoringDataset]
predict_datasets: list[BoringDataset]
def __init__(
self, num_datasets: int, num_samples: int, batch_size: int, num_workers: int
):
super().__init__()
self.num_datasets = num_datasets
self.num_samples = num_samples
self.batch_size = batch_size
self.num_workers = num_workers
def prepare_data(self):
pass
def setup(self, stage: str | None = None):
assert stage in ["all", "fit", "validate", "test", "predict", None]
if stage in ["fit", "all"]:
self.train_datasets = [
BoringDataset(num_samples=self.num_samples)
for _ in range(self.num_datasets)
]
if stage in ["fit", "validate", "all"]:
self.val_datasets = [
BoringDataset(num_samples=self.num_samples)
for _ in range(self.num_datasets)
]
if stage in ["test", "all"]:
self.test_datasets = [
BoringDataset(num_samples=self.num_samples)
for _ in range(self.num_datasets)
]
if stage in ["predict", "all"]:
self.predict_datasets = [
BoringDataset(num_samples=self.num_samples)
for _ in range(self.num_datasets)
]
def teardown(self, stage: str | None = None):
assert stage in ["all", "fit", "validate", "test", "predict", None]
if stage in ["fit", "all"]:
del self.train_datasets
if stage in ["fit", "validate", "all"]:
del self.val_datasets
if stage in ["test", "all"]:
del self.test_datasets
if stage in ["predict", "all"]:
del self.predict_datasets
def train_dataloader(
self,
) -> DataLoader:
kwargs = {
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"pin_memory": False,
"drop_last": True,
"persistent_workers": False,
"shuffle": True,
}
dataloader = DataLoader(ConcatDataset(self.train_datasets), **kwargs)
return dataloader
def val_dataloader(self) -> list[DataLoader]:
kwargs = {
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"pin_memory": False,
"drop_last": False,
"persistent_workers": False,
"shuffle": False,
}
dataloaders = [DataLoader(dataset, **kwargs) for dataset in self.val_datasets]
return dataloaders
def test_dataloader(self) -> list[DataLoader]:
kwargs = {
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"pin_memory": False,
"drop_last": False,
"persistent_workers": False,
"shuffle": False,
}
dataloaders = [DataLoader(dataset, **kwargs) for dataset in self.test_datasets]
return dataloaders
def predict_dataloader(self) -> list[DataLoader]:
kwargs = {
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"pin_memory": False,
"drop_last": False,
"persistent_workers": False,
"shuffle": False,
}
dataloaders = [
DataLoader(dataset, **kwargs) for dataset in self.predict_datasets
]
return dataloaders
class BoringModule(LightningModule):
val_dataloader_idx: int = 0
test_dataloader_idx: int = 0
predict_dataloader_idx: int = 0
def __init__(
self, num_datasets: int, num_samples: int, batch_size: int, num_workers: int
):
super().__init__()
self.num_datasets = num_datasets
self.num_samples = num_samples
self.batch_size = batch_size
self.num_workers = num_workers
def prepare_data(self):
pass
def setup(self, stage: str | None = None):
assert stage in ["all", "fit", "validate", "test", "predict", None]
self.datamodule = BoringDataModule(
num_datasets=self.num_datasets,
num_samples=self.num_samples,
batch_size=self.batch_size,
num_workers=self.num_workers,
)
self.datamodule.setup(stage=stage)
if stage in ["fit", "all"]:
self.loss_function = MSELoss()
self.train_metric = regression.MeanSquaredError()
if stage in ["fit", "validate", "all"]:
self.val_metric = regression.MeanSquaredError()
if stage in ["test", "all"]:
self.test_metric = regression.MeanSquaredError()
if stage in ["predict", "all"]:
self.predict_metric = regression.MeanSquaredError()
def configure_model(self):
self.model = Linear(in_features=1, out_features=1, bias=True)
def teardown(self, stage: str | None = None):
assert stage in ["fit", "validate", "test", "predict", "all", None]
self.datamodule.teardown(stage=stage)
del self.datamodule
del self.model
if stage in ["fit", "all"]:
del self.loss_function
del self.train_metric
if stage in ["fit", "validate", "all"]:
del self.val_metric
if stage in ["test", "all"]:
del self.test_metric
if stage in ["predict", "all"]:
del self.predict_metric
def train_dataloader(self) -> DataLoader:
return self.datamodule.train_dataloader()
def val_dataloader(self) -> list[DataLoader]:
return self.datamodule.val_dataloader()
def test_dataloader(self) -> list[DataLoader]:
return self.datamodule.test_dataloader()
def predict_dataloader(self) -> list[DataLoader]:
return self.datamodule.predict_dataloader()
def forward(self, input: dict) -> dict:
return {
"y": self.model(input["x"]),
}
def training_step(
self,
input: dict,
batch_idx: int,
) -> Tensor:
output = self(input)
train_loss = self.loss_function(input=output["y"], target=input["y"])
self.train_metric.update(preds=output["y"], target=input["y"])
self.log_dict(
dictionary={"train_loss": train_loss},
prog_bar=True,
sync_dist=not self.training,
add_dataloader_idx=False,
)
self.log_dict(
dictionary={"train_metric": self.train_metric},
sync_dist=not self.training,
add_dataloader_idx=False,
)
return train_loss
def validation_step(
self,
input: dict,
batch_idx: int,
dataloader_idx: int = 0,
):
if self.val_dataloader_idx != dataloader_idx:
self.val_dataloader_idx = dataloader_idx
self.val_metric.reset()
output = self(input)
self.val_metric.update(preds=output["y"], target=input["y"])
self.log_dict(
dictionary={f"val_metric/{dataloader_idx}": self.val_metric},
sync_dist=not self.training,
add_dataloader_idx=False,
)
def test_step(
self,
input: dict,
batch_idx: int,
dataloader_idx: int = 0,
):
if self.test_dataloader_idx != dataloader_idx:
self.test_dataloader_idx = dataloader_idx
self.test_metric.reset()
output = self(input)
self.test_metric.update(preds=output["y"], target=input["y"])
self.log_dict(
dictionary={f"test_metric/{dataloader_idx}": self.test_metric},
sync_dist=not self.training,
add_dataloader_idx=False,
)
def predict_step(
self,
input: dict,
batch_idx: int,
dataloader_idx: int = 0,
):
if self.predict_dataloader_idx != dataloader_idx:
self.predict_dataloader_idx = dataloader_idx
self.predict_metric.reset()
output = self(input)
self.predict_metric.update(preds=output["y"], target=input["y"])
self.log_dict(
dictionary={f"predict_metric/{dataloader_idx}": self.predict_metric},
sync_dist=not self.training,
add_dataloader_idx=False,
)
def configure_optimizers(self):
return {
"optimizer": AdamW(
self.model.parameters(),
lr=1e-1,
),
}
def main():
module = BoringModule(
num_datasets=2,
num_samples=10000,
batch_size=32,
num_workers=1,
)
trainer = Trainer(
logger=True,
max_epochs=10,
num_sanity_val_steps=0,
log_every_n_steps=1,
gradient_clip_val=1.0,
benchmark=True,
detect_anomaly=False,
sync_batchnorm=True,
# reload_dataloaders_every_n_epochs=0, # Neither of those two options have any effect
# reload_dataloaders_every_n_epochs=1, # on the lifetime of the train_dataloader()'s workers
)
trainer.fit(model=module)
if __name__ == "__main__":
main()
cc @Borda