Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 20, 2021
2 parents 5497394 + ed322eb commit 412896b
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 23 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ jobs:
runs-on: [self-hosted, GPU]
strategy:
matrix:
cuda: ['10.2', '11.0']
cuda: ['10.1', '10.2', '11.0']

steps:
- uses: actions/checkout@v2
Expand All @@ -291,8 +291,10 @@ jobs:
CUDA: ${{ matrix.cuda }}
run: |
# Check the install instructions on https://pytorch.org/ to keep these up-to-date.
if [[ $CUDA == '10.2' ]]; then
echo "DOCKER_TORCH_VERSION='torch==1.7.1 torchvision==0.8.2'" >> $GITHUB_ENV;
if [[ $CUDA == '10.1' ]]; then
echo "DOCKER_TORCH_VERSION='torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html'" >> $GITHUB_ENV;
elif [[ $CUDA == '10.2' ]]; then
echo "DOCKER_TORCH_VERSION='torch==1.7.1'" >> $GITHUB_ENV;
elif [[ $CUDA == '11.0' ]]; then
echo "DOCKER_TORCH_VERSION='torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html'" >> $GITHUB_ENV;
else
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
instead of an exception when lacking write permissions on an existing file lock.
This makes it possible to use the `FileLock` class on a read-only file system.
- Added a new learning rate scheduler: `CombinedLearningRateScheduler`. This can be used to combine different LR schedulers, using one after the other.
- Added an official CUDA 10.1 Docker image.
- Moving `ModelCard` and `TaskCard` abstractions into the main repository.
- Added a util function `allennlp.nn.util.dist_reduce(...)` for handling distributed reductions.
This is especially useful when implementing a distributed `Metric`.

### Changed

Expand Down
37 changes: 35 additions & 2 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
import math
import numpy
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp

from allennlp.common.checks import ConfigurationError
from allennlp.common.util import int_to_device, is_distributed

logger = logging.getLogger(__name__)

Expand All @@ -25,8 +28,6 @@ def move_to_device(obj, device: Union[torch.device, int]):
move all the Tensors to the specified device (or do nothing, if they are already on
the target device).
"""
from allennlp.common.util import int_to_device

device = int_to_device(device)

if isinstance(obj, torch.Tensor):
Expand Down Expand Up @@ -2014,3 +2015,35 @@ def tiny_value_of_dtype(dtype: torch.dtype):
return 1e-4
else:
raise TypeError("Does not support dtype " + str(dtype))


_V = TypeVar("_V", int, float)


def dist_reduce(value: _V, reduce_op: ReduceOp, **kwargs) -> _V:
"""
Reduces the given `value` across all distributed worker nodes according the given
reduction operation.
If called outside of a distributed context, it will just return `value`.
# Parameters
value : `_V`
The value to reduce across distributed nodes.
reduce_op : `ReduceOp`
The reduction operation to use.
**kwargs : `Any`
Additional arguments used to construct the tensor that will wrap `value`.
# Returns
`_V`
The final value.
"""
if not is_distributed():
return value
device = int_to_device(-1 if dist.get_backend() != "nccl" else torch.cuda.current_device())
value_tensor = torch.tensor(value, device=device, **kwargs)
dist.all_reduce(value_tensor, op=reduce_op)
return value_tensor.item() # type: ignore[return-value]
19 changes: 3 additions & 16 deletions allennlp/training/metrics/average.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from overrides import overrides

import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed
from allennlp.training.metrics.metric import Metric
from allennlp.nn.util import dist_reduce, ReduceOp


@Metric.register("average")
Expand All @@ -28,18 +25,8 @@ def __call__(self, value):
value : `float`
The value to average.
"""
_total_value = list(self.detach_tensors(value))[0]
_count = 1
if is_distributed():
device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu")
count = torch.tensor(_count, device=device)
total_value = torch.tensor(_total_value, device=device)
dist.all_reduce(count, op=dist.ReduceOp.SUM)
dist.all_reduce(total_value, op=dist.ReduceOp.SUM)
_count = count.item()
_total_value = total_value.item()
self._count += _count
self._total_value += _total_value
self._count += dist_reduce(1, ReduceOp.SUM)
self._total_value += dist_reduce(float(list(self.detach_tensors(value))[0]), ReduceOp.SUM)

@overrides
def get_metric(self, reset: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ nr.databind.core<0.0.17
nr.interface<0.0.4

mkdocs==1.1.2
mkdocs-material>=5.5.0,<6.2.0
mkdocs-material>=5.5.0,<6.3.0
markdown-include==0.6.0

#### PACKAGE-UPLOAD PACKAGES ####
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"scikit-learn",
"scipy",
"pytest",
"transformers>=4.0,<4.1",
"transformers>=4.1,<4.3",
"sentencepiece",
"jsonpickle",
"dataclasses;python_version<'3.7'",
Expand Down

0 comments on commit 412896b

Please sign in to comment.