Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
49f4775
Use a single instance of `rich.console.Console` throughout the codeba…
otaj Apr 27, 2022
5fe708d
Fix to ensure the checkpoint states are saved in a common filepath wi…
rohitgr7 Apr 27, 2022
bd83e87
Fix `trainer.logger` deprecation message (#12671)
carmocca Apr 27, 2022
fafa5f4
Fix tests related to DDP communication hooks (#12878)
akihironitta Apr 27, 2022
7ae1743
Use cmake installed with apt (#12907)
akihironitta Apr 28, 2022
2a66abd
ShardedGradScaler should only be set for FP16 (#12915)
Apr 28, 2022
0d668c9
Print ragged dict of metrics in `EvaluationLoop._print_results` prope…
otaj Apr 28, 2022
155985e
CHANGELOG + version update
carmocca Apr 28, 2022
9287a7f
Invoke parent DDP configuration for torch>1.10.2 (#12912)
jerome-habana Apr 28, 2022
c58a509
Threading support for legacy loading of checkpoints (#12814)
krshrimali Apr 28, 2022
86002c0
Remove use of jsonargparse internals (#12918)
mauvilsa Apr 29, 2022
838bbaf
Fix pickling of KFoldLoop (#12441)
niberger Apr 29, 2022
903a867
Exclude the CHANGELOG from the pre-commit size check (#12931)
carmocca Apr 29, 2022
0bd7521
Override `optimizer_zero_grad` when using the `IPUStrategy` (#12913)
hmellor May 1, 2022
b12a567
Update nvidia gpg key to fix nightly docker builds (#12930)
akihironitta May 2, 2022
60a1680
Fuse_modules in a qat-respecting way (#12891)
ORippler May 2, 2022
869f88f
Add hook test for reloading with max epochs (#12932)
carmocca May 2, 2022
04f6cfe
Enforce eval shuffle warning only for default samplers (#12653)
rohitgr7 May 2, 2022
fe934e6
Merge pull request #12920 from PyTorchLightning/rename/lightning_exte…
kaushikb11 Apr 28, 2022
9b7ef7a
[FIX] Enable mixed precision in the Fully Sharded Strategy when `prec…
May 3, 2022
683d1ef
Fix `TQDMProgressBar` reset and update to show correct time estimatio…
rohitgr7 May 3, 2022
694a819
Fix fit loop restart logic to enable resume using the checkpoint (#12…
rohitgr7 May 3, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci_dockers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ jobs:
- {python_version: "3.7", pytorch_version: "1.11", cuda_version: "11.3.1"}
# latest (used in Tutorials)
- {python_version: "3.8", pytorch_version: "1.8", cuda_version: "11.1"}
- {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1"}
- {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.1"}
- {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"}
steps:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/events-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ jobs:
- {python_version: "3.7", pytorch_version: "1.11", cuda_version: "11.3.1"}
# latest (used in Tutorials)
- {python_version: "3.8", pytorch_version: "1.8", cuda_version: "11.1"}
- {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1"}
- {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.1"}
- {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"}

Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ repos:
args: ['--maxkb=350', '--enforce-all']
exclude: |
(?x)^(
CHANGELOG.md|
docs/source/_static/images/general/fast_2.gif|
docs/source/_static/images/mnist_imgs/pt_to_pl.jpg|
docs/source/_static/images/lightning_module/pt_to_pl.png|
Expand Down
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.6.3] - 2022-05-03

### Fixed

- Use only a single instance of `rich.console.Console` throughout codebase ([#12886](https://github.com/PyTorchLightning/pytorch-lightning/pull/12886))
- Fixed an issue to ensure all the checkpoint states are saved in a common filepath with `DeepspeedStrategy` ([#12887](https://github.com/PyTorchLightning/pytorch-lightning/pull/12887))
- Fixed `trainer.logger` deprecation message ([#12671](https://github.com/PyTorchLightning/pytorch-lightning/pull/12671))
- Fixed an issue where sharded grad scaler is passed in when using BF16 with the `ShardedStrategy` ([#12915](https://github.com/PyTorchLightning/pytorch-lightning/pull/12915))
- Fixed an issue wrt recursive invocation of DDP configuration in hpu parallel plugin ([#12912](https://github.com/PyTorchLightning/pytorch-lightning/pull/12912))
- Fixed printing of ragged dictionaries in `Trainer.validate` and `Trainer.test` ([#12857](https://github.com/PyTorchLightning/pytorch-lightning/pull/12857))
- Fixed threading support for legacy loading of checkpoints ([#12814](https://github.com/PyTorchLightning/pytorch-lightning/pull/12814))
- Fixed pickling of `KFoldLoop` ([#12441](https://github.com/PyTorchLightning/pytorch-lightning/pull/12441))
- Stopped `optimizer_zero_grad` from being called after IPU execution ([#12913](https://github.com/PyTorchLightning/pytorch-lightning/pull/12913))
- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891))
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965))
- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))
- Fixed fit loop restart logic to enable resume using the checkpoint ([#12821](https://github.com/PyTorchLightning/pytorch-lightning/pull/12821)


## [1.6.2] - 2022-04-27

### Fixed
Expand Down
16 changes: 5 additions & 11 deletions dockers/base-conda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ ENV \
# CUDA_TOOLKIT_ROOT_DIR="/usr/local/cuda" \
MKL_THREADING_LAYER=GNU

RUN apt-get update -qq --fix-missing && \
RUN \
# TODO: Remove the manual key installation once the base image is updated.
# https://github.com/NVIDIA/nvidia-docker/issues/1631
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
apt-get update -qq --fix-missing && \
apt-get install -y --no-install-recommends \
build-essential \
cmake \
Expand Down Expand Up @@ -104,16 +108,6 @@ RUN \
pip install -r requirements-examples.txt --no-cache-dir --find-links https://download.pytorch.org/whl/test/torch_test.html && \
rm assistant.py

RUN \
apt-get purge -y cmake && \
wget -q https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz && \
tar -zxvf cmake-3.20.2.tar.gz && \
cd cmake-3.20.2 && \
./bootstrap -- -DCMAKE_USE_OPENSSL=OFF && \
make && \
make install && \
cmake --version

ENV \
# if you want this environment to be the default o \ne, uncomment the following line:
CONDA_DEFAULT_ENV=${CONDA_ENV} \
Expand Down
6 changes: 5 additions & 1 deletion dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ ENV \
# MAKEFLAGS="-j$(nproc)"
MAKEFLAGS="-j2"

RUN apt-get update -qq --fix-missing && \
RUN \
# TODO: Remove the manual key installation once the base image is updated.
# https://github.com/NVIDIA/nvidia-docker/issues/1631
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
apt-get update -qq --fix-missing && \
apt-get install -y --no-install-recommends \
build-essential \
pkg-config \
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
"sphinx_copybutton",
"sphinx_paramlinks",
"sphinx_togglebutton",
"pt_lightning_sphinx_theme.extensions.lightning_tutorials",
"pt_lightning_sphinx_theme.extensions.lightning",
]

# Suppress warnings about duplicate labels (needed for PL tutorials)
Expand Down
3 changes: 3 additions & 0 deletions pl_examples/loop_examples/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def __getattr__(self, key) -> Any:
return getattr(self.fit_loop, key)
return self.__dict__[key]

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)


class LitImageClassifier(ImageClassifier):
def __init__(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = "1.6.2"
__version__ = "1.6.3"
__author__ = "William Falcon et al."
__author_email__ = "waf2107@columbia.edu"
__license__ = "Apache-2.0"
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

Task, Style = None, None
if _RICH_AVAILABLE:
from rich.console import Console, RenderableType
from rich import get_console, reconfigure
from rich.console import RenderableType
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn
from rich.progress_bar import ProgressBar
from rich.style import Style
Expand Down Expand Up @@ -278,7 +279,8 @@ def enable(self) -> None:
def _init_progress(self, trainer):
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
self._console = Console(**self._console_kwargs)
reconfigure(**self._console_kwargs)
self._console = get_console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
self.progress = CustomProgress(
Expand Down
25 changes: 14 additions & 11 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,13 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches
self.main_progress_bar.total = convert_inf(total_batches)
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
current = self.train_batch_idx + self._val_processed
if self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)
_update_n(self.main_progress_bar, current, self.refresh_rate)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -288,17 +288,17 @@ def on_validation_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return

self.val_progress_bar.total = convert_inf(self.total_val_batches_current_dataloader)
self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")

def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
_update_n(self.val_progress_bar, self.val_batch_idx)
_update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate)

current = self.train_batch_idx + self._val_processed
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)
_update_n(self.main_progress_bar, current, self.refresh_rate)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
Expand All @@ -315,12 +315,12 @@ def on_test_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return

self.test_progress_bar.total = convert_inf(self.total_test_batches_current_dataloader)
self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")

def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
_update_n(self.test_progress_bar, self.test_batch_idx)
_update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar.close()
Expand All @@ -335,12 +335,12 @@ def on_predict_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return

self.predict_progress_bar.total = convert_inf(self.total_predict_batches_current_dataloader)
self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")

def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
_update_n(self.predict_progress_bar, self.predict_batch_idx)
_update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate)

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar.close()
Expand Down Expand Up @@ -384,7 +384,10 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
return x


def _update_n(bar: _tqdm, value: int) -> None:
def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None:
if not bar.disable:
bar.n = value
total = bar.total
leftover = current % refresh_rate
advance = leftover if (current == total and leftover != 0) else refresh_rate
bar.update(advance)
bar.refresh()
9 changes: 7 additions & 2 deletions pytorch_lightning/callbacks/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,19 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TORCH_GREATER_EQUAL_1_10:
from torch.ao.quantization.qconfig import QConfig
else:
from torch.quantization import QConfig

if _TORCH_GREATER_EQUAL_1_11:
from torch.ao.quantization import fuse_modules_qat as fuse_modules
else:
from torch.quantization import fuse_modules


def wrap_qat_forward_context(
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
Expand Down Expand Up @@ -252,7 +257,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
model.qconfig = self._qconfig

if self._check_feasible_fuse(model):
torch.quantization.fuse_modules(model, self._modules_to_fuse, inplace=True)
fuse_modules(model, self._modules_to_fuse, inplace=True)

# Prepare the model for QAT. This inserts observers and fake_quants in
# the model that will observe weight and activation tensors during calibration.
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pytorch_lightning.utilities.model_summary import get_human_readable_count

if _RICH_AVAILABLE:
from rich.console import Console
from rich import get_console
from rich.table import Table


Expand Down Expand Up @@ -73,7 +73,7 @@ def summarize(
model_size: float,
) -> None:

console = Console()
console = get_console()

table = Table(header_style="bold magenta")
table.add_column(" ", style="dim")
Expand Down
25 changes: 22 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
Expand Down Expand Up @@ -249,7 +249,26 @@ def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
@property
def logger(self) -> Optional[LightningLoggerBase]:
"""Reference to the logger object in the Trainer."""
return self.trainer.logger if self.trainer else None
# this should match the implementation of `trainer.logger`
# we don't reuse it so we can properly set the deprecation stacklevel
if self.trainer is None:
return
loggers = self.trainer.loggers
if len(loggers) == 0:
return None
if len(loggers) == 1:
return loggers[0]
else:
if not self._running_torchscript:
rank_zero_deprecation(
"Using `lightning_module.logger` when multiple loggers are configured."
" This behavior will change in v1.8 when `LoggerCollection` is removed, and"
" `lightning_module.logger` will return the first logger available.",
stacklevel=5,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return LoggerCollection(loggers)

@property
def loggers(self) -> List[LightningLoggerBase]:
Expand Down
Loading