Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
Browse files Browse the repository at this point in the history
…e_mps
  • Loading branch information
Borda committed Aug 26, 2022
2 parents 39af0b9 + 33a5ed9 commit 523b3a8
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 67 deletions.
11 changes: 7 additions & 4 deletions .azure/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:

- bash: |
CHANGED_FILES=$(git diff --name-status origin/master -- . | awk '{print $2}')
FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*'
FILTER='.azure/gpu_*|src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*'
echo $CHANGED_FILES > changed_files.txt
MATCHES=$(cat changed_files.txt | grep -E $FILTER)
echo $MATCHES
Expand Down Expand Up @@ -72,12 +72,15 @@ jobs:
set -e
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'bagua' not in line] ; open(fname, 'w').writelines(lines)"
TORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])")
CUDA_VERSION_MM=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [115,113,111,102] if $CUDA_VERSION_MM >= ver][0])")
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/base.txt ${PYTORCH_VERSION}
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt ${PYTORCH_VERSION}
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/examples.txt ${PYTORCH_VERSION}
pip install "bagua-cuda$CUDA_VERSION_BAGUA>=0.9.0"
pip install -e .[strategies]
pip install -U deepspeed # TODO: remove when docker images are upgraded
pip install --requirement requirements/pytorch/devel.txt
pip install -e .[strategies] --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html
pip install --requirement requirements/pytorch/devel.txt --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html
pip list
env:
PACKAGE_NAME: pytorch
Expand Down
14 changes: 8 additions & 6 deletions dockers/base-conda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ RUN \
# https://github.com/NVIDIA/nvidia-docker/issues/1631
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
apt-get update -qq --fix-missing && \
NCCL_VER=$(dpkg -s libnccl2 | grep '^Version:' | awk -F ' ' '{print $2}' | awk -F '-' '{print $1}' | grep -ve '^\s*$') && \
CUDA_VERSION_MM="${CUDA_VERSION%.*}" && \
MAX_ALLOWED_NCCL=2.11.4 && \
TO_INSTALL_NCCL=$(echo -e "$MAX_ALLOWED_NCCL\n$NCCL_VER" | sort -V | head -n1)-1+cuda${CUDA_VERSION_MM} && \
apt-get install -y --no-install-recommends \
build-essential \
cmake \
Expand All @@ -42,17 +46,15 @@ RUN \
curl \
unzip \
ca-certificates \
libopenmpi-dev

RUN \
libopenmpi-dev \
libnccl2=$TO_INSTALL_NCCL \
libnccl-dev=$TO_INSTALL_NCCL && \
# Install conda and python.
# NOTE new Conda does not forward the exit status... https://github.com/conda/conda/issues/8385
curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py38_${CONDA_VERSION}-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b && \
rm ~/miniconda.sh

RUN \
rm ~/miniconda.sh && \
# Cleaning
apt-get autoremove -y && \
apt-get clean && \
Expand Down
23 changes: 7 additions & 16 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ RUN \
# https://github.com/NVIDIA/nvidia-docker/issues/1631
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
apt-get update -qq --fix-missing && \
apt-get install -y --no-install-recommends \
NCCL_VER=$(dpkg -s libnccl2 | grep '^Version:' | awk -F ' ' '{print $2}' | awk -F '-' '{print $1}' | grep -ve '^\s*$') && \
CUDA_VERSION_MM="${CUDA_VERSION%.*}" && \
MAX_ALLOWED_NCCL=2.11.4 && \
TO_INSTALL_NCCL=$(echo -e "$MAX_ALLOWED_NCCL\n$NCCL_VER" | sort -V | head -n1)-1+cuda${CUDA_VERSION_MM} && \
apt-get install -y --no-install-recommends --allow-downgrades --allow-change-held-packages \
build-essential \
pkg-config \
cmake \
Expand All @@ -50,19 +54,17 @@ RUN \
libopenmpi-dev \
openmpi-bin \
ssh \
&& \

libnccl2=$TO_INSTALL_NCCL \
libnccl-dev=$TO_INSTALL_NCCL && \
# Install python
add-apt-repository ppa:deadsnakes/ppa && \
apt-get install -y \
python${PYTHON_VERSION} \
python${PYTHON_VERSION}-distutils \
python${PYTHON_VERSION}-dev \
&& \

update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \
update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 && \

# Cleaning
apt-get autoremove -y && \
apt-get clean && \
Expand All @@ -78,7 +80,6 @@ RUN \
wget https://bootstrap.pypa.io/get-pip.py --progress=bar:force:noscroll --no-check-certificate && \
python${PYTHON_VERSION} get-pip.py && \
rm get-pip.py && \

pip install -q fire && \
# Disable cache \
CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") && \
Expand All @@ -91,16 +92,6 @@ RUN \
pip install -r requirements/pytorch/devel.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html && \
rm assistant.py

RUN \
apt-get purge -y cmake && \
wget -q https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz && \
tar -zxvf cmake-3.20.2.tar.gz && \
cd cmake-3.20.2 && \
./bootstrap -- -DCMAKE_USE_OPENSSL=OFF && \
make && \
make install && \
cmake --version

ENV \
HOROVOD_CUDA_HOME=$CUDA_TOOLKIT_ROOT_DIR \
HOROVOD_GPU_OPERATIONS=NCCL \
Expand Down
7 changes: 6 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Replaced the unwrapping logic in strategies with direct access to unwrapped `LightningModule` ([#13738](https://github.com/Lightning-AI/lightning/pull/13738))


- Enabled `on_before_batch_transfer` for `DPStrategy` and `IPUAccelerator` ([14023](https://github.com/Lightning-AI/lightning/pull/14023))
- Enabled `on_before_batch_transfer` for `DPStrategy` and `IPUAccelerator` ([#14023](https://github.com/Lightning-AI/lightning/pull/14023))

- Included `torch.cuda` rng state to the aggregate `_collect_rng_states()` and `_set_rng_states()` ([#14384](https://github.com/Lightning-AI/lightning/pull/14384))



Expand Down Expand Up @@ -85,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect values after transferring data to a MPS device ([#13285](https://github.com/Lightning-AI/lightning/issues/13285))


- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806))


## [1.7.2] - 2022-08-17

### Added
Expand Down
13 changes: 8 additions & 5 deletions src/pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,17 @@ def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str)
holders.append(model)

# Check if attribute in model.hparams, either namespace or dict
if hasattr(model, "hparams"):
if attribute in model.hparams:
holders.append(model.hparams)
if hasattr(model, "hparams") and attribute in model.hparams:
holders.append(model.hparams)

trainer = model._trainer
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
holders.append(trainer.datamodule)
if trainer is not None and trainer.datamodule is not None:
if hasattr(trainer.datamodule, attribute):
holders.append(trainer.datamodule)

if hasattr(trainer.datamodule, "hparams") and attribute in trainer.datamodule.hparams:
holders.append(trainer.datamodule.hparams)

return holders

Expand Down
15 changes: 12 additions & 3 deletions src/pytorch_lightning/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,22 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:


def _collect_rng_states() -> Dict[str, Any]:
"""Collect the global random state of :mod:`torch`, :mod:`numpy` and Python."""
return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()}
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
return {
"torch": torch.get_rng_state(),
"torch.cuda": torch.cuda.get_rng_state_all(),
"numpy": np.random.get_state(),
"python": python_get_rng_state(),
}


def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
"""Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
process."""
torch.set_rng_state(rng_state_dict["torch"])
# torch.cuda rng_state is only included since v1.8.
if "torch.cuda" in rng_state_dict:
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
np.random.set_state(rng_state_dict["numpy"])
version, state, gauss = rng_state_dict["python"]
python_set_rng_state((version, tuple(state), gauss))
Expand Down
69 changes: 43 additions & 26 deletions tests/tests_pytorch/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@


class BatchSizeDataModule(BoringDataModule):
def __init__(self, batch_size):
super().__init__()
def __init__(self, data_dir, batch_size):
super().__init__(data_dir)
if batch_size is not None:
self.batch_size = batch_size

Expand Down Expand Up @@ -58,7 +58,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b
tuner = Tuner(trainer)

model = BatchSizeModel(model_bs)
datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None
datamodule = BatchSizeDataModule(tmpdir, dm_bs) if dm_bs != -1 else None

new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
assert new_batch_size == 16
Expand Down Expand Up @@ -140,47 +140,64 @@ def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
assert not os.path.exists(tmpdir / "scale_batch_size_temp_model.ckpt")


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("use_hparams", [True, False])
def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
"""Test that new batch size gets written to the correct hyperparameter attribute."""
"""Test that new batch size gets written to the correct hyperparameter attribute for model."""
tutils.reset_seed()

hparams = {"batch_size": 2}
before_batch_size = hparams.get("batch_size")
before_batch_size = hparams["batch_size"]

class HparamsBatchSizeModel(BatchSizeModel):
class HparamsBatchSizeModel(BoringModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__()
self.save_hyperparameters()

def dataloader(self, *args, **kwargs):
# artificially set batch_size so we can get a dataloader
# remove it immediately after, because we want only self.hparams.batch_size
setattr(self, "batch_size", before_batch_size)
dataloader = super().dataloader(*args, **kwargs)
del self.batch_size
return dataloader
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size)

def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size)

model_class = HparamsBatchSizeModel if use_hparams else BatchSizeModel
model = model_class(**hparams)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.tune(model, scale_batch_size_kwargs={"steps_per_trial": 2, "max_trials": 4})
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert before_batch_size != after_batch_size
assert after_batch_size <= len(trainer.train_dataloader.dataset)


@pytest.mark.parametrize("use_hparams", [True, False])
def test_auto_scale_batch_size_set_datamodule_attribute(tmpdir, use_hparams):
"""Test that new batch size gets written to the correct hyperparameter attribute for datamodule."""
tutils.reset_seed()

hparams = {"batch_size": 2}
before_batch_size = hparams["batch_size"]

class HparamsBatchSizeDataModule(BoringDataModule):
def __init__(self, data_dir, batch_size):
super().__init__(data_dir)
self.batch_size = batch_size
self.save_hyperparameters()

def train_dataloader(self):
return DataLoader(self.random_train, batch_size=self.batch_size)
return DataLoader(self.random_train, batch_size=self.hparams.batch_size)

datamodule_fit = HparamsBatchSizeDataModule(data_dir=tmpdir, batch_size=before_batch_size)
model_class = HparamsBatchSizeModel if use_hparams else BatchSizeModel
model = model_class(**hparams)
def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True, accelerator="gpu", devices=1)
trainer.tune(model, datamodule_fit)
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert trainer.datamodule == datamodule_fit
assert before_batch_size != after_batch_size
datamodule_class = HparamsBatchSizeDataModule if use_hparams else BatchSizeDataModule
datamodule = datamodule_class(data_dir=tmpdir, batch_size=before_batch_size)
model = BatchSizeModel(**hparams)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.tune(model, datamodule=datamodule, scale_batch_size_kwargs={"steps_per_trial": 2, "max_trials": 4})
after_batch_size = datamodule.hparams.batch_size if use_hparams else datamodule.batch_size
assert trainer.datamodule == datamodule
assert before_batch_size < after_batch_size
assert after_batch_size <= len(trainer.train_dataloader.dataset)
assert datamodule_fit.batch_size == after_batch_size


def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
Expand Down
22 changes: 17 additions & 5 deletions tests/tests_pytorch/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class TestModel4(LightningModule): # fail case
batch_size = 1

model4 = TestModel4()

trainer = Trainer()
model4.trainer = trainer
datamodule = LightningDataModule()
datamodule.batch_size = 8
trainer.datamodule = datamodule
Expand All @@ -87,12 +87,21 @@ class TestModel7(LightningModule): # test for datamodule w/ hparams w/ attribut
model7 = TestModel7()
model7.trainer = trainer

return model1, model2, model3, model4, model5, model6, model7
class TestDataModule8(LightningDataModule): # test for hparams dict
hparams = TestHparamsDict2

model8 = TestModel1()
trainer = Trainer()
model8.trainer = trainer
datamodule = TestDataModule8()
trainer.datamodule = datamodule

return model1, model2, model3, model4, model5, model6, model7, model8


def test_lightning_hasattr():
"""Test that the lightning_hasattr works in all cases."""
model1, model2, model3, model4, model5, model6, model7 = models = model_cases()
model1, model2, model3, model4, model5, model6, model7, model8 = models = model_cases()
assert lightning_hasattr(model1, "learning_rate"), "lightning_hasattr failed to find namespace variable"
assert lightning_hasattr(model2, "learning_rate"), "lightning_hasattr failed to find hparams namespace variable"
assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable"
Expand All @@ -104,6 +113,7 @@ def test_lightning_hasattr():
assert lightning_hasattr(
model7, "batch_size"
), "lightning_hasattr failed to find batch_size in hparams w/ datamodule present"
assert lightning_hasattr(model8, "batch_size")

for m in models:
assert not lightning_hasattr(m, "this_attr_not_exist")
Expand All @@ -116,10 +126,11 @@ def test_lightning_getattr():
value = lightning_getattr(m, "learning_rate")
assert value == i, "attribute not correctly extracted"

model5, model6, model7 = models[4:]
model5, model6, model7, model8 = models[4:]
assert lightning_getattr(model5, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model6, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model7, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model8, "batch_size") == 2, "batch_size not correctly extracted"

for m in models:
with pytest.raises(
Expand All @@ -136,13 +147,14 @@ def test_lightning_setattr(tmpdir):
lightning_setattr(m, "learning_rate", 10)
assert lightning_getattr(m, "learning_rate") == 10, "attribute not correctly set"

model5, model6, model7 = models[4:]
model5, model6, model7, model8 = models[4:]
lightning_setattr(model5, "batch_size", 128)
lightning_setattr(model6, "batch_size", 128)
lightning_setattr(model7, "batch_size", 128)
assert lightning_getattr(model5, "batch_size") == 128, "batch_size not correctly set"
assert lightning_getattr(model6, "batch_size") == 128, "batch_size not correctly set"
assert lightning_getattr(model7, "batch_size") == 128, "batch_size not correctly set"
assert lightning_getattr(model8, "batch_size") == 128, "batch_size not correctly set"

for m in models:
with pytest.raises(
Expand Down
Loading

0 comments on commit 523b3a8

Please sign in to comment.