Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,24 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))


- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


### Removed

- Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499))
- Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([#7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499))


- Removed deprecated data parallel classes `LightningDataParallel` and `LightningDistributedDataParallel` from `pytorch_lightning.overrides.data_parallel` ([7510](https://github.com/PyTorchLightning/pytorch-lightning/pull/7510))
- Removed deprecated data parallel classes `LightningDataParallel` and `LightningDistributedDataParallel` from `pytorch_lightning.overrides.data_parallel` ([#7510](https://github.com/PyTorchLightning/pytorch-lightning/pull/7510))


- Removed deprecated trainer attributes - `get_model` and `accelerator_backend` ([7502](https://github.com/PyTorchLightning/pytorch-lightning/pull/7502))
- Removed deprecated trainer attributes - `get_model` and `accelerator_backend` ([#7502](https://github.com/PyTorchLightning/pytorch-lightning/pull/7502))


- Removed deprecated utils modules `model_utils`, `warning_utils`, `xla_device_utils` and partially `argparse_utils` ([7503](https://github.com/PyTorchLightning/pytorch-lightning/pull/7503))
- Removed support for `self.log(tbptt_reduce_fx)` and `self.log(tbptt_pad_token)`. Please, open a discussion explaining your use-case if you relied on these. ([#7644](https://github.com/PyTorchLightning/pytorch-lightning/pull/7644))


- Removed deprecated utils modules `model_utils`, `warning_utils`, `xla_device_utils` and partially `argparse_utils` ([#7503](https://github.com/PyTorchLightning/pytorch-lightning/pull/7503))


- Removed deprecated trainer attributes - `on_cpu`, `on_tpu`, `use_tpu`, `on_gpu`, `use_dp`, `use_ddp`, `use_ddp2`, `use_horovod`, `use_single_gpu` ([#7501](https://github.com/PyTorchLightning/pytorch-lightning/pull/7501))
Expand Down Expand Up @@ -1338,7 +1341,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed getting `experiment_id` from MLFlow only once instead of each training loop ([#3394](https://github.com/PyTorchLightning/pytorch-lightning/pull/3394))
- Fixed `overfit_batches` which now correctly disables shuffling for the training loader. ([#3501](https://github.com/PyTorchLightning/pytorch-lightning/pull/3501))
- Fixed gradient norm tracking for `row_log_interval > 1` ([#3489](https://github.com/PyTorchLightning/pytorch-lightning/pull/3489))
- Fixed `ModelCheckpoint` name formatting ([3164](https://github.com/PyTorchLightning/pytorch-lightning/pull/3163))
- Fixed `ModelCheckpoint` name formatting ([#3164](https://github.com/PyTorchLightning/pytorch-lightning/pull/3163))
- Fixed example implementation of AutoEncoder ([#3190](https://github.com/PyTorchLightning/pytorch-lightning/pull/3190))
- Fixed invalid paths when remote logging with TensorBoard ([#3236](https://github.com/PyTorchLightning/pytorch-lightning/pull/3236))
- Fixed change `t()` to `transpose()` as XLA devices do not support `.t()` on 1-dim tensor ([#3252](https://github.com/PyTorchLightning/pytorch-lightning/pull/3252))
Expand Down Expand Up @@ -1598,8 +1601,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` ([#1908](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908))
- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458))
- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033))
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756))
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([#1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756))
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([#1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/pull/1667))
- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134))
Expand Down
27 changes: 17 additions & 10 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def log(
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
Expand Down Expand Up @@ -299,8 +299,6 @@ def log(
on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step
on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_step
reduce_fx: reduction function over step values for end of epoch. Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across GPUs/TPUs
Expand All @@ -309,6 +307,19 @@ def log(
the name (when using multiple). If False, user needs to give unique names for
each dataloader to not mix values
"""
if tbptt_reduce_fx is not None:
rank_zero_deprecation(
'`self.log(tbptt_reduce_fx=...)` is no longer supported. The flag will be removed in v1.6.'
' Please, open a discussion explaining your use-case in'
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
)
if tbptt_pad_token is not None:
rank_zero_deprecation(
'`self.log(tbptt_pad_token=...)` is no longer supported. The flag will be removed in v1.6.'
' Please, open a discussion explaining your use-case in'
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
)

if self._results is not None:
# TODO: if logged twice fail with crash

Expand All @@ -333,8 +344,6 @@ def log(
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
enable_graph=enable_graph,
sync_dist=sync_dist,
sync_dist_op=sync_dist_op,
Expand All @@ -352,8 +361,8 @@ def log_dict(
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
Expand All @@ -375,8 +384,6 @@ def log_dict(
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
reduce_fx: reduction function over step values for end of epoch. Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across GPUs/TPUs
Expand Down
60 changes: 2 additions & 58 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def log(
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
Expand Down Expand Up @@ -134,8 +132,6 @@ def log(
on_step=True,
on_epoch=False,
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False,
dataloader_idx=dataloader_idx,
)
Expand All @@ -153,8 +149,6 @@ def log(
on_step=False,
on_epoch=True,
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False,
dataloader_idx=dataloader_idx,
)
Expand All @@ -169,8 +163,6 @@ def log(
on_step,
on_epoch,
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=was_forked,
dataloader_idx=dataloader_idx,
)
Expand All @@ -187,8 +179,6 @@ def __set_meta(
on_step: bool,
on_epoch: bool,
reduce_fx: Callable,
tbptt_pad_token: int,
tbptt_reduce_fx: Callable,
forked: bool,
dataloader_idx: Union[int, None],
):
Expand All @@ -201,8 +191,6 @@ def __set_meta(
on_epoch=on_epoch,
reduce_fx=reduce_fx,
value=meta_value,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=forked,
dataloader_idx=dataloader_idx,
)
Expand Down Expand Up @@ -424,47 +412,6 @@ def unpack_batch_size(sample):
size = 1
return size

@classmethod
def gather(cls, outputs):
meta = outputs[0].get('meta')
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)

if meta:
result['meta'] = meta
return result

@classmethod
def padded_gather(cls, outputs):
meta = outputs[0].get('meta')
result = cls()
result = recursive_gather(outputs, result)

# find the padding used for other values
default_padding_idx = 0
for name, value in result.items():
if (
name != 'minimize' and isinstance(value, list) and len(value) > 0
and isinstance(value[0], torch.Tensor)
):
default_padding_idx = meta[name]['tbptt_pad_token']
break

# pad across each key individually
for name, value in result.items():
if (isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor)):
padding_key = default_padding_idx if name == 'minimize' else meta[name]['tbptt_pad_token']
padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key)
result[name] = padded

# also update the result
if meta and name != "minimize":
meta[name]['value'] = padded
if meta:
result['meta'] = meta
return result

@classmethod
def reduce_on_epoch_end(cls, outputs):
# get the batch sizes for all outputs
Expand Down Expand Up @@ -522,17 +469,14 @@ def reduce_across_time(cls, time_outputs):
if k in ['meta', 'extra'] or isinstance(value, Metric):
continue

# pick the reduce fx
tbptt_reduce_fx = torch.mean if k == "minimize" else meta[k]['tbptt_reduce_fx']

if isinstance(value, list):
value = torch.tensor(value)

if isinstance(value, dict):
# TODO: recursive reduce:
_recursive_fx_apply(value, tbptt_reduce_fx)
_recursive_fx_apply(value, torch.mean)
else:
result[k] = tbptt_reduce_fx(value.float())
result[k] = torch.mean(value.float())

result['meta'] = meta
return result
Expand Down
92 changes: 0 additions & 92 deletions tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,98 +182,6 @@ def test_dataloader(self):
assert len(predictions) == len(dm.random_test)


def test_result_gather_stack():
""" Test that tensors get concatenated when they all have the same shape. """
outputs = [
{
"foo": torch.zeros(4, 5)
},
{
"foo": torch.zeros(4, 5)
},
{
"foo": torch.zeros(4, 5)
},
]
result = Result.gather(outputs)
assert isinstance(result["foo"], torch.Tensor)
assert list(result["foo"].shape) == [12, 5]


def test_result_gather_concatenate():
""" Test that tensors get concatenated when they have varying size in first dimension. """
outputs = [
{
"foo": torch.zeros(4, 5)
},
{
"foo": torch.zeros(8, 5)
},
{
"foo": torch.zeros(3, 5)
},
]
result = Result.gather(outputs)
assert isinstance(result["foo"], torch.Tensor)
assert list(result["foo"].shape) == [15, 5]


def test_result_gather_scalar():
""" Test that 0-dim tensors get gathered and stacked correctly. """
outputs = [
{
"foo": torch.tensor(1)
},
{
"foo": torch.tensor(2)
},
{
"foo": torch.tensor(3)
},
]
result = Result.gather(outputs)
assert isinstance(result["foo"], torch.Tensor)
assert list(result["foo"].shape) == [3]


def test_result_gather_different_shapes():
""" Test that tensors of varying shape get gathered into a list. """
outputs = [
{
"foo": torch.tensor(1)
},
{
"foo": torch.zeros(2, 3)
},
{
"foo": torch.zeros(1, 2, 3)
},
]
result = Result.gather(outputs)
expected = [torch.tensor(1), torch.zeros(2, 3), torch.zeros(1, 2, 3)]
assert isinstance(result["foo"], list)
assert all(torch.eq(r, e).all() for r, e in zip(result["foo"], expected))


def test_result_gather_mixed_types():
""" Test that a collection of mixed types gets gathered into a list. """
outputs = [
{
"foo": 1.2
},
{
"foo": ["bar", None]
},
{
"foo": torch.tensor(1)
},
]
result = Result.gather(outputs)
expected = [1.2, ["bar", None], torch.tensor(1)]
assert isinstance(result["foo"], list)
assert result["foo"] == expected


def test_result_retrieve_last_logged_item():
result = Result()
result.log('a', 5., on_step=True, on_epoch=True)
Expand Down
26 changes: 26 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,29 @@ def test_v1_6_0_ddp_spawn_num_nodes():
def test_v1_6_0_ddp_spawn_sync_batchnorm():
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"):
DDPSpawnPlugin(sync_batchnorm=False)


def test_v1_6_0_tbptt_reduce_fx(tmpdir):

class TestModel(BoringModel):

def training_step(self, *args):
self.log("foo", 1, tbptt_reduce_fx=lambda x: x)
return super().training_step(*args)

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.deprecated_call(match=r"tbptt_reduce_fx=...\)` is no longer supported"):
trainer.fit(TestModel())


def test_v1_6_0_tbptt_pad_token(tmpdir):

class TestModel(BoringModel):

def training_step(self, *args):
self.log("foo", 1, tbptt_pad_token=0)
return super().training_step(*args)

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.deprecated_call(match=r"tbptt_pad_token=...\)` is no longer supported"):
trainer.fit(TestModel())