diff --git a/.azure/hpu-tests.yml b/.azure/hpu-tests.yml index bdfada907cac9..a8ecb3ad5efa8 100644 --- a/.azure/hpu-tests.yml +++ b/.azure/hpu-tests.yml @@ -9,10 +9,26 @@ trigger: - "master" - "release/*" - "refs/tags/*" + paths: + include: + - ".azure/hpu-tests.yml" + - "examples/pl_hpu/mnist_sample.py" + - "requirements/pytorch/**" + - "src/pytorch_lightning/**" + - "tests/tests_pytorch/**" pr: - - "master" - - "release/*" + branches: + include: + - "master" + - "release/*" + paths: + include: + - ".azure/hpu-tests.yml" + - "examples/pl_hpu/mnist_sample.py" + - "requirements/pytorch/**" + - "src/pytorch_lightning/**" + - "tests/tests_pytorch/**" jobs: - job: testing diff --git a/.github/file-filters.yml b/.github/file-filters.yml deleted file mode 100644 index e621cd83881e4..0000000000000 --- a/.github/file-filters.yml +++ /dev/null @@ -1,9 +0,0 @@ -# This file contains filters to be used in the CI to detect file changes and run the required CI jobs. - -app_examples: - - "src/lightning_app/**" - - "tests/tests_app_examples/**" - - "requirements/app/**" - - "examples/app_*" - - "setup.py" - - "src/pytorch_lightning/__version__.py" diff --git a/.github/workflows/_check-schema.yml b/.github/workflows/_check-schema.yml deleted file mode 100644 index 299af83503831..0000000000000 --- a/.github/workflows/_check-schema.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: Reusable Check Schema - -on: - workflow_call: - inputs: - azure-dir: - description: 'Directory containing Azure Pipelines config files. Provide an empty string to skip checking on Azure Pipelines files.' - default: './.azure/' - required: false - type: string - -jobs: - schema: - runs-on: ubuntu-20.04 - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: pip install check-jsonschema - - - name: GitHub Actions - workflow - run: check-jsonschema $(find .github/workflows -name '*.yml' -a ! -name '_*.yml') --builtin-schema "github-workflows" - - - name: GitHub Actions - action - run: | - if [ -d ".github/actions" ]; then - check-jsonschema .github/actions/*/*.yml --builtin-schema "github-actions" - fi - - - name: Azure Pipelines - env: - SCHEMA_FILE: https://raw.githubusercontent.com/microsoft/azure-pipelines-vscode/v1.204.0/service-schema.json - run: | - if [ -d ${{ inputs.azure-dir }} ]; then - check-jsonschema ${{ inputs.azure-dir }}/*.yml --schemafile "$SCHEMA_FILE" - fi diff --git a/.github/workflows/ci-app-cloud-e2e-test.yml b/.github/workflows/ci-app-cloud-e2e-test.yml index 3ad455650a117..c8cef5fbf53f9 100644 --- a/.github/workflows/ci-app-cloud-e2e-test.yml +++ b/.github/workflows/ci-app-cloud-e2e-test.yml @@ -57,9 +57,8 @@ jobs: - commands timeout-minutes: 35 steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.8 - uses: actions/setup-python@v4 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 with: python-version: "3.8" diff --git a/.github/workflows/ci-app-examples.yml b/.github/workflows/ci-app-examples.yml index 32e1fc54e1814..818777727ca5a 100644 --- a/.github/workflows/ci-app-examples.yml +++ b/.github/workflows/ci-app-examples.yml @@ -32,7 +32,7 @@ jobs: timeout-minutes: 10 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: diff --git a/.github/workflows/ci-app-tests.yml b/.github/workflows/ci-app-tests.yml index e4e6574d9aa31..f5725fab0e832 100644 --- a/.github/workflows/ci-app-tests.yml +++ b/.github/workflows/ci-app-tests.yml @@ -30,7 +30,7 @@ jobs: timeout-minutes: 20 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: diff --git a/.github/workflows/ci-pkg-install.yml b/.github/workflows/ci-pkg-install.yml index a9fdd36693a67..b4fae74f991aa 100644 --- a/.github/workflows/ci-pkg-install.yml +++ b/.github/workflows/ci-pkg-install.yml @@ -38,7 +38,7 @@ jobs: python-version: [3.8] # , 3.9 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -72,7 +72,7 @@ jobs: python-version: [3.8] # , 3.9 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -104,7 +104,7 @@ jobs: python-version: [3.8] # , 3.9 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/ci-pr-gatekeeper.yml b/.github/workflows/ci-pr-gatekeeper.yml index 8714bec926c23..5c235f151b59b 100644 --- a/.github/workflows/ci-pr-gatekeeper.yml +++ b/.github/workflows/ci-pr-gatekeeper.yml @@ -20,7 +20,7 @@ jobs: fetch-depth: "2" # To retrieve the preceding commit. - name: Get changed files using defaults id: changed-files - uses: tj-actions/changed-files@v29.0.1 + uses: tj-actions/changed-files@v29.0.3 - name: Determine changes id: touched run: | diff --git a/.github/workflows/ci-pytorch-test-conda.yml b/.github/workflows/ci-pytorch-test-conda.yml index 82c463a54169f..64d06a22949d8 100644 --- a/.github/workflows/ci-pytorch-test-conda.yml +++ b/.github/workflows/ci-pytorch-test-conda.yml @@ -33,11 +33,11 @@ jobs: - name: Workaround for https://github.com/actions/checkout/issues/760 run: git config --global --add safe.directory /__w/lightning/lightning - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Get changed files id: changed-files - uses: tj-actions/changed-files@v29.0.1 + uses: tj-actions/changed-files@v29.0.3 - name: Decide if the test should be skipped id: skip diff --git a/.github/workflows/ci-pytorch-test-full.yml b/.github/workflows/ci-pytorch-test-full.yml index 987373b6ea2bf..fbdc81b91c0ed 100644 --- a/.github/workflows/ci-pytorch-test-full.yml +++ b/.github/workflows/ci-pytorch-test-full.yml @@ -35,11 +35,11 @@ jobs: timeout-minutes: 40 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Get changed files id: changed-files - uses: tj-actions/changed-files@v29.0.1 + uses: tj-actions/changed-files@v29.0.3 - name: Decide if the test should be skipped id: skip diff --git a/.github/workflows/ci-pytorch-test-slow.yml b/.github/workflows/ci-pytorch-test-slow.yml index 126eaaf17da1a..091c3f606c3ca 100644 --- a/.github/workflows/ci-pytorch-test-slow.yml +++ b/.github/workflows/ci-pytorch-test-slow.yml @@ -7,6 +7,12 @@ on: # Trigger the workflow on push or pull request, but only for the master bra pull_request: branches: [master, "release/*"] types: [opened, reopened, ready_for_review, synchronize] + paths: + - "requirements/pytorch/**" + - "src/pytorch_lightning/**" + - "tests/tests_pytorch/**" + - "setup.cfg" # includes pytest config + - ".github/workflows/ci-pytorch-test-slow.yml" concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} @@ -26,45 +32,21 @@ jobs: timeout-minutes: 20 steps: - - uses: actions/checkout@v2 - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v29.0.1 - - - name: Decide if the test should be skipped - id: skip - shell: bash -l {0} - run: | - FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*' - echo "${{ steps.changed-files.outputs.all_changed_files }}" | tr " " "\n" > changed_files.txt - MATCHES=$(cat changed_files.txt | grep -E $FILTER) - echo $MATCHES - if [ -z "$MATCHES" ]; then - echo "Skip" - echo "::set-output name=continue::0" - else - echo "Continue" - echo "::set-output name=continue::1" - fi + - uses: actions/checkout@v3 - uses: actions/setup-python@v4 - if: ${{ (steps.skip.outputs.continue == '1') }} with: python-version: ${{ matrix.python-version }} - name: Reset caching - if: ${{ (steps.skip.outputs.continue == '1') }} run: python -c "import time; days = time.time() / 60 / 60 / 24; print(f'TIME_PERIOD=d{int(days / 2) * 2}')" >> $GITHUB_ENV - name: Get pip cache - if: ${{ (steps.skip.outputs.continue == '1') }} id: pip-cache run: | python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" - name: Cache pip - if: ${{ (steps.skip.outputs.continue == '1') }} uses: actions/cache@v3 with: path: ${{ steps.pip-cache.outputs.dir }} @@ -73,7 +55,6 @@ jobs: ${{ runner.os }}-pip-td${{ env.TIME_PERIOD }}-py${{ matrix.python-version }}- - name: Install dependencies - if: ${{ (steps.skip.outputs.continue == '1') }} env: PACKAGE_NAME: pytorch FREEZE_REQUIREMENTS: 1 @@ -85,21 +66,20 @@ jobs: shell: bash - name: Testing PyTorch - if: ${{ (steps.skip.outputs.continue == '1') }} working-directory: tests/tests_pytorch run: coverage run --source pytorch_lightning -m pytest -v --junitxml=results-${{ runner.os }}-py${{ matrix.python-version }}.xml env: PL_RUN_SLOW_TESTS: 1 - name: Upload pytest test results - if: ${{ (failure()) && (steps.skip.outputs.continue == '1') }} + if: failure() uses: actions/upload-artifact@v3 with: name: unittest-results-${{ runner.os }}-py${{ matrix.python-version }} path: tests/tests_pytorch/results-${{ runner.os }}-py${{ matrix.python-version }}.xml - name: Statistics - if: ${{ (success()) && (steps.skip.outputs.continue == '1') }} + if: success() working-directory: tests/tests_pytorch run: | coverage report @@ -107,7 +87,7 @@ jobs: - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 - if: ${{ (success()) && (steps.skip.outputs.continue == '1') }} + if: success() # see: https://github.com/actions/toolkit/issues/399 continue-on-error: true with: diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml index 156334ae96043..364266d340520 100644 --- a/.github/workflows/ci-schema.yml +++ b/.github/workflows/ci-schema.yml @@ -1,10 +1,11 @@ name: Check Schema on: - push: {} + push: + branches: [master, "release/*"] pull_request: branches: [master, "release/*"] jobs: check: - uses: ./.github/workflows/_check-schema.yml + uses: Lightning-AI/devtools/.github/workflows/check-schema.yml@v0.1.0 diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index a91f216af963f..0de1d16cfba58 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -19,7 +19,7 @@ jobs: matrix: pkg: ["app", "pytorch"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: submodules: true - uses: actions/setup-python@v4 @@ -70,7 +70,7 @@ jobs: matrix: pkg: ["app", "pytorch"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: submodules: true # lfs: true diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index dd589baf2fa46..97c320ca84298 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -9,8 +9,8 @@ jobs: runs-on: ubuntu-20.04 steps: - name: Checkout 🛎️ - uses: actions/checkout@v2 - # If you're using actions/checkout@v2 you must set persist-credentials to false in most cases for the deployment to work correctly. + uses: actions/checkout@v3 + # If you're using actions/checkout@v3 you must set persist-credentials to false in most cases for the deployment to work correctly. with: persist-credentials: false - uses: actions/setup-python@v4 diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml index 13d3895bf365d..2576b05e33566 100644 --- a/.github/workflows/events-nightly.yml +++ b/.github/workflows/events-nightly.yml @@ -15,7 +15,7 @@ jobs: steps: # does nightly releases from feature branch - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: python-version: 3.9 diff --git a/.github/workflows/legacy-checkpoints.yml b/.github/workflows/legacy-checkpoints.yml index 0856cfd3229a2..7a59b9446aab0 100644 --- a/.github/workflows/legacy-checkpoints.yml +++ b/.github/workflows/legacy-checkpoints.yml @@ -8,7 +8,7 @@ jobs: create-legacy-ckpts: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index 2de330ea5ca75..67503ba2b2c0d 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -22,7 +22,7 @@ jobs: - {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.6.1"} steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Get release version id: get_version diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index 97c3b8eca77d1..2c6f5da240f63 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -21,7 +21,7 @@ jobs: build-pkgs: ${{ steps.candidate.outputs.pkgs }} pull-pkgs: ${{ steps.download.outputs.pkgs }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: python-version: 3.9 @@ -60,7 +60,7 @@ jobs: max-parallel: 1 matrix: ${{ fromJSON(needs.releasing.outputs.build-pkgs) }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/download-artifact@v3 with: name: dist-packages-${{ github.sha }} @@ -94,7 +94,7 @@ jobs: max-parallel: 1 matrix: ${{ fromJSON(needs.releasing.outputs.pull-pkgs) }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/download-artifact@v3 with: name: pypi-packages-${{ github.sha }} @@ -118,7 +118,7 @@ jobs: needs: [build-package, download-package] runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/download-artifact@v3 with: name: dist-packages-${{ github.sha }} @@ -169,7 +169,7 @@ jobs: needs: build-meta-pkg if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/download-artifact@v3 with: name: dist-packages-${{ github.sha }} @@ -188,7 +188,7 @@ jobs: needs: build-meta-pkg if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/download-artifact@v3 with: name: dist-packages-${{ github.sha }} @@ -220,7 +220,7 @@ jobs: runs-on: ubuntu-20.04 needs: [build-package, publish-package] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: diff --git a/requirements/app/test.txt b/requirements/app/test.txt index 1ab762936db18..466defcae79ce 100644 --- a/requirements/app/test.txt +++ b/requirements/app/test.txt @@ -7,3 +7,4 @@ playwright==1.22.0 # pytest-flake8 httpx trio +psutil diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 20b6c1b8dbc12..b331a93c0b0bb 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -9,3 +9,4 @@ hydra-core>=1.0.5, <1.3.0 jsonargparse[signatures]>=4.12.0, <=4.12.0 gcsfs>=2021.5.0, <2022.8.0 rich>=10.14.0, !=10.15.0.a, <13.0.0 +protobuf<=3.20.1 # strict # an extra is updating protobuf, this pin prevents TensorBoard failure diff --git a/requirements/pytorch/loggers.txt b/requirements/pytorch/loggers.txt index 905823451973b..573daaa541ced 100644 --- a/requirements/pytorch/loggers.txt +++ b/requirements/pytorch/loggers.txt @@ -6,5 +6,4 @@ neptune-client>=0.10.0, <0.16.4 comet-ml>=3.1.12, <3.31.8 mlflow>=1.0.0, <1.29.0 -test_tube>=0.7.5, <=0.7.5 wandb>=0.10.22, <0.13.2 diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index f9c9ecdc46a98..d1838774b49ac 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -4,6 +4,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.7.5] - 2022-09-06 + +### Fixed + +- Squeezed tensor values when logging with `LightningModule.log` ([#14489](https://github.com/Lightning-AI/lightning/pull/14489)) +- Fixed `WandbLogger` `save_dir` is not set after creation ([#14326](https://github.com/Lightning-AI/lightning/pull/14326)) +- Fixed `Trainer.estimated_stepping_batches` when maximum number of epochs is not set ([#14317](https://github.com/Lightning-AI/lightning/pull/14317)) + + ## [1.7.4] - 2022-08-31 ### Added diff --git a/src/pytorch_lightning/__version__.py b/src/pytorch_lightning/__version__.py index 582554e87c281..57a819f4fc5bb 100644 --- a/src/pytorch_lightning/__version__.py +++ b/src/pytorch_lightning/__version__.py @@ -1 +1 @@ -version = "1.7.4" +version = "1.7.5" diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 1f89609e82e82..39023dd37fdf1 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -419,8 +419,7 @@ def log( " but it should not contain information about `dataloader_idx`" ) - value = apply_to_collection(value, numbers.Number, self.__to_tensor) - apply_to_collection(value, torch.Tensor, self.__check_numel_1, name) + value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name) if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name): # if we started a new epoch (running its first batch) the hook name has changed @@ -552,16 +551,15 @@ def __check_not_nested(value: dict, name: str) -> None: def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") - def __to_tensor(self, value: numbers.Number) -> Tensor: - return torch.tensor(value, device=self.device) - - @staticmethod - def __check_numel_1(value: Tensor, name: str) -> None: + def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor: + value = torch.tensor(value, device=self.device) if not torch.numel(value) == 1: raise ValueError( f"`self.log({name}, {value})` was called, but the tensor must have a single element." f" You can try doing `self.log({name}, {value}.mean())`" ) + value = value.squeeze() + return value def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None: """Override this method to change the default behaviour of ``log_grad_norm``. diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index baf4bc9092774..3198e46b1a586 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -223,7 +223,7 @@ def __init__(self, *args, **kwarg): Args: name: Display name for the run. - save_dir: Path where data is saved (wandb dir by default). + save_dir: Path where data is saved. offline: Run offline (data can be streamed later to wandb servers). id: Sets the version, mainly used to resume a previous run. version: Same as id. @@ -255,7 +255,7 @@ def __init__(self, *args, **kwarg): def __init__( self, name: Optional[str] = None, - save_dir: Optional[str] = None, + save_dir: str = ".", offline: bool = False, id: Optional[str] = None, anonymous: Optional[bool] = None, @@ -300,7 +300,7 @@ def __init__( name=name, project=project, id=version or id, - dir=save_dir, + dir=save_dir or kwargs.pop("dir"), resume="allow", anonymous=("allow" if anonymous else None), ) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 7025e49ee5613..378a969830a6f 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -2769,8 +2769,8 @@ def configure_optimizers(self): ) # infinite training - if self.max_epochs == -1 and self.max_steps == -1: - return float("inf") + if self.max_epochs == -1: + return float("inf") if self.max_steps == -1 else self.max_steps if self.train_dataloader is None: rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.") diff --git a/tests/tests_app/cli/test_cli.py b/tests/tests_app/cli/test_cli.py index 3e003293692a8..ec942db6f157c 100644 --- a/tests/tests_app/cli/test_cli.py +++ b/tests/tests_app/cli/test_cli.py @@ -140,3 +140,6 @@ def test_cli_logout(exists: mock.MagicMock, unlink: mock.MagicMock, creds: bool) unlink.assert_called_once_with() else: unlink.assert_not_called() + + +# TODO: test for the other commands diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index 648e1a8f38ec8..b408046c9e5d2 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -58,9 +58,15 @@ def test_wandb_logger_init(wandb, monkeypatch): wandb.init.reset_mock() WandbLogger(project="test_project").experiment wandb.init.assert_called_once_with( - name=None, dir=None, id=None, project="test_project", resume="allow", anonymous=None + name=None, dir=".", id=None, project="test_project", resume="allow", anonymous=None ) + # test wandb.init set save_dir correctly after created + wandb.run = None + wandb.init.reset_mock() + logger = WandbLogger(project="test_project") + assert logger.save_dir is not None + # test wandb.init and setting logger experiment externally wandb.run = None run = wandb.init() diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index d16be306b9365..cd7f83ddc7bfe 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -28,9 +28,9 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.module import LightningModule -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.datasets import RandomDictDataset from tests_pytorch.helpers.runif import RunIf @@ -837,3 +837,13 @@ def on_train_start(self): assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)] assert trainer.max_epochs > 1 + + +def test_unsqueezed_tensor_logging(): + model = BoringModel() + trainer = Trainer() + trainer.state.stage = RunningStage.TRAINING + model._current_fx_name = "training_step" + model.trainer = trainer + model.log("foo", torch.Tensor([1.2])) + assert trainer.callback_metrics["foo"].ndim == 0 diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 92a1126294dfc..0f694757ca22d 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -95,9 +95,9 @@ def test_num_stepping_batches_infinite_training(): assert trainer.estimated_stepping_batches == float("inf") -def test_num_stepping_batches_with_max_steps(): +@pytest.mark.parametrize("max_steps", [2, 100]) +def test_num_stepping_batches_with_max_steps(max_steps): """Test stepping batches with `max_steps`.""" - max_steps = 2 trainer = Trainer(max_steps=max_steps) model = BoringModel() trainer.fit(model)