Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Distributed training with gradient accumulation (#5100)
Browse files Browse the repository at this point in the history
* Fixes distributed training with gradient accumulation

* Fix in case we don't do anything in a batch group

* Test for the problematic condition

* Formatting

* More formatting

* Changelog

* Fix another test

* Fix even more tests

* Fixes one more test

* I can fix these tests all day.
  • Loading branch information
dirkgr committed Apr 8, 2021
1 parent fe2d6e5 commit de61100
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed a bug where some `Activation` implementations could not be pickled due to involving a lambda function.
- Fixed `__str__()` method on `ModelCardInfo` class.
- Fixed a stall when using distributed training and gradient accumulation at the same time
- Fixed an issue where using the `from_pretrained_transformer` `Vocabulary` constructor in distributed training via the `allennlp train` command
would result in the data being iterated through unnecessarily.

Expand Down
41 changes: 23 additions & 18 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,24 +466,8 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:

done_early = False
for batch_group in batch_group_generator_tqdm:
if self._distributed:
# Check whether the other workers have stopped already (due to differing amounts of
# data in each). If so, we can't proceed because we would hang when we hit the
# barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
# here because NCCL process groups apparently don't support BoolTensor.
done = torch.tensor(0, device=self.cuda_device)
torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
if done.item() > 0:
done_early = True
logger.warning(
f"Worker {torch.distributed.get_rank()} finishing training early! "
"This implies that there is an imbalance in your training "
"data across the workers and that some amount of it will be "
"ignored. A small amount of this is fine, but a major imbalance "
"should be avoided. Note: This warning will appear unless your "
"data is perfectly balanced."
)
break
if done_early:
break

batches_this_epoch += 1
self._batch_num_total += 1
Expand All @@ -499,6 +483,25 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
batch_loss = 0.0
batch_group_outputs = []
for batch in batch_group:
if self._distributed:
# Check whether the other workers have stopped already (due to differing amounts of
# data in each). If so, we can't proceed because we would hang when we hit the
# barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
# here because NCCL process groups apparently don't support BoolTensor.
done = torch.tensor(0, device=self.cuda_device)
torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
if done.item() > 0:
done_early = True
logger.warning(
f"Worker {torch.distributed.get_rank()} finishing training early! "
"This implies that there is an imbalance in your training "
"data across the workers and that some amount of it will be "
"ignored. A small amount of this is fine, but a major imbalance "
"should be avoided. Note: This warning will appear unless your "
"data is perfectly balanced."
)
break

with amp.autocast(self._use_amp):
batch_outputs = self.batch_outputs(batch, for_training=True)
batch_group_outputs.append(batch_outputs)
Expand All @@ -518,6 +521,8 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
self._scaler.scale(loss).backward()
else:
loss.backward()
if len(batch_group_outputs) <= 0:
continue

train_loss += batch_loss

Expand Down
1 change: 1 addition & 0 deletions test_fixtures/data/sequence_tagging.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ cats###N are###V animals###N .###N
dogs###N are###V animals###N .###N
snakes###N are###V animals###N .###N
birds###N are###V animals###N .###N
horses###N are###V animals###N .###N
2 changes: 1 addition & 1 deletion tests/commands/no_op_train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_train_model(self):
params = lambda: Params(
{
"model": {"type": "constant"},
"dataset_reader": {"type": "sequence_tagging"},
"dataset_reader": {"type": "sequence_tagging", "max_instances": 4},
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
"data_loader": {"batch_size": 2},
Expand Down
70 changes: 68 additions & 2 deletions tests/commands/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,62 @@ def test_train_model_distributed(self):
# Check we can load the serialized model
assert load_archive(out_dir).model

@pytest.mark.parametrize("max_instances", [1, 2, 3, 4, None])
@pytest.mark.parametrize("grad_acc", [None, 2])
@pytest.mark.parametrize("batch_size", [1, 2, 3])
def test_train_model_distributed_with_gradient_accumulation(
self, max_instances, grad_acc, batch_size
):
if torch.cuda.device_count() >= 2:
devices = [0, 1]
else:
devices = [-1, -1]

params = lambda: Params(
{
"model": {
"type": "simple_tagger",
"text_field_embedder": {
"token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}}
},
"encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
},
"dataset_reader": {"type": "sequence_tagging", "max_instances": max_instances},
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
"data_loader": {"batch_size": batch_size},
"trainer": {
"num_epochs": 2,
"optimizer": "adam",
"num_gradient_accumulation_steps": grad_acc,
},
"distributed": {"cuda_devices": devices},
}
)

out_dir = os.path.join(self.TEST_DIR, "test_distributed_train_with_grad_acc")
train_model(params(), serialization_dir=out_dir)

# Check that some logs specific to distributed
# training are where we expect.
serialized_files = os.listdir(out_dir)
assert "out_worker0.log" in serialized_files
assert "out_worker1.log" in serialized_files
assert "model.tar.gz" in serialized_files
assert "metrics.json" in serialized_files

# Make sure the metrics look right.
with open(os.path.join(out_dir, "metrics.json")) as f:
metrics = json.load(f)
assert metrics["peak_worker_0_memory_MB"] > 0
assert metrics["peak_worker_1_memory_MB"] > 0
if torch.cuda.device_count() >= 2:
assert metrics["peak_gpu_0_memory_MB"] > 0
assert metrics["peak_gpu_1_memory_MB"] > 0

# Check we can load the serialized model
assert load_archive(out_dir).model

@cpu_or_gpu
@pytest.mark.parametrize("max_instances_in_memory", [None, 10])
def test_train_model_distributed_with_sharded_reader(self, max_instances_in_memory):
Expand Down Expand Up @@ -343,7 +399,7 @@ def test_train_model_distributed_without_sharded_reader(self, max_instances_in_m
},
"encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
},
"dataset_reader": {"type": "sequence_tagging"},
"dataset_reader": {"type": "sequence_tagging", "max_instances": 4},
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
"data_loader": {
Expand Down Expand Up @@ -731,7 +787,17 @@ def test_dry_run_makes_vocab(self):
tokens = [line.strip() for line in f]

tokens.sort()
assert tokens == [".", "@@UNKNOWN@@", "animals", "are", "birds", "cats", "dogs", "snakes"]
assert tokens == [
".",
"@@UNKNOWN@@",
"animals",
"are",
"birds",
"cats",
"dogs",
"horses",
"snakes",
]

with open(vocab_path / "labels.txt") as f:
labels = [line.strip() for line in f]
Expand Down
2 changes: 1 addition & 1 deletion tests/data/dataset_readers/sequence_tagging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class TestSequenceTaggingDatasetReader:
def test_default_format(self):
reader = SequenceTaggingDatasetReader()
reader = SequenceTaggingDatasetReader(max_instances=4)
instances = list(
reader.read(AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv")
)
Expand Down
2 changes: 1 addition & 1 deletion tests/data/dataset_readers/sharded_dataset_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setup_method(self) -> None:
super().setup_method()

# use SequenceTaggingDatasetReader as the base reader
self.base_reader = SequenceTaggingDatasetReader()
self.base_reader = SequenceTaggingDatasetReader(max_instances=4)
base_file_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"

# Make 100 copies of the data
Expand Down
4 changes: 1 addition & 3 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class TrainerTestBase(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
self.data_path = str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv")
self.reader = SequenceTaggingDatasetReader()
self.reader = SequenceTaggingDatasetReader(max_instances=4)
self.data_loader = MultiProcessDataLoader(self.reader, self.data_path, batch_size=2)
self.data_loader_lazy = MultiProcessDataLoader(
self.reader, self.data_path, batch_size=2, max_instances_in_memory=10
Expand Down Expand Up @@ -815,7 +815,6 @@ def test_trainer_can_log_learning_rates_tensorboard(self):
trainer.train()

def test_sanity_check_callback(self):

model_with_bias = FakeModelForTestingNormalizationBiasVerification(use_bias=True)
inst = Instance({"x": TensorField(torch.rand(3, 1, 4))})
data_loader = SimpleDataLoader([inst, inst], 2)
Expand Down Expand Up @@ -1201,7 +1200,6 @@ def test_trainer_can_log_batch_inputs(self):
trainer.train()

def test_console_log_callback(self):

total_instances = 1000
batch_size = 25

Expand Down

0 comments on commit de61100

Please sign in to comment.