From 0220f7388910b7c164cee6281910f9ad78dcb7d4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 8 Aug 2021 17:05:27 -0700 Subject: [PATCH 1/7] Fix truncated backprop through time set on LightningModule and not Trainer --- pytorch_lightning/loops/batch/training_batch_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index e50045676d6de..b32a4e9779a38 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -452,10 +452,11 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]: batch: the current batch to split """ splits = [batch] - if self.trainer.truncated_bptt_steps is not None: + tbptt_steps = self._truncated_bptt_steps() + if tbptt_steps > 0: model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("tbptt_split_batch"): - splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) + splits = model_ref.tbptt_split_batch(batch, tbptt_steps) return splits def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None: From 19a958f687df57aed7fe7b2355abc156ddd02d6e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 8 Aug 2021 17:12:51 -0700 Subject: [PATCH 2/7] Update test_cpu.py --- tests/models/test_cpu.py | 72 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index c7bf787248ce3..237c1726deafe 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -384,3 +384,75 @@ def train_dataloader(self): ) trainer.fit(model) assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" + + +def test_tbptt_cpu_model_lightning_module_property(tmpdir): + """Test truncated back propagation through time works when set as a property of the LightningModule.""" + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class BpttTestModel(BoringModel): + + def __init__(self, batch_size, in_features, out_features, *args, **kwargs): + super().__init__(*args, **kwargs) + self.test_hidden = None + self.batch_size = batch_size + self.layer = torch.nn.Linear(in_features, out_features) + self.truncated_bptt_steps = truncated_bptt_steps + + def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + self.test_hidden = torch.rand(1) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) + return { + "loss": loss_val, + "hiddens": self.test_hidden, + } + + def training_epoch_end(self, training_step_outputs): + training_step_outputs = training_step_outputs[0] + assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps) + loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() + self.log("train_loss", loss) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), + batch_size=batch_size, + shuffle=False, + sampler=None, + ) + + model = BpttTestModel(batch_size=batch_size, in_features=truncated_bptt_steps, out_features=truncated_bptt_steps) + model.example_input_array = torch.randn(5, truncated_bptt_steps) + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0, + weights_summary=None, + ) + trainer.fit(model) + + assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" From 98de7ac8444b17148df9aad196fc95864b705fe0 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 8 Aug 2021 17:14:50 -0700 Subject: [PATCH 3/7] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ac97333e6d90..2b4d576b969cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333)) +- Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/)) + + ## [1.4.0] - 2021-07-27 ### Added From 0c9acb8e1e9694cb773cf7aa571b894050a3dac9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 8 Aug 2021 17:30:32 -0700 Subject: [PATCH 4/7] create separate test file for truncated bptt --- tests/models/test_cpu.py | 146 ------------------------- tests/models/test_truncated_bptt.py | 163 ++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 146 deletions(-) create mode 100644 tests/models/test_truncated_bptt.py diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 237c1726deafe..015c79458e1aa 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import pytest import torch import tests.helpers.pipelines as tpipes @@ -311,148 +310,3 @@ def test_all_features_cpu_model(tmpdir): model = BoringModel() tpipes.run_model_test(trainer_options, model, on_gpu=False, min_acc=0.01) - - -@pytest.mark.parametrize("n_hidden_states", [1, 2]) -def test_tbptt_cpu_model(tmpdir, n_hidden_states): - """Test truncated back propagation through time works.""" - truncated_bptt_steps = 2 - sequence_size = 30 - batch_size = 30 - - x_seq = torch.rand(batch_size, sequence_size, 1) - y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() - - class MockSeq2SeqDataset(torch.utils.data.Dataset): - def __getitem__(self, i): - return x_seq, y_seq_list - - def __len__(self): - return 1 - - class BpttTestModel(BoringModel): - def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs): - super().__init__(*args, **kwargs) - self.test_hidden = None - self.batch_size = batch_size - self.layer = torch.nn.Linear(in_features, out_features) - self.n_hidden_states = n_hidden_states - - def training_step(self, batch, batch_idx, hiddens): - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - if self.n_hidden_states == 1: - self.test_hidden = torch.rand(1) - else: - self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states) - - x_tensor, y_list = batch - assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" - - y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) - assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" - - pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) - return {"loss": loss_val, "hiddens": self.test_hidden} - - def training_epoch_end(self, training_step_outputs): - training_step_outputs = training_step_outputs[0] - assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps) - loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() - self.log("train_loss", loss) - - def train_dataloader(self): - return torch.utils.data.DataLoader( - dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None - ) - - model = BpttTestModel( - batch_size=batch_size, - in_features=truncated_bptt_steps, - out_features=truncated_bptt_steps, - n_hidden_states=n_hidden_states, - ) - model.example_input_array = torch.randn(5, truncated_bptt_steps) - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - truncated_bptt_steps=truncated_bptt_steps, - limit_val_batches=0, - weights_summary=None, - ) - trainer.fit(model) - assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" - - -def test_tbptt_cpu_model_lightning_module_property(tmpdir): - """Test truncated back propagation through time works when set as a property of the LightningModule.""" - truncated_bptt_steps = 2 - sequence_size = 30 - batch_size = 30 - - x_seq = torch.rand(batch_size, sequence_size, 1) - y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() - - class MockSeq2SeqDataset(torch.utils.data.Dataset): - - def __getitem__(self, i): - return x_seq, y_seq_list - - def __len__(self): - return 1 - - class BpttTestModel(BoringModel): - - def __init__(self, batch_size, in_features, out_features, *args, **kwargs): - super().__init__(*args, **kwargs) - self.test_hidden = None - self.batch_size = batch_size - self.layer = torch.nn.Linear(in_features, out_features) - self.truncated_bptt_steps = truncated_bptt_steps - - def training_step(self, batch, batch_idx, hiddens): - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - self.test_hidden = torch.rand(1) - - x_tensor, y_list = batch - assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" - - y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) - assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" - - pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) - return { - "loss": loss_val, - "hiddens": self.test_hidden, - } - - def training_epoch_end(self, training_step_outputs): - training_step_outputs = training_step_outputs[0] - assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps) - loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() - self.log("train_loss", loss) - - def train_dataloader(self): - return torch.utils.data.DataLoader( - dataset=MockSeq2SeqDataset(), - batch_size=batch_size, - shuffle=False, - sampler=None, - ) - - model = BpttTestModel(batch_size=batch_size, in_features=truncated_bptt_steps, out_features=truncated_bptt_steps) - model.example_input_array = torch.randn(5, truncated_bptt_steps) - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0, - weights_summary=None, - ) - trainer.fit(model) - - assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" diff --git a/tests/models/test_truncated_bptt.py b/tests/models/test_truncated_bptt.py new file mode 100644 index 0000000000000..dad19b485fdce --- /dev/null +++ b/tests/models/test_truncated_bptt.py @@ -0,0 +1,163 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel + + +@pytest.mark.parametrize("n_hidden_states", [1, 2]) +def test_tbptt_cpu_model(tmpdir, n_hidden_states): + """Test truncated back propagation through time works.""" + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class BpttTestModel(BoringModel): + def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs): + super().__init__(*args, **kwargs) + self.test_hidden = None + self.batch_size = batch_size + self.layer = torch.nn.Linear(in_features, out_features) + self.n_hidden_states = n_hidden_states + + def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + if self.n_hidden_states == 1: + self.test_hidden = torch.rand(1) + else: + self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) + return {"loss": loss_val, "hiddens": self.test_hidden} + + def training_epoch_end(self, training_step_outputs): + training_step_outputs = training_step_outputs[0] + assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps) + loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() + self.log("train_loss", loss) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None + ) + + model = BpttTestModel( + batch_size=batch_size, + in_features=truncated_bptt_steps, + out_features=truncated_bptt_steps, + n_hidden_states=n_hidden_states, + ) + model.example_input_array = torch.randn(5, truncated_bptt_steps) + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + truncated_bptt_steps=truncated_bptt_steps, + limit_val_batches=0, + weights_summary=None, + ) + trainer.fit(model) + assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" + + +@pytest.mark.parametrize("n_hidden_states", [1, 2]) +def test_tbptt_cpu_model_lightning_module_property(tmpdir, n_hidden_states): + """Test truncated back propagation through time works when set as a property of the LightningModule.""" + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class BpttTestModel(BoringModel): + def __init__(self, batch_size, in_features, out_features, *args, **kwargs): + super().__init__(*args, **kwargs) + self.test_hidden = None + self.batch_size = batch_size + self.layer = torch.nn.Linear(in_features, out_features) + self.truncated_bptt_steps = truncated_bptt_steps + + def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + self.test_hidden = torch.rand(1) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) + return { + "loss": loss_val, + "hiddens": self.test_hidden, + } + + def training_epoch_end(self, training_step_outputs): + training_step_outputs = training_step_outputs[0] + assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps) + loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() + self.log("train_loss", loss) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), + batch_size=batch_size, + shuffle=False, + sampler=None, + ) + + model = BpttTestModel(batch_size=batch_size, in_features=truncated_bptt_steps, out_features=truncated_bptt_steps) + model.example_input_array = torch.randn(5, truncated_bptt_steps) + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0, + weights_summary=None, + ) + trainer.fit(model) + + assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" From 88abb8af2cdcc059dbdd80d7db0f3bfecdfd1ece Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 8 Aug 2021 17:38:00 -0700 Subject: [PATCH 5/7] Update training_batch_loop.py --- pytorch_lightning/loops/batch/training_batch_loop.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index b32a4e9779a38..8b5268539aab7 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -451,12 +451,13 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]: Args: batch: the current batch to split """ - splits = [batch] tbptt_steps = self._truncated_bptt_steps() - if tbptt_steps > 0: - model_ref = self.trainer.lightning_module - with self.trainer.profiler.profile("tbptt_split_batch"): - splits = model_ref.tbptt_split_batch(batch, tbptt_steps) + if tbptt_steps == 0: + return [batch] + + model_ref = self.trainer.lightning_module + with self.trainer.profiler.profile("tbptt_split_batch"): + splits = model_ref.tbptt_split_batch(batch, tbptt_steps) return splits def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None: From c7c4478220983fd6aa5bd0d3a325e4ea6ced0d8e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 8 Aug 2021 17:53:32 -0700 Subject: [PATCH 6/7] Update test_truncated_bptt.py --- tests/models/test_truncated_bptt.py | 80 +++-------------------------- 1 file changed, 7 insertions(+), 73 deletions(-) diff --git a/tests/models/test_truncated_bptt.py b/tests/models/test_truncated_bptt.py index dad19b485fdce..b40d3b83aa9f1 100644 --- a/tests/models/test_truncated_bptt.py +++ b/tests/models/test_truncated_bptt.py @@ -19,8 +19,9 @@ from tests.helpers import BoringModel -@pytest.mark.parametrize("n_hidden_states", [1, 2]) -def test_tbptt_cpu_model(tmpdir, n_hidden_states): +@pytest.mark.parametrize("n_hidden_states", (1, 2)) +@pytest.mark.parametrize("property_on_module", (False, True)) +def test_tbptt_cpu_model(tmpdir, n_hidden_states, property_on_module): """Test truncated back propagation through time works.""" truncated_bptt_steps = 2 sequence_size = 30 @@ -43,6 +44,8 @@ def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args self.batch_size = batch_size self.layer = torch.nn.Linear(in_features, out_features) self.n_hidden_states = n_hidden_states + if property_on_module: + self.truncated_bptt_steps = truncated_bptt_steps def training_step(self, batch, batch_idx, hiddens): assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" @@ -80,84 +83,15 @@ def train_dataloader(self): ) model.example_input_array = torch.randn(5, truncated_bptt_steps) - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - truncated_bptt_steps=truncated_bptt_steps, - limit_val_batches=0, - weights_summary=None, - ) - trainer.fit(model) - assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" - - -@pytest.mark.parametrize("n_hidden_states", [1, 2]) -def test_tbptt_cpu_model_lightning_module_property(tmpdir, n_hidden_states): - """Test truncated back propagation through time works when set as a property of the LightningModule.""" - truncated_bptt_steps = 2 - sequence_size = 30 - batch_size = 30 - - x_seq = torch.rand(batch_size, sequence_size, 1) - y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() - - class MockSeq2SeqDataset(torch.utils.data.Dataset): - def __getitem__(self, i): - return x_seq, y_seq_list - - def __len__(self): - return 1 - - class BpttTestModel(BoringModel): - def __init__(self, batch_size, in_features, out_features, *args, **kwargs): - super().__init__(*args, **kwargs) - self.test_hidden = None - self.batch_size = batch_size - self.layer = torch.nn.Linear(in_features, out_features) - self.truncated_bptt_steps = truncated_bptt_steps - - def training_step(self, batch, batch_idx, hiddens): - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - self.test_hidden = torch.rand(1) - - x_tensor, y_list = batch - assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" - - y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) - assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" - - pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) - return { - "loss": loss_val, - "hiddens": self.test_hidden, - } - - def training_epoch_end(self, training_step_outputs): - training_step_outputs = training_step_outputs[0] - assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps) - loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() - self.log("train_loss", loss) - - def train_dataloader(self): - return torch.utils.data.DataLoader( - dataset=MockSeq2SeqDataset(), - batch_size=batch_size, - shuffle=False, - sampler=None, - ) - - model = BpttTestModel(batch_size=batch_size, in_features=truncated_bptt_steps, out_features=truncated_bptt_steps) - model.example_input_array = torch.randn(5, truncated_bptt_steps) + trainer_tbptt_steps = None if property_on_module else truncated_bptt_steps # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, + truncated_bptt_steps=trainer_tbptt_steps, limit_val_batches=0, weights_summary=None, ) trainer.fit(model) - assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" From 7887ad3e0f211fe20de052f4ab460529f41231a7 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 9 Aug 2021 11:21:49 -0700 Subject: [PATCH 7/7] move tbptt steps from test_train_loop_logging to new file --- tests/models/test_truncated_bptt.py | 70 +++++++++++++++++++ .../logging_/test_train_loop_logging.py | 70 ------------------- 2 files changed, 70 insertions(+), 70 deletions(-) diff --git a/tests/models/test_truncated_bptt.py b/tests/models/test_truncated_bptt.py index b40d3b83aa9f1..c454753e81151 100644 --- a/tests/models/test_truncated_bptt.py +++ b/tests/models/test_truncated_bptt.py @@ -95,3 +95,73 @@ def train_dataloader(self): ) trainer.fit(model) assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" + + +def test_tbptt_log(tmpdir): + truncated_bptt_steps = 2 + N, T, F = 32, 15, 1 # batches x timesteps (sequence size) x features + batch_size = 10 + assert T % truncated_bptt_steps != 0, "Should test leftover time steps" + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + def __init__(self): + self.x_seq = torch.randn(N, T, F) + self.y_seq = torch.randn(N, T, F) + + def __getitem__(self, index): + return self.x_seq[index], self.y_seq[index] + + def __len__(self): + return N + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.test_hidden = None + self.layer = torch.nn.LSTM(input_size=F, hidden_size=T, batch_first=True) + self.truncated_bptt_steps = truncated_bptt_steps + + def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + if hiddens is not None: + assert hiddens.grad_fn is None + split_idx = self.trainer.fit_loop.split_idx + self.test_hidden = torch.tensor(split_idx, requires_grad=True, dtype=torch.float).pow(2) + + x, y = batch + if self.trainer.fit_loop.epoch_loop.batch_loop.done: + # last split idx, not aligned + assert x.shape[1] == T % truncated_bptt_steps + assert y.shape[1] == T % truncated_bptt_steps + else: + assert x.shape[1] == truncated_bptt_steps + assert y.shape[1] == truncated_bptt_steps + + pred, _ = self(x) + loss = torch.nn.functional.mse_loss(pred, y) + + self.log("a", loss, on_epoch=True) + + return {"loss": loss, "hiddens": self.test_hidden} + + def on_train_batch_start(self, *args, **kwargs) -> None: + self.test_hidden = None + + def train_dataloader(self): + return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size) + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_val_batches=0, + max_epochs=2, + log_every_n_steps=2, + weights_summary=None, + ) + trainer.fit(model) + + assert trainer.fit_loop.batch_idx == N // batch_size + assert trainer.fit_loop.split_idx == T // truncated_bptt_steps + assert set(trainer.logged_metrics) == {"a_step", "a_epoch", "epoch"} diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 976f043201650..b218df9e3b15d 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -212,76 +212,6 @@ def validation_step(self, batch, batch_idx): assert trainer.logged_metrics["bar"] == result -def test_tbptt_log(tmpdir): - truncated_bptt_steps = 2 - N, T, F = 32, 15, 1 # batches x timesteps (sequence size) x features - batch_size = 10 - assert T % truncated_bptt_steps != 0, "Should test leftover time steps" - - class MockSeq2SeqDataset(torch.utils.data.Dataset): - def __init__(self): - self.x_seq = torch.randn(N, T, F) - self.y_seq = torch.randn(N, T, F) - - def __getitem__(self, index): - return self.x_seq[index], self.y_seq[index] - - def __len__(self): - return N - - class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.test_hidden = None - self.layer = torch.nn.LSTM(input_size=F, hidden_size=T, batch_first=True) - - def training_step(self, batch, batch_idx, hiddens): - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - if hiddens is not None: - assert hiddens.grad_fn is None - split_idx = self.trainer.fit_loop.split_idx - self.test_hidden = torch.tensor(split_idx, requires_grad=True, dtype=torch.float).pow(2) - - x, y = batch - if self.trainer.fit_loop.epoch_loop.batch_loop.done: - # last split idx, not aligned - assert x.shape[1] == T % truncated_bptt_steps - assert y.shape[1] == T % truncated_bptt_steps - else: - assert x.shape[1] == truncated_bptt_steps - assert y.shape[1] == truncated_bptt_steps - - pred, _ = self(x) - loss = torch.nn.functional.mse_loss(pred, y) - - self.log("a", loss, on_epoch=True) - - return {"loss": loss, "hiddens": self.test_hidden} - - def on_train_batch_start(self, *args, **kwargs) -> None: - self.test_hidden = None - - def train_dataloader(self): - return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size) - - model = TestModel() - model.training_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_val_batches=0, - truncated_bptt_steps=truncated_bptt_steps, - max_epochs=2, - log_every_n_steps=2, - weights_summary=None, - ) - trainer.fit(model) - - assert trainer.fit_loop.batch_idx == N // batch_size - assert trainer.fit_loop.split_idx == T // truncated_bptt_steps - assert set(trainer.logged_metrics) == {"a_step", "a_epoch", "epoch"} - - def test_different_batch_types_for_sizing(tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx):