Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge branch 'master' into vision
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jul 20, 2020
2 parents f87df83 + 478bf46 commit 6cc508d
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 42 deletions.
10 changes: 8 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Removed unnecessary warning about deadlocks in `DataLoader`.
- Use slower tqdm intervals when output is being piped or redirected.


## [v1.1.0rc1](https://github.com/allenai/allennlp/releases/tag/v1.1.0rc1) - 2020-07-14

### Fixed
Expand All @@ -31,7 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
in case it does not have a tokenizer.
- `reg_loss` is only now returned for models that have some regularization penalty configured.
- Fixed a bug that prevented `cached_path` from downloading assets from GitHub releases.
- Fixed a bug that erronously increased last label's false positive count in calculating fbeta metrics.
- Fixed a bug that erroneously increased last label's false positive count in calculating fbeta metrics.
- `Tqdm` output now looks much better when the output is being piped or redirected.
- Small improvements to how the API documentation is rendered.
- Only show validation progress bar from main process in distributed training.
Expand All @@ -50,7 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
scalar mix of all hidden layers from the transformer model instead of just the last layer. To utilize
this, just set `last_layer_only` to `False`.
- `cached_path()` can now read files inside of archives.
- Training metrics now include per-batch loss in addition to aggregate loss across number of batches.
- Training metrics now include `batch_loss` and `batch_reg_loss` in addition to aggregate loss across number of batches.

### Changed

Expand Down
4 changes: 3 additions & 1 deletion allennlp/common/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def flush(self):
class Tqdm:
@staticmethod
def tqdm(*args, **kwargs):
new_kwargs = {"file": TqdmToLogsWriter(), **kwargs}
# Use a slow interval if the output is being piped or redirected.
default_mininterval = 0.1 if sys.stderr.isatty() else 10.0
new_kwargs = {"file": TqdmToLogsWriter(), "mininterval": default_mininterval, **kwargs}

return _tqdm(*args, **new_kwargs)
9 changes: 0 additions & 9 deletions allennlp/data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import List, Dict, Union, Iterator
import warnings

import torch
from torch.utils import data

from allennlp.common.registrable import Registrable
from allennlp.common.lazy import Lazy
from allennlp.data.instance import Instance
from allennlp.data.dataset_readers.dataset_reader import AllennlpLazyDataset
from allennlp.data.batch import Batch
from allennlp.data.samplers import Sampler, BatchSampler

Expand Down Expand Up @@ -87,13 +85,6 @@ def __init__(
multiprocessing_context: str = None,
batches_per_epoch: int = None,
):
if num_workers and isinstance(dataset, AllennlpLazyDataset):
warnings.warn(
"Using multi-process data loading with a lazy dataset could lead to "
"deadlocks with certain tokenizers. See:\n"
" https://github.com/allenai/allennlp/issues/4330\n",
UserWarning,
)
super().__init__(
dataset=dataset,
batch_size=batch_size,
Expand Down
30 changes: 24 additions & 6 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,11 +507,14 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
regularization_penalty = self.model.get_regularization_penalty()

train_loss = 0.0
batch_loss = 0.0

if regularization_penalty is not None:
train_reg_loss = 0.0
batch_reg_loss = 0.0
else:
train_reg_loss = None
batch_reg_loss = None
# Set the model to "train" mode.
self._pytorch_model.train()

Expand Down Expand Up @@ -588,10 +591,12 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
scaled_loss.backward()
else:
loss.backward()
train_loss += loss.item()
batch_loss = loss.item()
train_loss += batch_loss
if reg_loss is not None:
reg_loss = reg_loss / len(batch_group)
train_reg_loss += reg_loss.item()
batch_reg_loss = reg_loss.item()
train_reg_loss += batch_reg_loss

batch_grad_norm = self.rescale_gradients()

Expand Down Expand Up @@ -627,6 +632,8 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
self.model,
train_loss,
train_reg_loss,
batch_loss,
batch_reg_loss,
batches_this_epoch,
world_size=self._world_size,
cuda_device=self.cuda_device,
Expand Down Expand Up @@ -675,7 +682,9 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
self.model,
train_loss,
train_reg_loss,
batches_this_epoch,
batch_loss=None,
batch_reg_loss=None,
num_batches=batches_this_epoch,
reset=True,
world_size=self._world_size,
cuda_device=self.cuda_device,
Expand Down Expand Up @@ -717,10 +726,13 @@ def _validation_loss(self, epoch: int) -> Tuple[float, float, int]:

batches_this_epoch = 0
val_loss = 0
val_batch_loss = 0
if regularization_penalty is not None:
val_reg_loss = 0
val_batch_reg_loss = 0
else:
val_reg_loss = None
val_batch_reg_loss = None
done_early = False
for batch in val_generator_tqdm:
if self._distributed:
Expand Down Expand Up @@ -752,15 +764,19 @@ def _validation_loss(self, epoch: int) -> Tuple[float, float, int]:
# count those batches for which we actually have a loss. If this variable ever
# gets used for something else, we might need to change things around a bit.
batches_this_epoch += 1
val_loss += loss.detach().cpu().numpy()
val_batch_loss = loss.detach().cpu().numpy()
val_loss += val_batch_loss
if reg_loss is not None:
val_reg_loss += reg_loss.detach().cpu().numpy()
val_batch_reg_loss = reg_loss.detach().cpu().numpy()
val_reg_loss += val_batch_reg_loss

# Update the description with the latest metrics
val_metrics = training_util.get_metrics(
self.model,
val_loss,
val_reg_loss,
val_batch_loss,
val_batch_reg_loss,
batches_this_epoch,
world_size=self._world_size,
cuda_device=self.cuda_device,
Expand Down Expand Up @@ -852,7 +868,9 @@ def train(self) -> Dict[str, Any]:
self.model,
val_loss,
val_reg_loss,
num_batches,
batch_loss=None,
batch_reg_loss=None,
num_batches=num_batches,
reset=True,
world_size=self._world_size,
cuda_device=self.cuda_device,
Expand Down
10 changes: 6 additions & 4 deletions allennlp/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def get_metrics(
model: Model,
total_loss: float,
total_reg_loss: Optional[float],
batch_loss: Optional[float],
batch_reg_loss: Optional[float],
num_batches: int,
reset: bool = False,
world_size: int = 1,
Expand All @@ -285,12 +287,12 @@ def get_metrics(
Returns the `"batch_loss"` separately.
"""
metrics = model.get_metrics(reset=reset)
if not reset:
metrics["batch_loss"] = total_loss
if batch_loss is not None:
metrics["batch_loss"] = batch_loss
metrics["loss"] = float(total_loss / num_batches) if num_batches > 0 else 0.0
if total_reg_loss is not None:
if not reset:
metrics["batch_reg_loss"] = total_reg_loss
if batch_reg_loss is not None:
metrics["batch_reg_loss"] = batch_reg_loss
metrics["reg_loss"] = float(total_reg_loss / num_batches) if num_batches > 0 else 0.0

if world_size > 1:
Expand Down
15 changes: 1 addition & 14 deletions tests/data/dataloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,7 @@
from allennlp.data.fields import LabelField
from allennlp.data.instance import Instance
from allennlp.data.dataloader import PyTorchDataLoader
from allennlp.data.dataset_readers.dataset_reader import (
DatasetReader,
AllennlpLazyDataset,
)


def test_multi_processing_with_lazy_dataset_warns():
def fake_instance_generator(file_name: str) -> Iterable[Instance]:
yield from []

with pytest.warns(UserWarning, match=r".*deadlocks.*"):
PyTorchDataLoader(
AllennlpLazyDataset(fake_instance_generator, "nonexistent_file"), num_workers=1
)
from allennlp.data.dataset_readers.dataset_reader import DatasetReader


@pytest.mark.parametrize("lazy", (True, False))
Expand Down
37 changes: 37 additions & 0 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,43 @@ def __call__(
expected_calls = [epoch for epoch in range(-1, 4)]
assert trainer.epoch_callback_calls == expected_calls

def test_total_loss_is_average_of_batch_loss(self):

batches_per_epoch = 3

data_loader_custom_epoch_lazy = PyTorchDataLoader(
self.instances_lazy,
batch_size=2,
collate_fn=allennlp_collate,
batches_per_epoch=batches_per_epoch,
)

class FakeBatchCallback(BatchCallback):
def __call__(
self,
trainer: "GradientDescentTrainer",
batch_inputs: List[List[TensorDict]],
batch_outputs: List[Dict[str, Any]],
epoch: int,
batch_number: int,
is_training: bool,
is_master: bool,
) -> None:
if not hasattr(trainer, "batch_losses"):
trainer.batch_losses = [] # type: ignore
trainer.batch_losses.append(batch_outputs[0]["loss"].item()) # type: ignore

trainer = GradientDescentTrainer(
self.model,
self.optimizer,
data_loader_custom_epoch_lazy,
num_epochs=1,
batch_callbacks=[FakeBatchCallback()],
)
metrics = trainer.train()

assert metrics["training_loss"] == float(sum(trainer.batch_losses) / batches_per_epoch)


class TestApexTrainer(TrainerTestBase):
@requires_gpu
Expand Down
13 changes: 7 additions & 6 deletions tests/training/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,15 @@ def forward(self, **kwargs):
return {}

model = FakeModel(None)
loss = 10.0
total_loss = 100.0
batch_loss = 10.0
num_batches = 2
metrics = get_metrics(model, loss, None, num_batches)
metrics = get_metrics(model, total_loss, None, batch_loss, None, num_batches)

assert metrics["loss"] == float(loss / num_batches)
assert metrics["batch_loss"] == loss
assert metrics["loss"] == float(total_loss / num_batches)
assert metrics["batch_loss"] == batch_loss

metrics = get_metrics(model, loss, None, num_batches, reset=True)
metrics = get_metrics(model, total_loss, None, None, None, num_batches)

assert metrics["loss"] == float(loss / num_batches)
assert metrics["loss"] == float(total_loss / num_batches)
assert "batch_loss" not in metrics

0 comments on commit 6cc508d

Please sign in to comment.