diff --git a/.github/workflows/legacy-checkpoints.yml b/.github/workflows/legacy-checkpoints.yml new file mode 100644 index 0000000000000..36c4df760aaaf --- /dev/null +++ b/.github/workflows/legacy-checkpoints.yml @@ -0,0 +1,40 @@ +name: Create Legacy Ckpts + +# https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: + workflow_dispatch: + +jobs: + create-legacy-ckpts: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install dependencies + run: | + pip install -r requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet + pip install awscli + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_KEY_ID }} + aws-region: us-east-1 + + - name: Generate checkpoint + run: | + while IFS= read -r line; do + bash legacy/generate_checkpoints.sh $line + done <<< $(cat legacy/back-compatible-versions.txt) + + - name: Push files to S3 + working-directory: ./legacy + run: | + aws s3 sync legacy/checkpoints/ s3://pl-public-data/legacy/checkpoints/ + zip -r checkpoints.zip checkpoints + aws s3 cp checkpoints.zip s3://pl-public-data/legacy/ --acl public-read diff --git a/legacy/README.md b/legacy/README.md index 5ffcecf971f3a..efbd18f7eede6 100644 --- a/legacy/README.md +++ b/legacy/README.md @@ -14,6 +14,6 @@ unzip -o checkpoints.zip To back populate collection with past version you can use following bash: ```bash -bash generate_checkpoints.sh 1.0.2 1.0.3 1.0.4 +bash generate_checkpoints.sh "1.3.7" "1.3.8" zip -r checkpoints.zip checkpoints/ ``` diff --git a/legacy/back-compatible-versions.txt b/legacy/back-compatible-versions.txt new file mode 100644 index 0000000000000..d9141ba2d3d0b --- /dev/null +++ b/legacy/back-compatible-versions.txt @@ -0,0 +1,39 @@ +1.0.0 +1.0.1 +1.0.2 +1.0.3 +1.0.4 +1.0.5 +1.0.6 +1.0.7 +1.0.8 +1.1.0 +1.1.1 +1.1.2 +1.1.3 +1.1.4 +1.1.5 +1.1.6 +1.1.7 +1.1.8 +1.2.0 +1.2.1 +1.2.2 +1.2.3 +1.2.4 +1.2.5 +1.2.6 +1.2.7 +1.2.8 +1.2.10 +1.3.0 +1.3.1 +1.3.2 +1.3.3 +1.3.4 +1.3.5 +1.3.6 +1.3.7 +1.3.8 +1.4.0 +1.4.1 diff --git a/legacy/generate_checkpoints.sh b/legacy/generate_checkpoints.sh index 53e09e5135206..e5152e55f1da6 100644 --- a/legacy/generate_checkpoints.sh +++ b/legacy/generate_checkpoints.sh @@ -2,12 +2,14 @@ # Sample call: # bash generate_checkpoints.sh 1.0.2 1.0.3 1.0.4 +set -e + LEGACY_PATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" -FROZEN_MIN_PT_VERSION="1.4" +FROZEN_MIN_PT_VERSION="1.6" echo $LEGACY_PATH # install some PT version here so it does not need to reinstalled for each env -pip install virtualenv "torch==1.5" --quiet --no-cache-dir +pip install virtualenv "torch==1.6" --quiet ENV_PATH="$LEGACY_PATH/vEnv" @@ -23,14 +25,14 @@ do # activate and install PL version source "$ENV_PATH/bin/activate" # there are problem to load ckpt in older versions since they are saved the newer versions - pip install "pytorch_lightning==$ver" "torch==$FROZEN_MIN_PT_VERSION" --quiet --no-cache-dir + pip install "pytorch_lightning==$ver" "torch==$FROZEN_MIN_PT_VERSION" "torchmetrics" "scikit-learn" --quiet python --version pip --version pip list | grep torch - python "$LEGACY_PATH/zero_training.py" - cp "$LEGACY_PATH/zero_training.py" ${LEGACY_PATH}/checkpoints/${ver} + python "$LEGACY_PATH/simple_classif_training.py" + cp "$LEGACY_PATH/simple_classif_training.py" ${LEGACY_PATH}/checkpoints/${ver} mv ${LEGACY_PATH}/checkpoints/${ver}/lightning_logs/version_0/checkpoints/*.ckpt ${LEGACY_PATH}/checkpoints/${ver}/ rm -rf ${LEGACY_PATH}/checkpoints/${ver}/lightning_logs diff --git a/legacy/simple_classif_training.py b/legacy/simple_classif_training.py new file mode 100644 index 0000000000000..5189d0c1819b6 --- /dev/null +++ b/legacy/simple_classif_training.py @@ -0,0 +1,170 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +import torch.nn.functional as F +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split +from torch import nn +from torch.utils.data import DataLoader, Dataset +from torchmetrics import Accuracy + +import pytorch_lightning as pl +from pytorch_lightning import LightningDataModule, LightningModule, seed_everything +from pytorch_lightning.callbacks import EarlyStopping + +PATH_LEGACY = os.path.dirname(__file__) + + +class SklearnDataset(Dataset): + def __init__(self, x, y, x_type, y_type): + self.x = x + self.y = y + self._x_type = x_type + self._y_type = y_type + + def __getitem__(self, idx): + return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type) + + def __len__(self): + return len(self.y) + + +class SklearnDataModule(LightningDataModule): + def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 128): + super().__init__() + self.batch_size = batch_size + self._x, self._y = sklearn_dataset + self._split_data() + self._x_type = x_type + self._y_type = y_type + + def _split_data(self): + self.x_train, self.x_test, self.y_train, self.y_test = train_test_split( + self._x, self._y, test_size=0.20, random_state=42 + ) + self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split( + self.x_train, self.y_train, test_size=0.40, random_state=42 + ) + + def train_dataloader(self): + return DataLoader( + SklearnDataset(self.x_train, self.y_train, self._x_type, self._y_type), + shuffle=True, + batch_size=self.batch_size, + ) + + def val_dataloader(self): + return DataLoader( + SklearnDataset(self.x_valid, self.y_valid, self._x_type, self._y_type), batch_size=self.batch_size + ) + + def test_dataloader(self): + return DataLoader( + SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size + ) + + +class ClassifDataModule(SklearnDataModule): + def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128): + data = make_classification( + n_samples=length, + n_features=num_features, + n_classes=num_classes, + n_clusters_per_class=2, + n_informative=int(num_features / num_classes), + random_state=42, + ) + super().__init__(data, x_type=torch.float32, y_type=torch.long, batch_size=batch_size) + + +class ClassificationModel(LightningModule): + def __init__(self, num_features=24, num_classes=3, lr=0.01): + super().__init__() + self.save_hyperparameters() + + self.lr = lr + for i in range(3): + setattr(self, f"layer_{i}", nn.Linear(num_features, num_features)) + setattr(self, f"layer_{i}a", torch.nn.ReLU()) + setattr(self, "layer_end", nn.Linear(num_features, num_classes)) + + self.train_acc = Accuracy() + self.valid_acc = Accuracy() + self.test_acc = Accuracy() + + def forward(self, x): + x = self.layer_0(x) + x = self.layer_0a(x) + x = self.layer_1(x) + x = self.layer_1a(x) + x = self.layer_2(x) + x = self.layer_2a(x) + x = self.layer_end(x) + logits = F.softmax(x, dim=1) + return logits + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return [optimizer], [] + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.cross_entropy(logits, y) + self.log("train_loss", loss, prog_bar=True) + self.log("train_acc", self.train_acc(logits, y), prog_bar=True) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False) + self.log("val_acc", self.valid_acc(logits, y), prog_bar=True) + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False) + self.log("test_acc", self.test_acc(logits, y), prog_bar=True) + + +def main_train(dir_path, max_epochs: int = 20): + seed_everything(42) + stopping = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005) + trainer = pl.Trainer( + default_root_dir=dir_path, + gpus=int(torch.cuda.is_available()), + precision=(16 if torch.cuda.is_available() else 32), + checkpoint_callback=True, + callbacks=[stopping], + min_epochs=3, + max_epochs=max_epochs, + accumulate_grad_batches=2, + deterministic=True, + ) + + dm = ClassifDataModule() + model = ClassificationModel() + trainer.fit(model, datamodule=dm) + res = trainer.test(model, datamodule=dm) + assert res[0]["test_loss"] <= 0.7 + assert res[0]["test_acc"] >= 0.85 + assert trainer.current_epoch < (max_epochs - 1) + + +if __name__ == "__main__": + path_dir = os.path.join(PATH_LEGACY, "checkpoints", str(pl.__version__)) + main_train(path_dir) diff --git a/legacy/zero_training.py b/legacy/zero_training.py deleted file mode 100644 index d1716b38962df..0000000000000 --- a/legacy/zero_training.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import torch -from torch.utils.data import Dataset - -import pytorch_lightning as pl - -PATH_LEGACY = os.path.dirname(__file__) - - -class RandomDataset(Dataset): - def __init__(self, size, length: int = 100): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - return self.data[index] - - def __len__(self): - return self.len - - -class DummyModel(pl.LightningModule): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - return self.layer(x) - - def _loss(self, batch, prediction): - # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls - return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) - - def _step(self, batch, batch_idx): - output = self.layer(batch) - loss = self._loss(batch, output) - # return {'loss': loss} # used for PL<1.0 - return loss # used for PL >= 1.0 - - def training_step(self, batch, batch_idx): - return self._step(batch, batch_idx) - - def validation_step(self, batch, batch_idx): - self._step(batch, batch_idx) - - def test_step(self, batch, batch_idx): - self._step(batch, batch_idx) - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - return [optimizer], [lr_scheduler] - - def train_dataloader(self): - return torch.utils.data.DataLoader(RandomDataset(32, 64)) - - def val_dataloader(self): - return torch.utils.data.DataLoader(RandomDataset(32, 64)) - - def test_dataloader(self): - return torch.utils.data.DataLoader(RandomDataset(32, 64)) - - -def main_train(dir_path, max_epochs: int = 5): - - trainer = pl.Trainer(default_root_dir=dir_path, checkpoint_callback=True, max_epochs=max_epochs) - - model = DummyModel() - trainer.fit(model) - - -if __name__ == "__main__": - path_dir = os.path.join(PATH_LEGACY, "checkpoints", str(pl.__version__)) - main_train(path_dir) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 8693965a52abc..040cd642556cf 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,84 +14,79 @@ import glob import os import sys +from unittest.mock import patch import pytest +import torch -from pytorch_lightning import Trainer -from tests import _PATH_LEGACY +import pytorch_lightning as pl +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import EarlyStopping +from tests import _PATH_LEGACY, _PROJECT_ROOT LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints") CHECKPOINT_EXTENSION = ".ckpt" +# load list of all back compatible versions +with open(os.path.join(_PROJECT_ROOT, "legacy", "back-compatible-versions.txt")) as fp: + LEGACY_BACK_COMPATIBLE_PL_VERSIONS = [ln.strip() for ln in fp.readlines()] -# todo: add more legacy checkpoints - for < v0.8 -@pytest.mark.parametrize( - "pl_version", - [ - # "0.8.1", - "0.8.3", - "0.8.4", - # "0.8.5", # this version has problem with loading on PT<=1.4 as it seems to be archive - # "0.9.0", # this version has problem with loading on PT<=1.4 as it seems to be archive - "0.10.0", - "1.0.0", - "1.0.1", - "1.0.2", - "1.0.3", - "1.0.4", - "1.0.5", - "1.0.6", - "1.0.7", - "1.0.8", - "1.1.0", - "1.1.1", - "1.1.2", - "1.1.3", - "1.1.4", - "1.1.5", - "1.1.6", - "1.1.7", - "1.1.8", - "1.2.0", - "1.2.1", - "1.2.2", - "1.2.3", - "1.2.4", - "1.2.5", - "1.2.6", - "1.2.7", - "1.2.8", - "1.2.10", - "1.3.0", - "1.3.1", - "1.3.2", - "1.3.3", - "1.3.4", - "1.3.5", - "1.3.6", - "1.3.7", - "1.3.8", - ], -) -def test_resume_legacy_checkpoints(tmpdir, pl_version: str): - path_dir = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) +@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) +def test_load_legacy_checkpoints(tmpdir, pl_version: str): + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + with patch("sys.path", [PATH_LEGACY] + sys.path): + from simple_classif_training import ClassifDataModule, ClassificationModel + + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) + assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' + path_ckpt = path_ckpts[-1] + + model = ClassificationModel.load_from_checkpoint(path_ckpt) + trainer = Trainer(default_root_dir=str(tmpdir)) + dm = ClassifDataModule() + res = trainer.test(model, datamodule=dm) + assert res[0]["test_loss"] <= 0.7 + assert res[0]["test_acc"] >= 0.85 + print(res) - # todo: make this as mock, so it is cleaner... - orig_sys_paths = list(sys.path) - sys.path.insert(0, path_dir) - from zero_training import DummyModel - path_ckpts = sorted(glob.glob(os.path.join(path_dir, f"*{CHECKPOINT_EXTENSION}"))) - assert path_ckpts, 'No checkpoints found in folder "%s"' % path_dir - path_ckpt = path_ckpts[-1] +class LimitNbEpochs(Callback): + def __init__(self, nb: int): + self.limit = nb + self._count = 0 - model = DummyModel.load_from_checkpoint(path_ckpt) - trainer = Trainer(default_root_dir=tmpdir, max_epochs=6) - trainer.fit(model) + def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._count += 1 + if self._count >= self.limit: + trainer.should_stop = True + + +@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) +def test_resume_legacy_checkpoints(tmpdir, pl_version: str): + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + with patch("sys.path", [PATH_LEGACY] + sys.path): + from simple_classif_training import ClassifDataModule, ClassificationModel - # todo - # model = DummyModel() - # trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, resume_from_checkpoint=path_ckpt) - # trainer.fit(model) + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) + assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' + path_ckpt = path_ckpts[-1] - sys.path = orig_sys_paths + dm = ClassifDataModule() + model = ClassificationModel() + es = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005) + stop = LimitNbEpochs(1) + trainer = Trainer( + default_root_dir=str(tmpdir), + gpus=int(torch.cuda.is_available()), + precision=(16 if torch.cuda.is_available() else 32), + checkpoint_callback=True, + callbacks=[es, stop], + max_epochs=21, + accumulate_grad_batches=2, + deterministic=True, + resume_from_checkpoint=path_ckpt, + ) + trainer.fit(model, datamodule=dm) + res = trainer.test(model, datamodule=dm) + assert res[0]["test_loss"] <= 0.7 + assert res[0]["test_acc"] >= 0.85 diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 7ff4b98907d1d..3ea319a27a07a 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -14,7 +14,6 @@ """Test deprecated functionality which will be removed in v1.5.0""" import os from typing import Any, Dict -from unittest import mock import pytest import torch diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 6f6de3ff16c81..6c5b5a8e33ffb 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -1,7 +1,6 @@ import pytest import torch from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset, IterableDataset from pytorch_lightning.utilities.data import extract_batch_size, get_len, has_iterable_dataset, has_len from tests.helpers.boring_model import RandomDataset, RandomIterableDataset