Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer profilers are typehinted with the deprecated BaseProfiler instead of Profiler #13046

Closed
mads-oestergaard opened this issue May 12, 2022 · 2 comments · Fixed by #13084
Closed
Labels
bug Something isn't working good first issue Good for newcomers profiler
Milestone

Comments

@mads-oestergaard
Copy link
Contributor

mads-oestergaard commented May 12, 2022

🐛 Bug

Trainer profilers are typehinted with the deprecated BaseProfiler instead of Profiler. This means that you cannot use class_path initialization of profilers with LightningCLI.

Error message:

  - "pytorch_lightning.profiler.PyTorchProfiler" is not a subclass of <class 'pytorch_lightning.profiler.base.BaseProfiler'>
  - Expected a <class 'str'> but got "{'class_path': 'pytorch_lightning.profiler.PyTorchProfiler', 'init_args': {'dirpath': 'profiling', 'filename': 'profile.txt'}}"
  - Expected a <class 'NoneType'> but got "{'class_path': 'pytorch_lightning.profiler.PyTorchProfiler', 'init_args': {'dirpath': 'profiling', 'filename': 'profile.txt'}}"

To Reproduce

Use this BoringModel code modified for LightningCLI

pl_bug.py:

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.utilities.cli import LightningCLI


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class RandomDataModule(LightningDataModule):
    def __init__(self):
        self.train = RandomDataset(32, 64)
        self.val = RandomDataset(32, 64)
        self.test = RandomDataset(32, 64)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=2)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=2)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=2)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


if __name__ == "__main__":
    LightningCLI(BoringModel, RandomDataModule)

along with this trainer config (bug_trainer.yaml):

profiler: 
  class_path: pytorch_lightning.profiler.PyTorchProfiler
  init_args:
    dirpath: profiling
    filename: profile.txt
    profiler_kwargs:
      profile_memory: true
      
limit_train_batches: 1
limit_val_batches: 1
limit_test_batches: 1
num_sanity_val_steps: 0
max_epochs: 1
enable_model_summary: false

Run the script from the commandline with
$ python pl_bug.py fit --trainer bug_trainer.yaml

Expected behavior

The code is run and the PyTorch profiler is instantiated and used

Fix

Change line 176 in trainer/trainer.py from
profiler: Optional[Union[BaseProfiler, str]] = None,
to
profiler: Optional[Union[Profiler, str]] = None,

Related issue:
**profiler_kwargs in profiler/pytorch.py is typehinted as Any, which gives the error

Parser key "trainer.profiler": Value "Namespace(class_path='pytorch_lightning.profiler.PyTorchProfiler', init_args=Namespace(dirpath='profiling', filename='profile.txt', group_by_input_shapes=False, emit_nvtx=False, export_to_chrome=True, row_limit=20, sort_by_key=None, record_module_names=True, profiler_kwargs=Namespace(profile_memory=True)))" does not validate against any of the types in typing.Union[pytorch_lightning.profiler.profiler.Profiler, str, NoneType]:
  - Problem with given class_path "pytorch_lightning.profiler.PyTorchProfiler":
    - 'Configuration check failed :: No action for destination key "profiler_kwargs.profile_memory" to check its value.'

A fix would be to typehint profiler_kwargs as Dict with an empty dict as default argument:
from **profiler_kwargs: Any, to profiler_kwargs: Dict = {},. This appears related to Refactor use of **kwargs in PL classes for better LightningCLI support · Issue #11653 · PyTorchLightning/pytorch-lightning (github.com).

Environment

  • CUDA:
    - GPU:
    - NVIDIA GeForce GTX 1080 Ti
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.21.6
    - pyTorch_debug: False
    - pyTorch_version: 1.11.0+cu102
    - pytorch-lightning: 1.6.3
    - tqdm: 4.64.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.9.12
    - version: How to set hyperparameters search range and run the search? #45~20.04.1-Ubuntu SMP Mon Apr 4 09:38:31 UTC 2022

Additional context

cc @otaj @Borda @carmocca @kaushikb11 @ninginthecloud @rohitgr7 @nbcsm @guotuofeng

@mads-oestergaard mads-oestergaard added the needs triage Waiting to be triaged by maintainers label May 12, 2022
@carmocca
Copy link
Contributor

Hi!

Would you like to open a PR with the typehint fix for the Trainer?

A fix would be to typehint profiler_kwargs

I would wait on this change until we know the direction on #11653

@carmocca carmocca added bug Something isn't working good first issue Good for newcomers profiler and removed needs triage Waiting to be triaged by maintainers labels May 13, 2022
@carmocca carmocca added this to the 1.6.x milestone May 13, 2022
@mads-oestergaard
Copy link
Contributor Author

Yeah sure, but I won’t have time before monday, so if there is another PR that depends on this small fix (#12308) then it’s fine to have them include it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers profiler
Projects
None yet
2 participants