Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .actions/pull_legacy_checkpoints.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/bin/bash

# Run this script from the project root.
URL="https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip"
mkdir -p tests/legacy
# wget is simpler but does not work on Windows
python -c "from urllib.request import urlretrieve; urlretrieve('$URL', 'tests/legacy/checkpoints.zip')"
ls -l tests/legacy/

unzip -o tests/legacy/checkpoints.zip -d tests/legacy/
ls -l tests/legacy/checkpoints/
10 changes: 7 additions & 3 deletions .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,17 @@ jobs:
python requirements/pytorch/check-avail-extras.py
displayName: 'Env details'

- bash: bash .actions/pull_legacy_checkpoints.sh
displayName: 'Get legacy checkpoints'

- bash: python -m pytest pytorch_lightning
workingDirectory: src
displayName: 'Testing: PyTorch doctests'

- bash: |
bash .actions/pull_legacy_checkpoints.sh
cd tests/legacy
bash generate_checkpoints.sh
ls -l checkpoints/
displayName: 'Get legacy checkpoints'

- bash: python -m coverage run --source pytorch_lightning -m pytest --ignore benchmarks -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50
env:
PL_RUN_CUDA_TESTS: "1"
Expand Down
12 changes: 8 additions & 4 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ jobs:
if: ${{ matrix.requires == 'oldest' }}
run: python .actions/assistant.py replace_oldest_ver

- name: Pull legacy checkpoints
if: ${{ matrix.type != 'slow' }}
run: bash .actions/pull_legacy_checkpoints.sh

- name: Adjust PyTorch versions in requirements files
if: ${{ matrix.requires != 'oldest' }}
run: |
Expand Down Expand Up @@ -161,6 +157,14 @@ jobs:
--source_import="pytorch_lightning,lightning_fabric" \
--target_import="lightning.pytorch,lightning.fabric"

- name: Get legacy checkpoints
if: ${{ matrix.type != 'slow' }}
run: |
bash .actions/pull_legacy_checkpoints.sh
cd tests/legacy
bash generate_checkpoints.sh
ls -l checkpoints/

- name: Testing Warnings
# the stacklevel can only be set on >=3.7
if: matrix.python-version != '3.7'
Expand Down
6 changes: 6 additions & 0 deletions tests/legacy/back-compatible-versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@
1.8.4
1.8.5
1.8.6
1.9.0
1.9.1
1.9.2
1.9.3
1.9.4
1.9.5
20 changes: 9 additions & 11 deletions tests/legacy/generate_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ set -e
LEGACY_PATH=$(cd $(dirname $0); pwd -P)
ENV_PATH=$LEGACY_PATH/vEnv
export PYTHONPATH=$(dirname $LEGACY_PATH) # for `import tests_pytorch`
echo LEGACY_PATH: $LEGACY_PATH
echo ENV_PATH: $ENV_PATH
echo PYTHONPATH: $PYTHONPATH
printf "LEGACY_PATH: $LEGACY_PATH"
printf "ENV_PATH: $ENV_PATH"
printf "PYTHONPATH: $PYTHONPATH"
rm -rf $ENV_PATH

function create_and_save_checkpoint {
python --version
python -m pip --version
python -m pip list

python $LEGACY_PATH/simple_classif_training.py
python $LEGACY_PATH/simple_classif_training.py $pl_ver

cp $LEGACY_PATH/simple_classif_training.py $LEGACY_PATH/checkpoints/$pl_ver
mv $LEGACY_PATH/checkpoints/$pl_ver/lightning_logs/version_0/checkpoints/*.ckpt $LEGACY_PATH/checkpoints/$pl_ver/
Expand All @@ -28,11 +29,9 @@ function create_and_save_checkpoint {
# iterate over all arguments assuming that each argument is version
for pl_ver in "$@"
do
echo processing version: $pl_ver
printf "processing version: $pl_ver"

# Don't install/update anything before activating venv
# to avoid breaking any existing environment.
rm -rf $ENV_PATH
# Don't install/update anything before activating venv to avoid breaking any existing environment.
python -m venv $ENV_PATH
source $ENV_PATH/bin/activate

Expand All @@ -47,10 +46,9 @@ done

# use the PL installed in the environment if no PL version is specified
if [[ -z "$@" ]]; then
pl_ver=$(python -c "import pytorch_lightning as pl; print(pl.__version__)")
echo processing version: $pl_ver
printf "processing local version"

python -m pip install -r $LEGACY_PATH/requirements.txt

pl_ver="local"
create_and_save_checkpoint
fi
2 changes: 1 addition & 1 deletion tests/legacy/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torchmetrics # necessary because old PL verions don't have it as dependency
torchmetrics # necessary because old PL versions don't have it as dependency
scikit-learn
4 changes: 3 additions & 1 deletion tests/legacy/simple_classif_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

import torch

Expand Down Expand Up @@ -50,5 +51,6 @@ def main_train(dir_path, max_epochs: int = 20):


if __name__ == "__main__":
path_dir = os.path.join(PATH_LEGACY, "checkpoints", str(pl.__version__))
name = sys.argv[1] if len(sys.argv) > 1 else str(pl.__version__)
path_dir = os.path.join(PATH_LEGACY, "checkpoints", name)
main_train(path_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
# load list of all back compatible versions
with open(os.path.join(_PATH_LEGACY, "back-compatible-versions.txt")) as fp:
LEGACY_BACK_COMPATIBLE_PL_VERSIONS = [ln.strip() for ln in fp.readlines()]
# This shall be created for each CI run
LEGACY_BACK_COMPATIBLE_PL_VERSIONS += ["local"]


@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_imports_unified(pl_version: str):
path_ckpt = path_ckpts[-1]

# only below version 1.5.0 we pickled stuff in checkpoints
if Version(pl_version) < Version("1.5.0"):
if pl_version != "local" and Version(pl_version) < Version("1.5.0"):
context = pytest.warns(UserWarning, match="Redirecting import of")
else:
context = no_warning_call(match="Redirecting import of*")
Expand Down