Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.5.2] - 2021-11-16

### Fixed

- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))
- Fixed an issue where class or init-only variables of dataclasses were passed to the dataclass constructor in `utilities.apply_to_collection` ([#9702](https://github.com/PyTorchLightning/pytorch-lightning/issues/9702))
- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))
- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))
- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461))
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))
- Fixed scripting causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/pull/10470), [#10555](https://github.com/PyTorchLightning/pytorch-lightning/pull/10555))
- Do not fail if batch size could not be inferred for logging when using DeepSpeed ([#10438](https://github.com/PyTorchLightning/pytorch-lightning/issues/10438))
- Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from `DeviceDtypeModuleMixin` ([#10559](https://github.com/PyTorchLightning/pytorch-lightning/issues/10559))


## [1.5.1] - 2021-11-09

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = "1.5.1"
__version__ = "1.5.2"
__author__ = "William Falcon et al."
__author_email__ = "waf2107@columbia.edu"
__license__ = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
): # short circuit if metric not present
return

current = logs.get(self.monitor)
current = logs[self.monitor].squeeze()
should_stop, reason = self._evaluate_stopping_criteria(current)

# stop every ddp process if any world process decides to stop
Expand Down
34 changes: 23 additions & 11 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,12 @@ def render(self, task) -> RenderableType:
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""

def __init__(self, trainer):
def __init__(self, trainer, style):
self._trainer = trainer
self._tasks = {}
self._current_task_id = 0
self._metrics = {}
self._style = style
super().__init__()

def update(self, metrics):
Expand All @@ -158,23 +159,34 @@ def render(self, task) -> Text:

for k, v in self._metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
return Text(_text, justify="left")
return Text(_text, justify="left", style=self._style)


@dataclass
class RichProgressBarTheme:
"""Styles to associate to different base components.

Args:
description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
progress_bar: Style for the bar in progress.
progress_bar_finished: Style for the finished progress bar.
progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
time: Style for the processed time and estimate time remaining.
processing_speed: Style for the speed of the batches being processed.
metrics: Style for the metrics

https://rich.readthedocs.io/en/stable/style.html
"""

text_color: str = "white"
progress_bar_complete: Union[str, Style] = "#6206E0"
description: Union[str, Style] = "white"
progress_bar: Union[str, Style] = "#6206E0"
progress_bar_finished: Union[str, Style] = "#6206E0"
progress_bar_pulse: Union[str, Style] = "#6206E0"
batch_process: str = "white"
time: str = "grey54"
processing_speed: str = "grey70"
batch_progress: Union[str, Style] = "white"
time: Union[str, Style] = "grey54"
processing_speed: Union[str, Style] = "grey70"
metrics: Union[str, Style] = "white"


class RichProgressBar(ProgressBarBase):
Expand Down Expand Up @@ -268,7 +280,7 @@ def _init_progress(self, trainer):
self._reset_progress_bar_ids()
self._console: Console = Console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer)
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,
Expand Down Expand Up @@ -351,7 +363,7 @@ def on_validation_epoch_start(self, trainer, pl_module):
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
return self.progress.add_task(
f"[{self.theme.text_color}]{description}", total=total_batches, visible=visible
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)

def _update(self, progress_bar_id: int, visible: bool = True) -> None:
Expand Down Expand Up @@ -448,11 +460,11 @@ def configure_columns(self, trainer) -> list:
return [
TextColumn("[progress.description]{task.description}"),
CustomBarColumn(
complete_style=self.theme.progress_bar_complete,
complete_style=self.theme.progress_bar,
finished_style=self.theme.progress_bar_finished,
pulse_style=self.theme.progress_bar_pulse,
),
BatchesProcessedColumn(style=self.theme.batch_process),
BatchesProcessedColumn(style=self.theme.batch_progress),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
]
17 changes: 12 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._param_requires_grad_state = {}
self._metric_attributes: Optional[Dict[int, str]] = None
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
# TODO: remove after the 1.6 release
self._running_torchscript = False

self._register_sharded_tensor_state_dict_hooks_if_available()

Expand Down Expand Up @@ -1962,6 +1964,8 @@ def to_torchscript(
"""
mode = self.training

self._running_torchscript = True

if method == "script":
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == "trace":
Expand All @@ -1987,6 +1991,8 @@ def to_torchscript(
with fs.open(file_path, "wb") as f:
torch.jit.save(torchscript_module, f)

self._running_torchscript = False

return torchscript_module

@property
Expand All @@ -1996,11 +2002,12 @@ def model_size(self) -> float:
Note:
This property will not return correct value for Deepspeed (stage 3) and fully-sharded training.
"""
rank_zero_deprecation(
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
stacklevel=5,
)
if not self._running_torchscript: # remove with the deprecation removal
rank_zero_deprecation(
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
stacklevel=5,
)
return get_model_size_mb(self)

def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
from torch.nn import Module

import pytorch_lightning as pl


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ["device", "dtype"]
Expand Down Expand Up @@ -177,7 +179,9 @@ def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
if not isinstance(module, DeviceDtypeModuleMixin):
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
# work when using `init_meta_context`.
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
return
if device is not None:
module._device = device
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device

Expand Down Expand Up @@ -64,7 +65,7 @@ def step(self, closure: Optional[Callable] = None) -> None:
)


class _LiteModule(nn.Module):
class _LiteModule(DeviceDtypeModuleMixin):
def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None:
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def log_graph(self, model: "pl.LightningModule", input_array=None):

if input_array is not None:
input_array = model._apply_batch_transfer_handler(input_array)
model._running_torchscript = True
self.experiment.add_graph(model, input_array)
model._running_torchscript = False
else:
rank_zero_warn(
"Could not log computational graph since the"
Expand Down
21 changes: 13 additions & 8 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,11 +622,6 @@ def _format_batch_size_and_grad_accum_config(self):
)
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
if "train_micro_batch_size_per_gpu" not in self.config:
rank_zero_warn(
"Inferring the batch size for internal deepspeed logging from the `train_dataloader()`. "
"If you require skipping this, please pass "
"`Trainer(strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`"
)
batch_size = self._auto_select_batch_size()
self.config["train_micro_batch_size_per_gpu"] = batch_size
if "gradient_clipping" not in self.config:
Expand All @@ -638,9 +633,19 @@ def _auto_select_batch_size(self):
batch_size = 1
train_dl_source = self.lightning_module.trainer._data_connector._train_dataloader_source
if train_dl_source.is_defined():
train_dataloader = train_dl_source.dataloader()
if hasattr(train_dataloader, "batch_sampler"):
batch_size = train_dataloader.batch_sampler.batch_size
try:
train_dataloader = train_dl_source.dataloader()
if hasattr(train_dataloader, "batch_sampler"):
batch_size = train_dataloader.batch_sampler.batch_size
# broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup`
# to have been called before
except Exception:
if self.global_rank == 0:
deepspeed.utils.logging.logger.warning(
"Tried to infer the batch size for internal deepspeed logging from the `train_dataloader()`. "
"To ensure DeepSpeed logging remains correct, please manually pass the plugin with the "
"batch size, `Trainer(strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`."
)
return batch_size

def _format_precision_config(self):
Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,25 @@ def to_tensor(x):
args = apply_to_collection(args, dtype=(int, float), function=to_tensor)
return args

def training_step(self, *args, **kwargs):
def _step(self, stage: RunningStage, *args: Any, **kwargs: Any):
args = self._prepare_input(args)
return self.poptorch_models[RunningStage.TRAINING](*args, **kwargs)
poptorch_model = self.poptorch_models[stage]
self.lightning_module._running_torchscript = True
out = poptorch_model(*args, **kwargs)
self.lightning_module._running_torchscript = False
return out

def training_step(self, *args, **kwargs):
return self._step(RunningStage.TRAINING, *args, **kwargs)

def validation_step(self, *args, **kwargs):
args = self._prepare_input(args)
return self.poptorch_models[RunningStage.VALIDATING](*args, **kwargs)
return self._step(RunningStage.VALIDATING, *args, **kwargs)

def test_step(self, *args, **kwargs):
args = self._prepare_input(args)
return self.poptorch_models[RunningStage.TESTING](*args, **kwargs)
return self._step(RunningStage.TESTING, *args, **kwargs)

def predict_step(self, *args, **kwargs):
args = self._prepare_input(args)
return self.poptorch_models[RunningStage.PREDICTING](*args, **kwargs)
return self._step(RunningStage.PREDICTING, *args, **kwargs)

def teardown(self) -> None:
# undo dataloader patching
Expand Down
26 changes: 23 additions & 3 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class _Sync:
fn: Optional[Callable] = None
_should: bool = False
rank_zero_only: bool = False
op: Optional[str] = None
group: Optional[Any] = None
_op: Optional[str] = None
_group: Optional[Any] = None

def __post_init__(self) -> None:
self._generate_sync_fn()
Expand All @@ -67,6 +67,26 @@ def should(self, should: bool) -> None:
# `self._fn` needs to be re-generated.
self._generate_sync_fn()

@property
def op(self) -> Optional[str]:
return self._op

@op.setter
def op(self, op: Optional[str]) -> None:
self._op = op
# `self._fn` needs to be re-generated.
self._generate_sync_fn()

@property
def group(self) -> Optional[Any]:
return self._group

@group.setter
def group(self, group: Optional[Any]) -> None:
self._group = group
# `self._fn` needs to be re-generated.
self._generate_sync_fn()

def _generate_sync_fn(self) -> None:
"""Used to compute the syncing function and cache it."""
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
Expand Down Expand Up @@ -426,7 +446,7 @@ def log(
dataloader_idx=dataloader_idx,
metric_attribute=metric_attribute,
)
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only)

# register logged value if it doesn't exist
if key not in self:
Expand Down
Loading