Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

strip out old DP stuff, ensure multiple cuda devices raises errors #3516

Merged
merged 15 commits into from
Dec 13, 2019
12 changes: 5 additions & 7 deletions allennlp/commands/find_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import Params, Tqdm
from allennlp.common.util import prepare_environment, lazy_groups_of
from allennlp.common.util import prepare_environment
from allennlp.data import Vocabulary, DataIterator
from allennlp.models import Model
from allennlp.training import Trainer
Expand Down Expand Up @@ -223,6 +223,7 @@ def find_learning_rate_model(
train_data = all_datasets["train"]

trainer_params = params.pop("trainer")

no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
Expand Down Expand Up @@ -296,10 +297,7 @@ def search_learning_rate(

trainer.model.train()

num_gpus = len(trainer._cuda_devices)

raw_train_generator = trainer.iterator(trainer.train_data, shuffle=trainer.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
train_generator = trainer.iterator(trainer.train_data, shuffle=trainer.shuffle)
train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_batches)

learning_rates = []
Expand All @@ -310,7 +308,7 @@ def search_learning_rate(
else:
lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches)

for i, batch_group in enumerate(train_generator_tqdm):
for i, batch in enumerate(train_generator_tqdm):

if linear_steps:
current_lr = start_lr + (lr_update_factor * i)
Expand All @@ -321,7 +319,7 @@ def search_learning_rate(
param_group["lr"] = current_lr

trainer.optimizer.zero_grad()
loss = trainer.batch_loss(batch_group, for_training=True)
loss = trainer.batch_loss(batch, for_training=True)
loss.backward()
loss = loss.detach().cpu().item()

Expand Down
2 changes: 1 addition & 1 deletion allennlp/commands/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def fine_tune_model(
model,
test_data,
validation_iterator or iterator,
cuda_device=trainer._cuda_devices[0],
cuda_device=trainer.cuda_device,
batch_weight_key=batch_weight_key,
)

Expand Down
52 changes: 27 additions & 25 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from allennlp.commands.make_vocab import make_vocab_from_params
from allennlp.commands.subcommand import Subcommand
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError, check_for_gpu, parse_cuda_device
from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common.util import (
prepare_environment,
prepare_global_logging,
Expand Down Expand Up @@ -269,11 +269,10 @@ def train_model(
create_serialization_dir(params, serialization_dir, recover, force)
params.to_file(os.path.join(serialization_dir, CONFIG_NAME))

cuda_device = params.params.get("trainer").get("cuda_device", -1)
check_for_gpu(cuda_device)

distributed = params.params.get("trainer").get("distributed", False)
if not distributed:
distributed_params = params.params.pop("distributed", None)
# If distributed isn't in the config and the config contains strictly
# one cuda device, we just run a single training process.
if distributed_params is None:
model = _train_worker(
process_rank=0,
params=params,
Expand All @@ -286,18 +285,24 @@ def train_model(
)
archive_model(serialization_dir, files_to_archive=params.files_to_archive)
return model

# Otherwise, we are running multiple processes for training.
else:
device_id = parse_cuda_device(cuda_device)
# We are careful here so that we can raise a good error if someone
# passed the wrong thing - cuda_devices are required.
device_ids = distributed_params.pop("cuda_devices", None)
multi_device = isinstance(device_ids, list) and len(device_ids) > 1

if not isinstance(device_id, list):
if not multi_device:
raise ConfigurationError(
"Multiple cuda devices need to be configured to run distributed training."
)
check_for_gpu(device_ids)

master_addr = params.params.get("trainer").pop("master_address", "127.0.0.1")
master_port = params.params.get("trainer").pop("master_port", 29500)
num_procs = len(device_id)
num_nodes = params.params.get("trainer").pop("num_nodes", 1)
master_addr = distributed_params.pop("master_address", "127.0.0.1")
master_port = distributed_params.pop("master_port", 29500)
num_procs = len(device_ids)
num_nodes = distributed_params.pop("num_nodes", 1)
world_size = num_nodes * num_procs

os.environ["MASTER_ADDR"] = master_addr
Expand Down Expand Up @@ -332,10 +337,10 @@ def train_model(
cache_prefix,
include_package,
node_rank,
num_procs,
master_addr,
master_port,
world_size,
device_ids,
),
nprocs=num_procs,
)
Expand All @@ -353,10 +358,10 @@ def _train_worker(
cache_prefix: str = None,
include_package: List[str] = None,
node_rank: int = 0,
num_procs_per_node: int = 0,
master_addr: str = "127.0.0.1",
master_port: int = 29500,
world_size: int = 1,
distributed_device_ids: List[str] = None,
) -> Optional[Model]:
"""
Helper to train the configured model/experiment. In distributed mode, this is spawned as a
Expand Down Expand Up @@ -415,18 +420,21 @@ def _train_worker(
for package_name in include_package:
import_submodules(package_name)

num_procs_per_node = len(distributed_device_ids)
# The Unique identifier of the worker process among all the processes in the
# distributed training group is computed here. This is used while initializing
# the process group using `init_process_group`
global_rank = node_rank * num_procs_per_node + process_rank

cuda_device = params.params.get("trainer").get("cuda_device", -1)
device_list = parse_cuda_device(cuda_device)

# In distributed training, the configured device is always going to be a list.
# The corresponding gpu id for the particular worker is obtained by picking the id
# from the device list with the rank as index
gpu_id = device_list[process_rank] # type: ignore
gpu_id = distributed_device_ids[process_rank] # type: ignore

# Till now, "cuda_device" might not be set in the trainer params.
# But a worker trainer needs to only know about its specific GPU id.
params["trainer"]["cuda_device"] = gpu_id
params["trainer"]["world_size"] = world_size

torch.cuda.set_device(gpu_id)
dist.init_process_group(
Expand All @@ -440,12 +448,6 @@ def _train_worker(
f"for distributed training in worker {global_rank}"
)

# Till now, "cuda_device" will be a list of ids as configured originally
# in params. But a worker trainer needs to only know about its specific
# GPU id.
params["trainer"]["cuda_device"] = gpu_id
params["trainer"]["world_size"] = world_size

trainer_type = params.get("trainer", {}).get("type", "default")

if trainer_type == "default":
Expand Down Expand Up @@ -504,7 +506,7 @@ def _train_worker(
trainer.model,
evaluation_dataset,
evaluation_iterator,
cuda_device=trainer._cuda_devices[0], # pylint: disable=protected-access,
cuda_device=trainer.cuda_device,
# TODO(brendanr): Pass in an arg following Joel's trainer refactor.
batch_weight_key="",
)
Expand Down
15 changes: 6 additions & 9 deletions allennlp/tests/commands/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,8 @@ def test_train_model_distributed(self):
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
"iterator": {"type": "basic", "batch_size": 2},
"trainer": {
"num_epochs": 2,
"optimizer": "adam",
"distributed": True,
"cuda_device": [0, 1],
},
"trainer": {"num_epochs": 2, "optimizer": "adam"},
"distributed": {"cuda_devices": [0, 1]},
}
)

Expand Down Expand Up @@ -136,7 +132,8 @@ def test_distributed_raises_error_with_no_gpus(self):
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
"iterator": {"type": "basic", "batch_size": 2},
"trainer": {"num_epochs": 2, "optimizer": "adam", "distributed": True},
"trainer": {"num_epochs": 2, "optimizer": "adam"},
"distributed": {},
}
)
with pytest.raises(ConfigurationError):
Expand Down Expand Up @@ -183,8 +180,8 @@ def test_error_is_throw_when_cuda_device_is_not_available(self):
"encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
},
"dataset_reader": {"type": "sequence_tagging"},
"train_data_path": "tests/fixtures/data/sequence_tagging.tsv",
"validation_data_path": "tests/fixtures/data/sequence_tagging.tsv",
"train_data_path": "allennlp/tests/fixtures/data/sequence_tagging.tsv",
"validation_data_path": "allennlp/tests/fixtures/data/sequence_tagging.tsv",
"iterator": {"type": "basic", "batch_size": 2},
"trainer": {
"num_epochs": 2,
Expand Down
8 changes: 4 additions & 4 deletions allennlp/tests/models/simple_tagger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def test_regularization(self):
training_batch = next(iterator(self.instances, num_epochs=1))
validation_batch = next(iterator(self.instances, num_epochs=1))

training_loss = trainer.batch_loss([training_batch], for_training=True).item()
validation_loss = trainer.batch_loss([validation_batch], for_training=False).item()
training_loss = trainer.batch_loss(training_batch, for_training=True).item()
validation_loss = trainer.batch_loss(validation_batch, for_training=False).item()

# Training loss should have the regularization penalty, but validation loss should not.
numpy.testing.assert_almost_equal(training_loss, validation_loss)
Expand Down Expand Up @@ -124,8 +124,8 @@ def test_regularization(self):
training_batch = next(self.iterator(self.instances, num_epochs=1))
validation_batch = next(self.iterator(self.instances, num_epochs=1))

training_loss = self.trainer.batch_loss([training_batch], for_training=True).data
validation_loss = self.trainer.batch_loss([validation_batch], for_training=False).data
training_loss = self.trainer.batch_loss(training_batch, for_training=True).data
validation_loss = self.trainer.batch_loss(validation_batch, for_training=False).data

# Training loss should have the regularization penalty, but validation loss should not.
assert (training_loss != validation_loss).all()
Expand Down
57 changes: 17 additions & 40 deletions allennlp/tests/training/callback_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,52 +262,29 @@ def test_trainer_can_run_cuda(self):
callbacks=self.default_callbacks(),
cuda_device=0,
)
trainer.train()

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need multiple GPUs.")
def test_trainer_can_run_multiple_gpu(self):
self.model.cuda()

class MetaDataCheckWrapper(Model):
"""
Checks that the metadata field has been correctly split across the batch dimension
when running on multiple gpus.
"""

def __init__(self, model):
super().__init__(model.vocab)
self.model = model

def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore
assert (
"metadata" in kwargs and "tags" in kwargs
), f"tokens and metadata must be provided. Got {kwargs.keys()} instead."
batch_size = kwargs["tokens"]["tokens"].size()[0]
assert len(kwargs["metadata"]) == batch_size, (
f"metadata must be split appropriately. Expected {batch_size} elements, "
f"got {len(kwargs['metadata'])} elements."
)
return self.model.forward(**kwargs)

multigpu_iterator = BasicIterator(batch_size=4)
multigpu_iterator.index_with(self.vocab)
trainer = CallbackTrainer(
MetaDataCheckWrapper(self.model),
training_data=self.instances,
iterator=multigpu_iterator,
optimizer=self.optimizer,
num_epochs=2,
callbacks=self.default_callbacks(),
cuda_device=[0, 1],
)
metrics = trainer.train()
assert "peak_cpu_memory_MB" in metrics
assert isinstance(metrics["peak_cpu_memory_MB"], float)
assert metrics["peak_cpu_memory_MB"] > 0
assert "peak_gpu_0_memory_MB" in metrics
assert isinstance(metrics["peak_gpu_0_memory_MB"], int)
assert "peak_gpu_1_memory_MB" in metrics
assert isinstance(metrics["peak_gpu_1_memory_MB"], int)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 or more GPUs required.")
def test_passing_trainer_multiple_gpus_raises_error(self):
self.model.cuda()

multigpu_iterator = BasicIterator(batch_size=4)
multigpu_iterator.index_with(self.vocab)
with pytest.raises(ConfigurationError):
CallbackTrainer(
self.model,
training_data=self.instances,
iterator=multigpu_iterator,
optimizer=self.optimizer,
num_epochs=2,
callbacks=self.default_callbacks(),
cuda_device=[0, 1],
)

def test_trainer_can_resume_training(self):
trainer = CallbackTrainer(
Expand Down
10 changes: 4 additions & 6 deletions allennlp/tests/training/gan_callback_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(
num_epochs: int = 20,
shuffle: bool = False,
serialization_dir: Optional[str] = None,
cuda_device: Union[int, List] = -1,
cuda_device: int = -1,
callbacks: List[Callback] = None,
distributed: bool = False,
rank: int = 0,
Expand Down Expand Up @@ -235,11 +235,9 @@ def _reset_counters(self) -> None:
self.fake_stdev = 0.0
self.count = 0

def train_one_batch_group(self, batch_group):
# Each batch_group should have only one batch
batch, = batch_group
array = batch["array"]
def train_one_batch(self, batch):

array = batch["array"]
# We should not have mixed batches:
if len(set(batch["stage"])) != 1:
raise ValueError("mixed batch")
Expand Down Expand Up @@ -290,7 +288,7 @@ def train_one_epoch(self) -> None:
# Reset epoch counters
self._reset_counters()

# Will call `self.train_one_batch_group`
# Will call `self.train_one_batch`
super().train_one_epoch()


Expand Down