From 27349216e078c471b8314e5708f265945e714eab Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:27:32 -0400 Subject: [PATCH 01/29] made validation step optional --- pytorch_lightning/models/trainer.py | 9 +++++++-- pytorch_lightning/root_module/root_module.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index de8a0a03e1a9e..d5f90fa5dc8a9 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -345,13 +345,13 @@ def __layout_bookeeping(self): self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check) # determine number of validation batches - self.nb_val_batches = len(self.val_dataloader) + self.nb_val_batches = len(self.val_dataloader) if self.val_dataloader is not None else 0 self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check) self.nb_val_batches = max(1, self.nb_val_batches) self.nb_val_batches = self.nb_val_batches # determine number of test batches - self.nb_test_batches = len(self.test_dataloader) + self.nb_test_batches = len(self.test_dataloader) if self.test_dataloader is not None else 0 self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check) # determine when to check validation @@ -372,6 +372,10 @@ def validate(self, model, dataloader, max_batches): :param max_batches: Scalar :return: """ + # skip validation if model has no validation_step defined + if not self.__is_function_implemented('validation_step'): + return {} + # enable eval mode model.zero_grad() model.eval() @@ -439,6 +443,7 @@ def get_dataloaders(self, model): :return: """ self.tng_dataloader = model.tng_dataloader + self.test_dataloader = model.test_dataloader self.val_dataloader = model.val_dataloader diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 700e6db14e1e4..13cf46818f345 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -36,18 +36,20 @@ def forward(self, *args, **kwargs): def validation_step(self, data_batch, batch_nb): """ return whatever outputs will need to be aggregated in validation_end + OPTIONAL :param data_batch: :return: """ - raise NotImplementedError + pass def validation_end(self, outputs): """ Outputs has the appended output after each validation step + OPTIONAL :param outputs: :return: dic_with_metrics for tqdm """ - raise NotImplementedError + pass def training_step(self, data_batch, batch_nb): """ @@ -67,7 +69,7 @@ def configure_optimizers(self): @data_loader def tng_dataloader(self): """ - Implement a function to load an h5py of this data + Implement a PyTorch DataLoader :return: """ raise NotImplementedError @@ -75,18 +77,18 @@ def tng_dataloader(self): @data_loader def test_dataloader(self): """ - Implement a function to load an h5py of this data + Implement a PyTorch DataLoader :return: """ - raise NotImplementedError + return None @data_loader def val_dataloader(self): """ - Implement a function to load an h5py of this data + Implement a PyTorch DataLoader :return: """ - raise NotImplementedError + return None @classmethod def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None): From 332c07f8f907aed3f5ba8449f2b00f98553b1830 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:29:29 -0400 Subject: [PATCH 02/29] added no val model --- pytorch_lightning/testing/no_val_module.py | 196 +++++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 pytorch_lightning/testing/no_val_module.py diff --git a/pytorch_lightning/testing/no_val_module.py b/pytorch_lightning/testing/no_val_module.py new file mode 100644 index 0000000000000..4dd5156267da9 --- /dev/null +++ b/pytorch_lightning/testing/no_val_module.py @@ -0,0 +1,196 @@ +import os +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import optim +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import MNIST +from torchvision import transforms +from test_tube import HyperOptArgumentParser + +from pytorch_lightning.root_module.root_module import LightningModule +from pytorch_lightning import data_loader + + +class LightningNoValTestModel(LightningModule): + """ + Sample model to show how to define a template + """ + + def __init__(self, hparams, force_remove_distributed_sampler=False): + """ + Pass in parsed HyperOptArgumentParser to the model + :param hparams: + """ + # init superclass + super(LightningNoValTestModel, self).__init__() + self.hparams = hparams + + self.batch_size = hparams.batch_size + + # if you specify an example input, the summary will show input/output for each layer + self.example_input_array = torch.rand(5, 28 * 28) + + # remove to test warning for dist sampler + self.force_remove_distributed_sampler = force_remove_distributed_sampler + + # build model + self.__build_model() + + # --------------------- + # MODEL SETUP + # --------------------- + def __build_model(self): + """ + Layout model + :return: + """ + self.c_d1 = nn.Linear(in_features=self.hparams.in_features, + out_features=self.hparams.hidden_dim) + self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) + self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) + + self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, + out_features=self.hparams.out_features) + + # --------------------- + # TRAINING + # --------------------- + def forward(self, x): + """ + No special modification required for lightning, define as you normally would + :param x: + :return: + """ + + x = self.c_d1(x) + x = torch.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + + x = self.c_d2(x) + logits = F.log_softmax(x, dim=1) + + return logits + + def loss(self, labels, logits): + nll = F.nll_loss(logits, labels) + return nll + + def training_step(self, data_batch, batch_i): + """ + Lightning calls this inside the training loop + :param data_batch: + :return: + """ + # forward pass + x, y = data_batch + x = x.view(x.size(0), -1) + + y_hat = self.forward(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + + # alternate possible outputs to test + if self.trainer.batch_nb % 1 == 0: + output = OrderedDict({ + 'loss': loss_val, + 'prog': {'some_val': loss_val * loss_val} + }) + return output + if self.trainer.batch_nb % 2 == 0: + return loss_val + + def on_tng_metrics(self, logs): + logs['some_tensor_to_test'] = torch.rand(1) + + # --------------------- + # TRAINING SETUP + # --------------------- + def configure_optimizers(self): + """ + return whatever optimizers we want here + :return: list of optimizers + """ + # try no scheduler for this model (testing purposes) + optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + # test returning only 1 list instead of 2 + return [optimizer] + + def __dataloader(self, train): + # init data generators + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root=self.hparams.data_root, train=train, + transform=transform, download=True) + + # when using multi-node we need to add the datasampler + train_sampler = None + batch_size = self.hparams.batch_size + + try: + if self.on_gpu and not self.force_remove_distributed_sampler: + train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) + batch_size = batch_size // self.trainer.world_size # scale batch size + except Exception: + pass + + should_shuffle = train_sampler is None + loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=should_shuffle, + sampler=train_sampler + ) + + return loader + + @data_loader + def tng_dataloader(self): + return self.__dataloader(train=True) + + @staticmethod + def add_model_specific_args(parent_parser, root_dir): # pragma: no cover + """ + Parameters you define here will be available to your model through self.hparams + :param parent_parser: + :param root_dir: + :return: + """ + parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) + + # param overwrites + # parser.set_defaults(gradient_clip=5.0) + + # network params + parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) + parser.add_argument('--in_features', default=28 * 28, type=int) + parser.add_argument('--out_features', default=10, type=int) + # use 500 for CPU, 50000 for GPU to see speed difference + parser.add_argument('--hidden_dim', default=50000, type=int) + + # data + parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str) + + # training params (opt) + parser.opt_list('--learning_rate', default=0.001 * 8, type=float, + options=[0.0001, 0.0005, 0.001, 0.005], + tunable=False) + parser.opt_list('--optimizer_name', default='adam', type=str, + options=['adam'], tunable=False) + + # if using 2 nodes with 4 gpus each the batch size here + # (256) will be 256 / (2*8) = 16 per gpu + parser.opt_list('--batch_size', default=256 * 8, type=int, + options=[32, 64, 128, 256], tunable=False, + help='batch size will be divided over all gpus being used across all nodes') + return parser From 1a17e48654aed8ade1b546c28528c5dc91f47046 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:31:24 -0400 Subject: [PATCH 03/29] val_step can be implemented but not validation_end --- pytorch_lightning/models/trainer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index d5f90fa5dc8a9..ea261f295fe5e 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -422,11 +422,13 @@ def validate(self, model, dataloader, max_batches): if self.progress_bar and self.prog_bar is not None: self.prog_bar.update(1) - # give model a chance to do something with the outputs - if self.data_parallel: - val_results = model.module.validation_end(outputs) - else: - val_results = model.validation_end(outputs) + # give model a chance to do something with the outputs (and method defined) + val_results = {} + if self.__is_function_implemented('validation_end'): + if self.data_parallel: + val_results = model.module.validation_end(outputs) + else: + val_results = model.validation_end(outputs) # enable train mode again model.train() From 1a93a212cc47d3bc18b5e5809fb9c50a2ff2e366 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:32:22 -0400 Subject: [PATCH 04/29] added no val end model --- .../testing/no_val_end_module.py | 247 ++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 pytorch_lightning/testing/no_val_end_module.py diff --git a/pytorch_lightning/testing/no_val_end_module.py b/pytorch_lightning/testing/no_val_end_module.py new file mode 100644 index 0000000000000..4fd7eb09d5725 --- /dev/null +++ b/pytorch_lightning/testing/no_val_end_module.py @@ -0,0 +1,247 @@ +import os +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import optim +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import MNIST +from torchvision import transforms +from test_tube import HyperOptArgumentParser + +from pytorch_lightning.root_module.root_module import LightningModule +from pytorch_lightning import data_loader + + +class LightningTestModel(LightningModule): + """ + Sample model to show how to define a template + """ + + def __init__(self, hparams, force_remove_distributed_sampler=False): + """ + Pass in parsed HyperOptArgumentParser to the model + :param hparams: + """ + # init superclass + super(LightningTestModel, self).__init__() + self.hparams = hparams + + self.batch_size = hparams.batch_size + + # if you specify an example input, the summary will show input/output for each layer + self.example_input_array = torch.rand(5, 28 * 28) + + # remove to test warning for dist sampler + self.force_remove_distributed_sampler = force_remove_distributed_sampler + + # build model + self.__build_model() + + # --------------------- + # MODEL SETUP + # --------------------- + def __build_model(self): + """ + Layout model + :return: + """ + self.c_d1 = nn.Linear(in_features=self.hparams.in_features, + out_features=self.hparams.hidden_dim) + self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) + self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) + + self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, + out_features=self.hparams.out_features) + + # --------------------- + # TRAINING + # --------------------- + def forward(self, x): + """ + No special modification required for lightning, define as you normally would + :param x: + :return: + """ + + x = self.c_d1(x) + x = torch.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + + x = self.c_d2(x) + logits = F.log_softmax(x, dim=1) + + return logits + + def loss(self, labels, logits): + nll = F.nll_loss(logits, labels) + return nll + + def training_step(self, data_batch, batch_i): + """ + Lightning calls this inside the training loop + :param data_batch: + :return: + """ + # forward pass + x, y = data_batch + x = x.view(x.size(0), -1) + + y_hat = self.forward(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + + # alternate possible outputs to test + if self.trainer.batch_nb % 1 == 0: + output = OrderedDict({ + 'loss': loss_val, + 'prog': {'some_val': loss_val * loss_val} + }) + return output + if self.trainer.batch_nb % 2 == 0: + return loss_val + + def validation_step(self, data_batch, batch_i): + """ + Lightning calls this inside the validation loop + :param data_batch: + :return: + """ + x, y = data_batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + + if self.on_gpu: + val_acc = val_acc.cuda(loss_val.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + val_acc = val_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_i % 1 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + }) + return output + if batch_i % 2 == 0: + return val_acc + + if batch_i % 3 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + 'test_dic': {'val_loss_a': loss_val} + }) + return output + + def on_tng_metrics(self, logs): + logs['some_tensor_to_test'] = torch.rand(1) + + # --------------------- + # TRAINING SETUP + # --------------------- + def configure_optimizers(self): + """ + return whatever optimizers we want here + :return: list of optimizers + """ + # try no scheduler for this model (testing purposes) + optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + # test returning only 1 list instead of 2 + return [optimizer] + + def __dataloader(self, train): + # init data generators + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root=self.hparams.data_root, train=train, + transform=transform, download=True) + + # when using multi-node we need to add the datasampler + train_sampler = None + batch_size = self.hparams.batch_size + + try: + if self.on_gpu and not self.force_remove_distributed_sampler: + train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) + batch_size = batch_size // self.trainer.world_size # scale batch size + except Exception: + pass + + should_shuffle = train_sampler is None + loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=should_shuffle, + sampler=train_sampler + ) + + return loader + + @data_loader + def tng_dataloader(self): + return self.__dataloader(train=True) + + @data_loader + def val_dataloader(self): + return self.__dataloader(train=False) + + @data_loader + def test_dataloader(self): + return self.__dataloader(train=False) + + @staticmethod + def add_model_specific_args(parent_parser, root_dir): # pragma: no cover + """ + Parameters you define here will be available to your model through self.hparams + :param parent_parser: + :param root_dir: + :return: + """ + parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) + + # param overwrites + # parser.set_defaults(gradient_clip=5.0) + + # network params + parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) + parser.add_argument('--in_features', default=28 * 28, type=int) + parser.add_argument('--out_features', default=10, type=int) + # use 500 for CPU, 50000 for GPU to see speed difference + parser.add_argument('--hidden_dim', default=50000, type=int) + + # data + parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str) + + # training params (opt) + parser.opt_list('--learning_rate', default=0.001 * 8, type=float, + options=[0.0001, 0.0005, 0.001, 0.005], + tunable=False) + parser.opt_list('--optimizer_name', default='adam', type=str, + options=['adam'], tunable=False) + + # if using 2 nodes with 4 gpus each the batch size here + # (256) will be 256 / (2*8) = 16 per gpu + parser.opt_list('--batch_size', default=256 * 8, type=int, + options=[32, 64, 128, 256], tunable=False, + help='batch size will be divided over all gpus being used across all nodes') + return parser From c836ef2e3ac13bf76b0a13c98fcc6cd93c341289 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:36:44 -0400 Subject: [PATCH 05/29] added tests --- pytorch_lightning/testing/__init__.py | 3 + .../testing/no_val_end_module.py | 4 +- pytorch_lightning/testing/no_val_module.py | 4 +- tests/test_models.py | 120 +++++++++++++++++- 4 files changed, 126 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/testing/__init__.py b/pytorch_lightning/testing/__init__.py index e69de29bb2d1d..7beb286e30790 100644 --- a/pytorch_lightning/testing/__init__.py +++ b/pytorch_lightning/testing/__init__.py @@ -0,0 +1,3 @@ +from .lm_test_module import LightningTestModel +from .no_val_end_module import NoValEndTestModel +from .no_val_module import NoValModel \ No newline at end of file diff --git a/pytorch_lightning/testing/no_val_end_module.py b/pytorch_lightning/testing/no_val_end_module.py index 4fd7eb09d5725..3b42ab02565c8 100644 --- a/pytorch_lightning/testing/no_val_end_module.py +++ b/pytorch_lightning/testing/no_val_end_module.py @@ -15,7 +15,7 @@ from pytorch_lightning import data_loader -class LightningTestModel(LightningModule): +class NoValEndTestModel(LightningModule): """ Sample model to show how to define a template """ @@ -26,7 +26,7 @@ def __init__(self, hparams, force_remove_distributed_sampler=False): :param hparams: """ # init superclass - super(LightningTestModel, self).__init__() + super(NoValEndTestModel, self).__init__() self.hparams = hparams self.batch_size = hparams.batch_size diff --git a/pytorch_lightning/testing/no_val_module.py b/pytorch_lightning/testing/no_val_module.py index 4dd5156267da9..029bc44769302 100644 --- a/pytorch_lightning/testing/no_val_module.py +++ b/pytorch_lightning/testing/no_val_module.py @@ -15,7 +15,7 @@ from pytorch_lightning import data_loader -class LightningNoValTestModel(LightningModule): +class NoValModel(LightningModule): """ Sample model to show how to define a template """ @@ -26,7 +26,7 @@ def __init__(self, hparams, force_remove_distributed_sampler=False): :param hparams: """ # init superclass - super(LightningNoValTestModel, self).__init__() + super(NoValModel, self).__init__() self.hparams = hparams self.batch_size = hparams.batch_size diff --git a/tests/test_models.py b/tests/test_models.py index 896eb88490913..e20349e15ab23 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -10,7 +10,7 @@ # sys.path += [os.path.abspath('..'), os.path.abspath('../..')] from pytorch_lightning import Trainer -from pytorch_lightning.testing.lm_test_module import LightningTestModel +from pytorch_lightning.testing import LightningTestModel, NoValEndTestModel, NoValModel from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.utilities.debugging import MisconfigurationException from pytorch_lightning.root_module import memory @@ -26,6 +26,124 @@ # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ +def test_no_val_module(): + """ + Tests use case where trainer saves the model, and user loads it from tags independently + :return: + """ + hparams = get_hparams() + model = NoValModel(hparams) + + save_dir = init_save_dir() + + # exp file to get meta + exp = get_exp(False) + exp.argparse(hparams) + exp.save() + + trainer_options = dict( + max_nb_epochs=1, + cluster=SlurmCluster(), + experiment=exp, + checkpoint_callback=ModelCheckpoint(save_dir) + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # traning complete + assert result == 1, 'amp + ddp model failed to complete' + + # make a prediction + for batch in model.test_dataloader: + break + + x, y = batch + x = x.view(x.size(0), -1) + + # generate preds before saving model + model.eval() + pred_before_saving = model(x) + + # save model + new_weights_path = os.path.join(save_dir, 'save_test.ckpt') + trainer.save_checkpoint(new_weights_path) + + # load new model + tags_path = exp.get_data_path(exp.name, exp.version) + tags_path = os.path.join(tags_path, 'meta_tags.csv') + model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, + tags_csv=tags_path, on_gpu=False) + model_2.eval() + + # make prediction + # assert that both predictions are the same + new_pred = model_2(x) + assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 + + clear_save_dir() + + +def test_no_val_end_module(): + """ + Tests use case where trainer saves the model, and user loads it from tags independently + :return: + """ + hparams = get_hparams() + model = NoValEndTestModel(hparams) + + save_dir = init_save_dir() + + # exp file to get meta + exp = get_exp(False) + exp.argparse(hparams) + exp.save() + + trainer_options = dict( + max_nb_epochs=1, + cluster=SlurmCluster(), + experiment=exp, + checkpoint_callback=ModelCheckpoint(save_dir) + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # traning complete + assert result == 1, 'amp + ddp model failed to complete' + + # make a prediction + for batch in model.test_dataloader: + break + + x, y = batch + x = x.view(x.size(0), -1) + + # generate preds before saving model + model.eval() + pred_before_saving = model(x) + + # save model + new_weights_path = os.path.join(save_dir, 'save_test.ckpt') + trainer.save_checkpoint(new_weights_path) + + # load new model + tags_path = exp.get_data_path(exp.name, exp.version) + tags_path = os.path.join(tags_path, 'meta_tags.csv') + model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, + tags_csv=tags_path, on_gpu=False) + model_2.eval() + + # make prediction + # assert that both predictions are the same + new_pred = model_2(x) + assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 + + clear_save_dir() + + def test_simple_cpu(): """ Verify continue training session on CPU From eefe5120c1ddd46def2b01d3685f9d79b6def9e1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:47:09 -0400 Subject: [PATCH 06/29] added tests --- pytorch_lightning/root_module/root_module.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 13cf46818f345..e12d1336bb0a8 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -5,9 +5,10 @@ from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.root_module.hooks import ModelHooks from pytorch_lightning.root_module.decorators import data_loader +from abc import ABC, abstractmethod -class LightningModule(GradInformation, ModelIO, ModelHooks): +class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def __init__(self, *args, **kwargs): super(LightningModule, self).__init__(*args, **kwargs) @@ -51,6 +52,7 @@ def validation_end(self, outputs): """ pass + @abstractmethod def training_step(self, data_batch, batch_nb): """ return loss, dict with metrics for tqdm @@ -59,6 +61,7 @@ def training_step(self, data_batch, batch_nb): """ raise NotImplementedError + @abstractmethod def configure_optimizers(self): """ Return a list of optimizers and a list of schedulers (could be empty) From 5de39830f77c85f35203cebd2d435b2fe7016251 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:48:12 -0400 Subject: [PATCH 07/29] remove class --- pytorch_lightning/root_module/root_module.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index e12d1336bb0a8..13cf46818f345 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -5,10 +5,9 @@ from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.root_module.hooks import ModelHooks from pytorch_lightning.root_module.decorators import data_loader -from abc import ABC, abstractmethod -class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): +class LightningModule(GradInformation, ModelIO, ModelHooks): def __init__(self, *args, **kwargs): super(LightningModule, self).__init__(*args, **kwargs) @@ -52,7 +51,6 @@ def validation_end(self, outputs): """ pass - @abstractmethod def training_step(self, data_batch, batch_nb): """ return loss, dict with metrics for tqdm @@ -61,7 +59,6 @@ def training_step(self, data_batch, batch_nb): """ raise NotImplementedError - @abstractmethod def configure_optimizers(self): """ Return a list of optimizers and a list of schedulers (could be empty) From da1f7f7ab05138d2ca171f3bcd23510e2cadd863 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:52:32 -0400 Subject: [PATCH 08/29] remove class --- pytorch_lightning/models/trainer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index ea261f295fe5e..3cfbcdb9fb612 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -13,6 +13,7 @@ import torch.multiprocessing as mp import torch.distributed as dist +from pytorch_lightning import LightningModule from pytorch_lightning.root_module.memory import get_gpu_memory_map from pytorch_lightning.root_module.model_saving import TrainerIO from pytorch_lightning.pt_overrides.override_data_parallel import ( @@ -312,6 +313,12 @@ def __is_function_implemented(self, f_name): f_op = getattr(model, f_name, None) return callable(f_op) + def __is_overriden(self, f_name): + model = self.__get_model() + model_op = getattr(model, f_name, None) + parent_op = getattr(LightningModule(), f_name, None) + return model_op.__code__ is not parent_op.__code__ + @property def __tng_tqdm_dic(self): tqdm_dic = { @@ -373,7 +380,7 @@ def validate(self, model, dataloader, max_batches): :return: """ # skip validation if model has no validation_step defined - if not self.__is_function_implemented('validation_step'): + if not self.__is_overriden('validation_step'): return {} # enable eval mode @@ -424,7 +431,7 @@ def validate(self, model, dataloader, max_batches): # give model a chance to do something with the outputs (and method defined) val_results = {} - if self.__is_function_implemented('validation_end'): + if self.__is_overriden('validation_end'): if self.data_parallel: val_results = model.module.validation_end(outputs) else: From da15dd054636345a5387683561adffac4d9a30f9 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:54:01 -0400 Subject: [PATCH 09/29] remove class --- pytorch_lightning/models/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 3cfbcdb9fb612..5c08ce5c8e675 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -13,7 +13,6 @@ import torch.multiprocessing as mp import torch.distributed as dist -from pytorch_lightning import LightningModule from pytorch_lightning.root_module.memory import get_gpu_memory_map from pytorch_lightning.root_module.model_saving import TrainerIO from pytorch_lightning.pt_overrides.override_data_parallel import ( @@ -316,7 +315,7 @@ def __is_function_implemented(self, f_name): def __is_overriden(self, f_name): model = self.__get_model() model_op = getattr(model, f_name, None) - parent_op = getattr(LightningModule(), f_name, None) + parent_op = getattr(model.super(), f_name, None) return model_op.__code__ is not parent_op.__code__ @property From f1b612b491692eec2afee5cde2b11bc6459e9095 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 07:56:46 -0400 Subject: [PATCH 10/29] remove class --- pytorch_lightning/models/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 5c08ce5c8e675..58ab7b5db7f35 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -313,6 +313,8 @@ def __is_function_implemented(self, f_name): return callable(f_op) def __is_overriden(self, f_name): + import pdb + pdb.set_trace() model = self.__get_model() model_op = getattr(model, f_name, None) parent_op = getattr(model.super(), f_name, None) From 21e25b6422088b9693272d829f0ead789648084b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:05:45 -0400 Subject: [PATCH 11/29] remove class --- pytorch_lightning/models/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 58ab7b5db7f35..8632253f0914d 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -13,6 +13,7 @@ import torch.multiprocessing as mp import torch.distributed as dist +from pytorch_lightning.root_module.root_module import LightningModule from pytorch_lightning.root_module.memory import get_gpu_memory_map from pytorch_lightning.root_module.model_saving import TrainerIO from pytorch_lightning.pt_overrides.override_data_parallel import ( @@ -316,8 +317,8 @@ def __is_overriden(self, f_name): import pdb pdb.set_trace() model = self.__get_model() - model_op = getattr(model, f_name, None) - parent_op = getattr(model.super(), f_name, None) + model_op_id = model.__dict__[f_name] + parent_op_id = LightningModule.__dict__[f_name] return model_op.__code__ is not parent_op.__code__ @property From 5f2ca7ff31eb2e2245dc83a93fb85992a1d2968f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:05:58 -0400 Subject: [PATCH 12/29] remove class --- pytorch_lightning/models/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 8632253f0914d..ee78725f5a6bf 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -319,7 +319,7 @@ def __is_overriden(self, f_name): model = self.__get_model() model_op_id = model.__dict__[f_name] parent_op_id = LightningModule.__dict__[f_name] - return model_op.__code__ is not parent_op.__code__ + return model_op_id is not parent_op_id @property def __tng_tqdm_dic(self): From 8a08ef45207891ea57b3a7b17d5a8a728cd375fd Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:08:33 -0400 Subject: [PATCH 13/29] remove class --- pytorch_lightning/models/trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index ee78725f5a6bf..4094cda1311fe 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -317,9 +317,7 @@ def __is_overriden(self, f_name): import pdb pdb.set_trace() model = self.__get_model() - model_op_id = model.__dict__[f_name] - parent_op_id = LightningModule.__dict__[f_name] - return model_op_id is not parent_op_id + return f_name in model.__dict__ @property def __tng_tqdm_dic(self): From c4e92c4f6cd7d994e8b8435b307d747a6cc61d79 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:08:41 -0400 Subject: [PATCH 14/29] remove class --- pytorch_lightning/models/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 4094cda1311fe..49796fa1a0e61 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -314,8 +314,6 @@ def __is_function_implemented(self, f_name): return callable(f_op) def __is_overriden(self, f_name): - import pdb - pdb.set_trace() model = self.__get_model() return f_name in model.__dict__ From b3af23c852ed9d11e0d0e38395695a9731463ba1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:10:50 -0400 Subject: [PATCH 15/29] remove class --- pytorch_lightning/models/trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 49796fa1a0e61..65ee380729995 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -315,7 +315,11 @@ def __is_function_implemented(self, f_name): def __is_overriden(self, f_name): model = self.__get_model() - return f_name in model.__dict__ + try: + model.__dict__[f_name] + return True + except KeyError: + return False @property def __tng_tqdm_dic(self): From 2678bf15d10853462f3f52f8d06ee6d5e2a03ac3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:12:39 -0400 Subject: [PATCH 16/29] remove class --- pytorch_lightning/models/trainer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 65ee380729995..49796fa1a0e61 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -315,11 +315,7 @@ def __is_function_implemented(self, f_name): def __is_overriden(self, f_name): model = self.__get_model() - try: - model.__dict__[f_name] - return True - except KeyError: - return False + return f_name in model.__dict__ @property def __tng_tqdm_dic(self): From f700b3ce1d67f9869775a799f46a2473506f152c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:13:43 -0400 Subject: [PATCH 17/29] remove class --- tests/test_models.py | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index e20349e15ab23..6cc744805fcdb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -55,17 +55,6 @@ def test_no_val_module(): # traning complete assert result == 1, 'amp + ddp model failed to complete' - # make a prediction - for batch in model.test_dataloader: - break - - x, y = batch - x = x.view(x.size(0), -1) - - # generate preds before saving model - model.eval() - pred_before_saving = model(x) - # save model new_weights_path = os.path.join(save_dir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) @@ -78,10 +67,6 @@ def test_no_val_module(): model_2.eval() # make prediction - # assert that both predictions are the same - new_pred = model_2(x) - assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 - clear_save_dir() @@ -114,17 +99,6 @@ def test_no_val_end_module(): # traning complete assert result == 1, 'amp + ddp model failed to complete' - # make a prediction - for batch in model.test_dataloader: - break - - x, y = batch - x = x.view(x.size(0), -1) - - # generate preds before saving model - model.eval() - pred_before_saving = model(x) - # save model new_weights_path = os.path.join(save_dir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) @@ -137,13 +111,10 @@ def test_no_val_end_module(): model_2.eval() # make prediction - # assert that both predictions are the same - new_pred = model_2(x) - assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 - clear_save_dir() + def test_simple_cpu(): """ Verify continue training session on CPU From f93ed92a3c699f776231fc15e06257b667912e4f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:15:28 -0400 Subject: [PATCH 18/29] updated docs --- docs/LightningModule/RequiredTrainerInterface.md | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index d0b5b6b3d94c1..677815b73a1e1 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -10,16 +10,14 @@ Otherwise, to Define a Lightning Module, implement the following methods: **Required**: - [training_step](RequiredTrainerInterface.md#training_step) -- [validation_step](RequiredTrainerInterface.md#validation_step) -- [validation_end](RequiredTrainerInterface.md#validation_end) - -- [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers) - -- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader) - [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader) -- [test_dataloader](RequiredTrainerInterface.md#test_dataloader) +- [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers) **Optional**: +- [validation_step](RequiredTrainerInterface.md#validation_step) +- [validation_end](RequiredTrainerInterface.md#validation_end) +- [val_dataloader](RequiredTrainerInterface.md#val_dataloader) +- [test_dataloader](RequiredTrainerInterface.md#test_dataloader) - [on_save_checkpoint](RequiredTrainerInterface.md#on_save_checkpoint) - [on_load_checkpoint](RequiredTrainerInterface.md#on_load_checkpoint) From f6ddc15b19fe50eade19cf90149f8f4c4c4a813c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:17:08 -0400 Subject: [PATCH 19/29] updated docs --- README.md | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 775f6bc7bf93e..2245eba3aa6fb 100644 --- a/README.md +++ b/README.md @@ -81,36 +81,40 @@ class CoolModel(pl.LightningModule): def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) - def my_loss(self, y_hat, y): - return F.cross_entropy(y_hat, y) - def training_step(self, batch, batch_nb): + # REQUIRED x, y = batch y_hat = self.forward(x) - return {'loss': self.my_loss(y_hat, y)} + return {'loss': F.cross_entropy(y_hat, y)} def validation_step(self, batch, batch_nb): + # OPTIONAL x, y = batch y_hat = self.forward(x) return {'val_loss': self.my_loss(y_hat, y)} def validation_end(self, outputs): + # OPTIONAL avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() return {'avg_val_loss': avg_loss} def configure_optimizers(self): + # REQUIRED return [torch.optim.Adam(self.parameters(), lr=0.02)] @pl.data_loader def tng_dataloader(self): + # REQUIRED return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def val_dataloader(self): + # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def test_dataloader(self): + # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) ``` From c8c7b70ec3c4d330617b2367c84da310c59c2422 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:24:25 -0400 Subject: [PATCH 20/29] updated test --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 6cc744805fcdb..6c242f2ae56f0 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -540,7 +540,7 @@ def test_early_stopping_cpu_model(): :return: """ - stopping = EarlyStopping() + stopping = EarlyStopping(monitor='val_acc') trainer_options = dict( early_stop_callback=stopping, gradient_clip=1.0, From 698a3edffa21ce84d79aece8493e22b365690206 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:26:14 -0400 Subject: [PATCH 21/29] updated test --- tests/test_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_models.py b/tests/test_models.py index 6c242f2ae56f0..c78f8f489728f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -614,6 +614,7 @@ def test_all_features_cpu_model(): print_nan_grads=True, progress_bar=False, experiment=get_exp(), + accumulate_grad_batches=2, max_nb_epochs=1, train_percent_check=0.4, val_percent_check=0.4 From 81df0c168d98f70b853ec3609f6aaf218aeccbcd Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:35:30 -0400 Subject: [PATCH 22/29] updated test --- pytorch_lightning/models/trainer.py | 2 ++ tests/test_models.py | 54 ++++++++++++++--------------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 49796fa1a0e61..4094cda1311fe 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -314,6 +314,8 @@ def __is_function_implemented(self, f_name): return callable(f_op) def __is_overriden(self, f_name): + import pdb + pdb.set_trace() model = self.__get_model() return f_name in model.__dict__ diff --git a/tests/test_models.py b/tests/test_models.py index c78f8f489728f..a5413de64b987 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,6 +26,33 @@ # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ +def test_early_stopping_cpu_model(): + """ + Test each of the trainer options + :return: + """ + + stopping = EarlyStopping(monitor='val_acc') + trainer_options = dict( + early_stop_callback=stopping, + gradient_clip=1.0, + overfit_pct=0.20, + track_grad_norm=2, + print_nan_grads=True, + progress_bar=False, + experiment=get_exp(), + train_percent_check=0.1, + val_percent_check=0.1 + ) + + model, hparams = get_model() + run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) + + # test freeze on cpu + model.freeze() + model.unfreeze() + + def test_no_val_module(): """ Tests use case where trainer saves the model, and user loads it from tags independently @@ -534,33 +561,6 @@ def test_amp_gpu_ddp_slurm_managed(): clear_save_dir() -def test_early_stopping_cpu_model(): - """ - Test each of the trainer options - :return: - """ - - stopping = EarlyStopping(monitor='val_acc') - trainer_options = dict( - early_stop_callback=stopping, - gradient_clip=1.0, - overfit_pct=0.20, - track_grad_norm=2, - print_nan_grads=True, - progress_bar=False, - experiment=get_exp(), - train_percent_check=0.1, - val_percent_check=0.1 - ) - - model, hparams = get_model() - run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) - - # test freeze on cpu - model.freeze() - model.unfreeze() - - def test_cpu_model_with_amp(): """ Make sure model trains on CPU From baeee35d34b104bcb465f1fca2f2e1b9927c8ab0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:37:02 -0400 Subject: [PATCH 23/29] updated test --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index a5413de64b987..8ed946b5f0e11 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,7 +32,7 @@ def test_early_stopping_cpu_model(): :return: """ - stopping = EarlyStopping(monitor='val_acc') + stopping = EarlyStopping(monitor='val_loss') trainer_options = dict( early_stop_callback=stopping, gradient_clip=1.0, From 338529a5c211d36c3cd7565056fbcd629b1254db Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:41:14 -0400 Subject: [PATCH 24/29] updated test --- pytorch_lightning/models/trainer.py | 7 ++-- tests/test_models.py | 54 ++++++++++++++--------------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 4094cda1311fe..7de7d39e668a7 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -314,10 +314,13 @@ def __is_function_implemented(self, f_name): return callable(f_op) def __is_overriden(self, f_name): + model = self.__get_model() + super_object = super(model.__class__, model) + import pdb pdb.set_trace() - model = self.__get_model() - return f_name in model.__dict__ + is_overriden = hasattr(model, f_name) and not hasattr(super_object, f_name) + return is_overriden @property def __tng_tqdm_dic(self): diff --git a/tests/test_models.py b/tests/test_models.py index 8ed946b5f0e11..bad6b677dd253 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,33 +26,6 @@ # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ -def test_early_stopping_cpu_model(): - """ - Test each of the trainer options - :return: - """ - - stopping = EarlyStopping(monitor='val_loss') - trainer_options = dict( - early_stop_callback=stopping, - gradient_clip=1.0, - overfit_pct=0.20, - track_grad_norm=2, - print_nan_grads=True, - progress_bar=False, - experiment=get_exp(), - train_percent_check=0.1, - val_percent_check=0.1 - ) - - model, hparams = get_model() - run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) - - # test freeze on cpu - model.freeze() - model.unfreeze() - - def test_no_val_module(): """ Tests use case where trainer saves the model, and user loads it from tags independently @@ -561,6 +534,33 @@ def test_amp_gpu_ddp_slurm_managed(): clear_save_dir() +def test_early_stopping_cpu_model(): + """ + Test each of the trainer options + :return: + """ + + stopping = EarlyStopping(monitor='val_loss') + trainer_options = dict( + early_stop_callback=stopping, + gradient_clip=1.0, + overfit_pct=0.20, + track_grad_norm=2, + print_nan_grads=True, + progress_bar=False, + experiment=get_exp(), + train_percent_check=0.1, + val_percent_check=0.1 + ) + + model, hparams = get_model() + run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) + + # test freeze on cpu + model.freeze() + model.unfreeze() + + def test_cpu_model_with_amp(): """ Make sure model trains on CPU From b9837df01b3a786b67719c81f8d8b169c70eca41 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:42:11 -0400 Subject: [PATCH 25/29] updated test --- tests/test_models.py | 55 ++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index bad6b677dd253..ac2b525da9b76 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,6 +26,34 @@ # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ + +def test_early_stopping_cpu_model(): + """ + Test each of the trainer options + :return: + """ + + stopping = EarlyStopping(monitor='val_loss') + trainer_options = dict( + early_stop_callback=stopping, + gradient_clip=1.0, + overfit_pct=0.20, + track_grad_norm=2, + print_nan_grads=True, + progress_bar=False, + experiment=get_exp(), + train_percent_check=0.1, + val_percent_check=0.1 + ) + + model, hparams = get_model() + run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) + + # test freeze on cpu + model.freeze() + model.unfreeze() + + def test_no_val_module(): """ Tests use case where trainer saves the model, and user loads it from tags independently @@ -534,33 +562,6 @@ def test_amp_gpu_ddp_slurm_managed(): clear_save_dir() -def test_early_stopping_cpu_model(): - """ - Test each of the trainer options - :return: - """ - - stopping = EarlyStopping(monitor='val_loss') - trainer_options = dict( - early_stop_callback=stopping, - gradient_clip=1.0, - overfit_pct=0.20, - track_grad_norm=2, - print_nan_grads=True, - progress_bar=False, - experiment=get_exp(), - train_percent_check=0.1, - val_percent_check=0.1 - ) - - model, hparams = get_model() - run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) - - # test freeze on cpu - model.freeze() - model.unfreeze() - - def test_cpu_model_with_amp(): """ Make sure model trains on CPU From 76e89b0c0cd902a92ec8cd79b4b5465c4b83308c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:44:33 -0400 Subject: [PATCH 26/29] updated test --- tests/test_models.py | 50 ++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index ac2b525da9b76..7a41c7799afd9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,31 +27,31 @@ # TESTS # ------------------------------------------------------------------------ -def test_early_stopping_cpu_model(): - """ - Test each of the trainer options - :return: - """ - - stopping = EarlyStopping(monitor='val_loss') - trainer_options = dict( - early_stop_callback=stopping, - gradient_clip=1.0, - overfit_pct=0.20, - track_grad_norm=2, - print_nan_grads=True, - progress_bar=False, - experiment=get_exp(), - train_percent_check=0.1, - val_percent_check=0.1 - ) - - model, hparams = get_model() - run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) - - # test freeze on cpu - model.freeze() - model.unfreeze() +# def test_early_stopping_cpu_model(): +# """ +# Test each of the trainer options +# :return: +# """ +# +# stopping = EarlyStopping(monitor='val_loss') +# trainer_options = dict( +# early_stop_callback=stopping, +# gradient_clip=1.0, +# overfit_pct=0.20, +# track_grad_norm=2, +# print_nan_grads=True, +# progress_bar=False, +# experiment=get_exp(), +# train_percent_check=0.1, +# val_percent_check=0.1 +# ) +# +# model, hparams = get_model() +# run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) +# +# # test freeze on cpu +# model.freeze() +# model.unfreeze() def test_no_val_module(): From f57f6625c38dbafeae1fbb2405ecf6d2a4129c6e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:46:21 -0400 Subject: [PATCH 27/29] updated test --- pytorch_lightning/models/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 7de7d39e668a7..3490131d439f7 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -317,9 +317,8 @@ def __is_overriden(self, f_name): model = self.__get_model() super_object = super(model.__class__, model) - import pdb - pdb.set_trace() - is_overriden = hasattr(model, f_name) and not hasattr(super_object, f_name) + # when code pointers are different, it was overriden + is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__ return is_overriden @property From 09442f1d73dde558fa4c5744eec29b4588fb99b1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 08:46:30 -0400 Subject: [PATCH 28/29] updated test --- tests/test_models.py | 50 ++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 7a41c7799afd9..ac2b525da9b76 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,31 +27,31 @@ # TESTS # ------------------------------------------------------------------------ -# def test_early_stopping_cpu_model(): -# """ -# Test each of the trainer options -# :return: -# """ -# -# stopping = EarlyStopping(monitor='val_loss') -# trainer_options = dict( -# early_stop_callback=stopping, -# gradient_clip=1.0, -# overfit_pct=0.20, -# track_grad_norm=2, -# print_nan_grads=True, -# progress_bar=False, -# experiment=get_exp(), -# train_percent_check=0.1, -# val_percent_check=0.1 -# ) -# -# model, hparams = get_model() -# run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) -# -# # test freeze on cpu -# model.freeze() -# model.unfreeze() +def test_early_stopping_cpu_model(): + """ + Test each of the trainer options + :return: + """ + + stopping = EarlyStopping(monitor='val_loss') + trainer_options = dict( + early_stop_callback=stopping, + gradient_clip=1.0, + overfit_pct=0.20, + track_grad_norm=2, + print_nan_grads=True, + progress_bar=False, + experiment=get_exp(), + train_percent_check=0.1, + val_percent_check=0.1 + ) + + model, hparams = get_model() + run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) + + # test freeze on cpu + model.freeze() + model.unfreeze() def test_no_val_module(): From ef2fc45b0a67f84b897773463b0bc3e6dfa2d426 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 11 Aug 2019 09:17:57 -0400 Subject: [PATCH 29/29] fix pep8 --- pytorch_lightning/testing/__init__.py | 2 +- tests/test_models.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/testing/__init__.py b/pytorch_lightning/testing/__init__.py index 7beb286e30790..b3289a1c71475 100644 --- a/pytorch_lightning/testing/__init__.py +++ b/pytorch_lightning/testing/__init__.py @@ -1,3 +1,3 @@ from .lm_test_module import LightningTestModel from .no_val_end_module import NoValEndTestModel -from .no_val_module import NoValModel \ No newline at end of file +from .no_val_module import NoValModel diff --git a/tests/test_models.py b/tests/test_models.py index ac2b525da9b76..ac0b68603f420 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -142,7 +142,6 @@ def test_no_val_end_module(): clear_save_dir() - def test_simple_cpu(): """ Verify continue training session on CPU