Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))


- Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/))


## [1.4.0] - 2021-07-27

### Added
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,11 +451,13 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]:
Args:
batch: the current batch to split
"""
splits = [batch]
if self.trainer.truncated_bptt_steps is not None:
model_ref = self.trainer.lightning_module
with self.trainer.profiler.profile("tbptt_split_batch"):
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
tbptt_steps = self._truncated_bptt_steps()
if tbptt_steps == 0:
return [batch]

model_ref = self.trainer.lightning_module
with self.trainer.profiler.profile("tbptt_split_batch"):
splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
return splits

def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
Expand Down
74 changes: 0 additions & 74 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import os

import pytest
import torch

import tests.helpers.pipelines as tpipes
Expand Down Expand Up @@ -311,76 +310,3 @@ def test_all_features_cpu_model(tmpdir):
model = BoringModel()

tpipes.run_model_test(trainer_options, model, on_gpu=False, min_acc=0.01)


@pytest.mark.parametrize("n_hidden_states", [1, 2])
def test_tbptt_cpu_model(tmpdir, n_hidden_states):
"""Test truncated back propagation through time works."""
truncated_bptt_steps = 2
sequence_size = 30
batch_size = 30

x_seq = torch.rand(batch_size, sequence_size, 1)
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()

class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
return x_seq, y_seq_list

def __len__(self):
return 1

class BpttTestModel(BoringModel):
def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs):
super().__init__(*args, **kwargs)
self.test_hidden = None
self.batch_size = batch_size
self.layer = torch.nn.Linear(in_features, out_features)
self.n_hidden_states = n_hidden_states

def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
if self.n_hidden_states == 1:
self.test_hidden = torch.rand(1)
else:
self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states)

x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"

y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"

pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps))
return {"loss": loss_val, "hiddens": self.test_hidden}

def training_epoch_end(self, training_step_outputs):
training_step_outputs = training_step_outputs[0]
assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps)
loss = torch.stack([x["loss"] for x in training_step_outputs]).mean()
self.log("train_loss", loss)

def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None
)

model = BpttTestModel(
batch_size=batch_size,
in_features=truncated_bptt_steps,
out_features=truncated_bptt_steps,
n_hidden_states=n_hidden_states,
)
model.example_input_array = torch.randn(5, truncated_bptt_steps)

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
truncated_bptt_steps=truncated_bptt_steps,
limit_val_batches=0,
weights_summary=None,
)
trainer.fit(model)
assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}"
167 changes: 167 additions & 0 deletions tests/models/test_truncated_bptt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from pytorch_lightning import Trainer
from tests.helpers import BoringModel


@pytest.mark.parametrize("n_hidden_states", (1, 2))
@pytest.mark.parametrize("property_on_module", (False, True))
def test_tbptt_cpu_model(tmpdir, n_hidden_states, property_on_module):
"""Test truncated back propagation through time works."""
truncated_bptt_steps = 2
sequence_size = 30
batch_size = 30

x_seq = torch.rand(batch_size, sequence_size, 1)
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()

class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
return x_seq, y_seq_list

def __len__(self):
return 1

class BpttTestModel(BoringModel):
def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs):
super().__init__(*args, **kwargs)
self.test_hidden = None
self.batch_size = batch_size
self.layer = torch.nn.Linear(in_features, out_features)
self.n_hidden_states = n_hidden_states
if property_on_module:
self.truncated_bptt_steps = truncated_bptt_steps

def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
if self.n_hidden_states == 1:
self.test_hidden = torch.rand(1)
else:
self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states)

x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"

y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"

pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps))
return {"loss": loss_val, "hiddens": self.test_hidden}

def training_epoch_end(self, training_step_outputs):
training_step_outputs = training_step_outputs[0]
assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps)
loss = torch.stack([x["loss"] for x in training_step_outputs]).mean()
self.log("train_loss", loss)

def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None
)

model = BpttTestModel(
batch_size=batch_size,
in_features=truncated_bptt_steps,
out_features=truncated_bptt_steps,
n_hidden_states=n_hidden_states,
)
model.example_input_array = torch.randn(5, truncated_bptt_steps)

trainer_tbptt_steps = None if property_on_module else truncated_bptt_steps

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
truncated_bptt_steps=trainer_tbptt_steps,
limit_val_batches=0,
weights_summary=None,
)
trainer.fit(model)
assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}"


def test_tbptt_log(tmpdir):
truncated_bptt_steps = 2
N, T, F = 32, 15, 1 # batches x timesteps (sequence size) x features
batch_size = 10
assert T % truncated_bptt_steps != 0, "Should test leftover time steps"

class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __init__(self):
self.x_seq = torch.randn(N, T, F)
self.y_seq = torch.randn(N, T, F)

def __getitem__(self, index):
return self.x_seq[index], self.y_seq[index]

def __len__(self):
return N

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.test_hidden = None
self.layer = torch.nn.LSTM(input_size=F, hidden_size=T, batch_first=True)
self.truncated_bptt_steps = truncated_bptt_steps

def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
if hiddens is not None:
assert hiddens.grad_fn is None
split_idx = self.trainer.fit_loop.split_idx
self.test_hidden = torch.tensor(split_idx, requires_grad=True, dtype=torch.float).pow(2)

x, y = batch
if self.trainer.fit_loop.epoch_loop.batch_loop.done:
# last split idx, not aligned
assert x.shape[1] == T % truncated_bptt_steps
assert y.shape[1] == T % truncated_bptt_steps
else:
assert x.shape[1] == truncated_bptt_steps
assert y.shape[1] == truncated_bptt_steps

pred, _ = self(x)
loss = torch.nn.functional.mse_loss(pred, y)

self.log("a", loss, on_epoch=True)

return {"loss": loss, "hiddens": self.test_hidden}

def on_train_batch_start(self, *args, **kwargs) -> None:
self.test_hidden = None

def train_dataloader(self):
return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_val_batches=0,
max_epochs=2,
log_every_n_steps=2,
weights_summary=None,
)
trainer.fit(model)

assert trainer.fit_loop.batch_idx == N // batch_size
assert trainer.fit_loop.split_idx == T // truncated_bptt_steps
assert set(trainer.logged_metrics) == {"a_step", "a_epoch", "epoch"}
70 changes: 0 additions & 70 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,76 +212,6 @@ def validation_step(self, batch, batch_idx):
assert trainer.logged_metrics["bar"] == result


def test_tbptt_log(tmpdir):
truncated_bptt_steps = 2
N, T, F = 32, 15, 1 # batches x timesteps (sequence size) x features
batch_size = 10
assert T % truncated_bptt_steps != 0, "Should test leftover time steps"

class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __init__(self):
self.x_seq = torch.randn(N, T, F)
self.y_seq = torch.randn(N, T, F)

def __getitem__(self, index):
return self.x_seq[index], self.y_seq[index]

def __len__(self):
return N

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.test_hidden = None
self.layer = torch.nn.LSTM(input_size=F, hidden_size=T, batch_first=True)

def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
if hiddens is not None:
assert hiddens.grad_fn is None
split_idx = self.trainer.fit_loop.split_idx
self.test_hidden = torch.tensor(split_idx, requires_grad=True, dtype=torch.float).pow(2)

x, y = batch
if self.trainer.fit_loop.epoch_loop.batch_loop.done:
# last split idx, not aligned
assert x.shape[1] == T % truncated_bptt_steps
assert y.shape[1] == T % truncated_bptt_steps
else:
assert x.shape[1] == truncated_bptt_steps
assert y.shape[1] == truncated_bptt_steps

pred, _ = self(x)
loss = torch.nn.functional.mse_loss(pred, y)

self.log("a", loss, on_epoch=True)

return {"loss": loss, "hiddens": self.test_hidden}

def on_train_batch_start(self, *args, **kwargs) -> None:
self.test_hidden = None

def train_dataloader(self):
return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_val_batches=0,
truncated_bptt_steps=truncated_bptt_steps,
max_epochs=2,
log_every_n_steps=2,
weights_summary=None,
)
trainer.fit(model)

assert trainer.fit_loop.batch_idx == N // batch_size
assert trainer.fit_loop.split_idx == T // truncated_bptt_steps
assert set(trainer.logged_metrics) == {"a_step", "a_epoch", "epoch"}


def test_different_batch_types_for_sizing(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
Expand Down