Skip to content

Optimize fit_loop() to reduce train_dataloader()'s memory footprint #20382

@guillaume-rochette-oxb

Description

@guillaume-rochette-oxb

Description & Motivation

Hi,

I have noticed that the train_dataloader()'s workers were still up, idle but withholding resources, whilst the val_dataloader()'s would be actively delivering batches.
After some investigation, I found the following pseudo-code describing fit(), here simplified:

def fit(self):
    [...]
    for epoch in epochs:
        fit_loop()
    [...]

def fit_loop():
    [...]
    for batch in train_dataloader():
        [...]
        if should_check_val:
            val_loop()
        [...]
    [...]

def val_loop():
    [...]
    for batch in val_dataloader():
        [...]
    [...]

And the actual behaviour matches the pseudo code, so this is not a bug and is working as intended.

However, I've been struggling to maintain the equilibrium between data processing speed and memory footprint when running instance segmentation runs on large and dense non-public datasets.

I understand that when val_check_interval is different than None, running the val_loop within the train_dataloader() loop is necessary. However, in when the val_check_interval is None, I think that it would be beneficial to modify the fit_loop() to something like,

def fit_loop():
    [...]
    for batch in train_dataloader():
        [...]
        if should_check_val and val_check_interval is not None:
            val_loop()
        [...]
    [...]
    if should_check_val and val_check_interval is None:
        val_loop()
    [...]

That way resources would be freed as soon as they're not needed.

Pitch

Within the implementation, the val_loop() is called within on_advance_end(), and the fit_loop() within run() is considerably different than the pseudo-code.
I'm assuming that we need to modify and re-use on_advance_end() after the completion of the while-loop in run().

Is this correct?

Alternatives

No response

Additional context

I have made this boring.py to illustrate the situation and have a concrete example to debug on,

import torch
from torch import Tensor
from torch.nn import Linear, MSELoss
from torch.optim import AdamW
from torch.utils.data import ConcatDataset, Dataset, DataLoader

from torchmetrics import regression

from lightning.pytorch import LightningModule, LightningDataModule, Trainer


class BoringDataset(Dataset):
    def __init__(self, num_samples: int):
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index: int) -> dict[str, Tensor]:
        x = torch.randn(1, dtype=torch.float32)
        y = 5.0 * x + 2.0
        return {"x": x, "y": y}


class BoringDataModule(LightningDataModule):
    train_datasets: list[BoringDataset]
    val_datasets: list[BoringDataset]
    test_datasets: list[BoringDataset]
    predict_datasets: list[BoringDataset]

    def __init__(
        self, num_datasets: int, num_samples: int, batch_size: int, num_workers: int
    ):
        super().__init__()
        self.num_datasets = num_datasets
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        pass

    def setup(self, stage: str | None = None):
        assert stage in ["all", "fit", "validate", "test", "predict", None]

        if stage in ["fit", "all"]:
            self.train_datasets = [
                BoringDataset(num_samples=self.num_samples)
                for _ in range(self.num_datasets)
            ]

        if stage in ["fit", "validate", "all"]:
            self.val_datasets = [
                BoringDataset(num_samples=self.num_samples)
                for _ in range(self.num_datasets)
            ]

        if stage in ["test", "all"]:
            self.test_datasets = [
                BoringDataset(num_samples=self.num_samples)
                for _ in range(self.num_datasets)
            ]

        if stage in ["predict", "all"]:
            self.predict_datasets = [
                BoringDataset(num_samples=self.num_samples)
                for _ in range(self.num_datasets)
            ]

    def teardown(self, stage: str | None = None):
        assert stage in ["all", "fit", "validate", "test", "predict", None]

        if stage in ["fit", "all"]:
            del self.train_datasets

        if stage in ["fit", "validate", "all"]:
            del self.val_datasets

        if stage in ["test", "all"]:
            del self.test_datasets

        if stage in ["predict", "all"]:
            del self.predict_datasets

    def train_dataloader(
        self,
    ) -> DataLoader:
        kwargs = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": False,
            "drop_last": True,
            "persistent_workers": False,
            "shuffle": True,
        }
        dataloader = DataLoader(ConcatDataset(self.train_datasets), **kwargs)
        return dataloader

    def val_dataloader(self) -> list[DataLoader]:
        kwargs = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": False,
            "drop_last": False,
            "persistent_workers": False,
            "shuffle": False,
        }
        dataloaders = [DataLoader(dataset, **kwargs) for dataset in self.val_datasets]
        return dataloaders

    def test_dataloader(self) -> list[DataLoader]:
        kwargs = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": False,
            "drop_last": False,
            "persistent_workers": False,
            "shuffle": False,
        }
        dataloaders = [DataLoader(dataset, **kwargs) for dataset in self.test_datasets]
        return dataloaders

    def predict_dataloader(self) -> list[DataLoader]:
        kwargs = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": False,
            "drop_last": False,
            "persistent_workers": False,
            "shuffle": False,
        }
        dataloaders = [
            DataLoader(dataset, **kwargs) for dataset in self.predict_datasets
        ]
        return dataloaders


class BoringModule(LightningModule):
    val_dataloader_idx: int = 0
    test_dataloader_idx: int = 0
    predict_dataloader_idx: int = 0

    def __init__(
        self, num_datasets: int, num_samples: int, batch_size: int, num_workers: int
    ):
        super().__init__()

        self.num_datasets = num_datasets
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        pass

    def setup(self, stage: str | None = None):
        assert stage in ["all", "fit", "validate", "test", "predict", None]

        self.datamodule = BoringDataModule(
            num_datasets=self.num_datasets,
            num_samples=self.num_samples,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )
        self.datamodule.setup(stage=stage)

        if stage in ["fit", "all"]:
            self.loss_function = MSELoss()
            self.train_metric = regression.MeanSquaredError()

        if stage in ["fit", "validate", "all"]:
            self.val_metric = regression.MeanSquaredError()

        if stage in ["test", "all"]:
            self.test_metric = regression.MeanSquaredError()

        if stage in ["predict", "all"]:
            self.predict_metric = regression.MeanSquaredError()

    def configure_model(self):
        self.model = Linear(in_features=1, out_features=1, bias=True)

    def teardown(self, stage: str | None = None):
        assert stage in ["fit", "validate", "test", "predict", "all", None]

        self.datamodule.teardown(stage=stage)
        del self.datamodule

        del self.model

        if stage in ["fit", "all"]:
            del self.loss_function
            del self.train_metric

        if stage in ["fit", "validate", "all"]:
            del self.val_metric

        if stage in ["test", "all"]:
            del self.test_metric

        if stage in ["predict", "all"]:
            del self.predict_metric

    def train_dataloader(self) -> DataLoader:
        return self.datamodule.train_dataloader()

    def val_dataloader(self) -> list[DataLoader]:
        return self.datamodule.val_dataloader()

    def test_dataloader(self) -> list[DataLoader]:
        return self.datamodule.test_dataloader()

    def predict_dataloader(self) -> list[DataLoader]:
        return self.datamodule.predict_dataloader()

    def forward(self, input: dict) -> dict:
        return {
            "y": self.model(input["x"]),
        }

    def training_step(
        self,
        input: dict,
        batch_idx: int,
    ) -> Tensor:
        output = self(input)

        train_loss = self.loss_function(input=output["y"], target=input["y"])
        self.train_metric.update(preds=output["y"], target=input["y"])

        self.log_dict(
            dictionary={"train_loss": train_loss},
            prog_bar=True,
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

        self.log_dict(
            dictionary={"train_metric": self.train_metric},
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

        return train_loss

    def validation_step(
        self,
        input: dict,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        if self.val_dataloader_idx != dataloader_idx:
            self.val_dataloader_idx = dataloader_idx
            self.val_metric.reset()

        output = self(input)

        self.val_metric.update(preds=output["y"], target=input["y"])

        self.log_dict(
            dictionary={f"val_metric/{dataloader_idx}": self.val_metric},
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

    def test_step(
        self,
        input: dict,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        if self.test_dataloader_idx != dataloader_idx:
            self.test_dataloader_idx = dataloader_idx
            self.test_metric.reset()

        output = self(input)

        self.test_metric.update(preds=output["y"], target=input["y"])

        self.log_dict(
            dictionary={f"test_metric/{dataloader_idx}": self.test_metric},
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

    def predict_step(
        self,
        input: dict,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        if self.predict_dataloader_idx != dataloader_idx:
            self.predict_dataloader_idx = dataloader_idx
            self.predict_metric.reset()

        output = self(input)

        self.predict_metric.update(preds=output["y"], target=input["y"])

        self.log_dict(
            dictionary={f"predict_metric/{dataloader_idx}": self.predict_metric},
            sync_dist=not self.training,
            add_dataloader_idx=False,
        )

    def configure_optimizers(self):
        return {
            "optimizer": AdamW(
                self.model.parameters(),
                lr=1e-1,
            ),
        }


def main():
    module = BoringModule(
        num_datasets=2,
        num_samples=10000,
        batch_size=32,
        num_workers=1,
    )
    trainer = Trainer(
        logger=True,
        max_epochs=10,
        num_sanity_val_steps=0,
        log_every_n_steps=1,
        gradient_clip_val=1.0,
        benchmark=True,
        detect_anomaly=False,
        sync_batchnorm=True,
        # reload_dataloaders_every_n_epochs=0, # Neither of those two options have any effect
        # reload_dataloaders_every_n_epochs=1, # on the lifetime of the train_dataloader()'s workers
    )
    trainer.fit(model=module)


if __name__ == "__main__":
    main()

cc @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementrepro neededThe issue is missing a reproducible example

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions