Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace Hparams by init args #1896

Merged
merged 101 commits into from May 24, 2020
Merged
Show file tree
Hide file tree
Changes from 100 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
1fced53
remove the need for hparams
williamFalcon May 19, 2020
7fe5f13
remove the need for hparams
williamFalcon May 19, 2020
0283055
remove the need for hparams
williamFalcon May 19, 2020
599c9ad
remove the need for hparams
williamFalcon May 19, 2020
32c7435
replace self.hparams
williamFalcon May 19, 2020
29d3e0a
replace self.hparams
williamFalcon May 19, 2020
f508424
replace self.hparams
williamFalcon May 19, 2020
28b85bd
replace self.hparams
williamFalcon May 19, 2020
355eb7a
replace self.hparams
williamFalcon May 19, 2020
5cc272a
replace self.hparams
williamFalcon May 19, 2020
a5bcd1c
replace self.hparams
williamFalcon May 19, 2020
a4a7407
replace self.hparams
williamFalcon May 19, 2020
8f7e8a2
replace self.hparams
williamFalcon May 19, 2020
b1cd0b5
replace self.hparams
williamFalcon May 19, 2020
e97237e
replace self.hparams
williamFalcon May 19, 2020
a2f6cb5
replace self.hparams
williamFalcon May 19, 2020
9216d28
replace self.hparams
williamFalcon May 19, 2020
7cbc1b2
replace self.hparams
williamFalcon May 19, 2020
137ae13
replace self.hparams
williamFalcon May 19, 2020
b6a9336
replace self.hparams
williamFalcon May 19, 2020
6ea138c
replace self.hparams
williamFalcon May 19, 2020
485ce20
replace self.hparams
williamFalcon May 19, 2020
14dab1b
replace self.hparams
williamFalcon May 19, 2020
268277a
replace self.hparams
williamFalcon May 19, 2020
2111e4b
replace self.hparams
williamFalcon May 19, 2020
07a1c00
replace self.hparams
williamFalcon May 19, 2020
90a1226
replace self.hparams
williamFalcon May 19, 2020
4429d22
replace self.hparams
williamFalcon May 19, 2020
6060a02
replace self.hparams
williamFalcon May 19, 2020
6f856df
replace self.hparams
williamFalcon May 19, 2020
da385fe
replace self.hparams
williamFalcon May 19, 2020
065226d
replace self.hparams
williamFalcon May 19, 2020
f634a8e
replace self.hparams
williamFalcon May 19, 2020
34055b5
replace self.hparams
williamFalcon May 19, 2020
e05c11b
replace self.hparams
williamFalcon May 19, 2020
0937108
replace self.hparams
williamFalcon May 19, 2020
f6587ce
fixed
williamFalcon May 19, 2020
0303695
fixed
williamFalcon May 19, 2020
2b0ceb8
fixed
williamFalcon May 19, 2020
e226c88
fixed
williamFalcon May 19, 2020
ec00520
fixed
williamFalcon May 19, 2020
72793c3
fixed
williamFalcon May 19, 2020
840265d
fixed
williamFalcon May 19, 2020
4bb28fa
fixed
williamFalcon May 19, 2020
91569a8
fixed
williamFalcon May 19, 2020
509036e
fixed
williamFalcon May 19, 2020
a99ffb7
fixed
williamFalcon May 19, 2020
5c3ea20
fixed
williamFalcon May 19, 2020
9d08be3
fixed
williamFalcon May 19, 2020
0452418
fixed
williamFalcon May 19, 2020
0b5557f
finished moco
williamFalcon May 20, 2020
6cd5ea9
basic
williamFalcon May 20, 2020
ed1090c
testing
Borda May 20, 2020
295654e
todo
Borda May 20, 2020
1465a03
recurse
Borda May 20, 2020
91ab93e
hparams
Borda May 20, 2020
0519723
persist
Borda May 20, 2020
a19df1d
hparams
Borda May 20, 2020
1f87263
chlog
Borda May 20, 2020
f35eab0
tests
Borda May 20, 2020
3555e83
tests
Borda May 20, 2020
2a1b2dc
tests
Borda May 20, 2020
3c79ae3
tests
Borda May 21, 2020
5767188
tests
Borda May 21, 2020
cbb00b5
tests
Borda May 21, 2020
acc020f
review
Borda May 21, 2020
b3b6236
saving
Borda May 21, 2020
b97e0b1
tests
Borda May 22, 2020
5a4740a
tests
Borda May 22, 2020
2a6be20
tests
Borda May 22, 2020
e80b006
docs
Borda May 22, 2020
e50b78f
finished moco
williamFalcon May 22, 2020
fd7be0d
hparams
Borda May 22, 2020
c319528
review
Borda May 22, 2020
b313477
Apply suggestions from code review
Borda May 22, 2020
488d18a
hparams
Borda May 22, 2020
d24b78e
overwrite
Borda May 22, 2020
0d7ee37
transform
Borda May 22, 2020
fb7898a
transform
Borda May 22, 2020
3d8a3db
transform
Borda May 22, 2020
db6f943
transform
Borda May 23, 2020
66717da
cleaning
Borda May 23, 2020
088c3bd
cleaning
Borda May 23, 2020
dfd3a26
tests
Borda May 23, 2020
72f4cd0
examples
Borda May 23, 2020
2a8872b
examples
Borda May 23, 2020
5fe6f02
examples
Borda May 23, 2020
bad8d11
Apply suggestions from code review
Borda May 24, 2020
55b58f7
chp key
Borda May 24, 2020
5383014
tests
Borda May 24, 2020
1fd8cce
Apply suggestions from code review
Borda May 24, 2020
ab3be59
class
Borda May 24, 2020
10ca1a8
Merge branch 'no_hparams' of https://github.com/PyTorchLightning/pyto…
Borda May 24, 2020
8f57274
updated docs
williamFalcon May 24, 2020
20ad2ca
updated docs
williamFalcon May 24, 2020
f683160
updated docs
williamFalcon May 24, 2020
db5d1bf
updated docs
williamFalcon May 24, 2020
6af55e7
save
Borda May 24, 2020
04e8e67
wip
Borda May 24, 2020
a432f6e
fix
Borda May 24, 2020
2892e5a
flake8
Borda May 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Expand Up @@ -133,4 +133,5 @@ mnist/
# pl tests
ml-runs/
*.zip
pytorch\ lightning
pytorch\ lightning
test-reports/
3 changes: 3 additions & 0 deletions .run_local_tests.sh
Expand Up @@ -14,3 +14,6 @@ rm -rf ./tests/tests/*
rm -rf ./lightning_logs
python -m coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8
python -m coverage report -m

# specific file
# python -m coverage run --source pytorch_lightning -m py.test -k test_trainer.py --flake8
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed non-finite values from loss in `LRFinder` ([#1862](https://github.com/PyTorchLightning/pytorch-lightning/pull/1862))

- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896))

### Deprecated

- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917))
Expand Down
122 changes: 86 additions & 36 deletions docs/source/hyperparameters.rst
Expand Up @@ -75,7 +75,7 @@ Now in your main trainer file, add the Trainer args, the program args, and add t
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
parser = Trainer.add_argparse_args(parser)

hparams = parser.parse_args()
args = parser.parse_args()

Now you can call run your program like so

Expand All @@ -87,39 +87,50 @@ Finally, make sure to start the training like so:

.. code-block:: python

# YES
model = LitModel(hparams)
trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...)
# init the trainer like this
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)

# NOT like this
trainer = Trainer(gpus=hparams.gpus, ...)

# init the model with Namespace directly
model = LitModel(args)

# or init the model with all the key-value pairs
dict_args = vars(args)
model = LitModel(**dict_args)

# NO
# model = LitModel(learning_rate=hparams.learning_rate, ...)
# trainer = Trainer(gpus=hparams.gpus, ...)
LightningModule hyperparameters
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

LightningModule hparams
^^^^^^^^^^^^^^^^^^^^^^^
.. warning:: The use of `hparams` is no longer recommended (but still supported)

Normally, we don't hard-code the values to a model. We usually use the command line to
modify the network and read those values in the LightningModule
LightningModule is just an nn.Module, you can use it as you normally would. However, there are
some best practices to improve readability and reproducibility.

1. It's more readable to specify all the arguments that go into a module (with default values).
This helps users of your module know everything that is required to run this.

.. testcode::

class LitMNIST(LightningModule):

def __init__(self, hparams):
def __init__(self, layer_1_dim=128, layer_2_dim=256, learning_rate=1e-4, batch_size=32, **kwargs):
super().__init__()
self.layer_1_dim = layer_1_dim
self.layer_2_dim = layer_2_dim
self.learning_rate = learning_rate
self.batch_size = batch_size

# do this to save all arguments in any logger (tensorboard)
self.hparams = hparams

self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
self.layer_2 = torch.nn.Linear(hparams.layer_1_dim, hparams.layer_2_dim)
self.layer_3 = torch.nn.Linear(hparams.layer_2_dim, 10)
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim)
self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.layer_2_dim)
self.layer_3 = torch.nn.Linear(self.layer_2_dim, 10)

def train_dataloader(self):
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
return DataLoader(mnist_train, batch_size=self.batch_size)

def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.learning_rate)
return Adam(self.parameters(), lr=self.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
Expand All @@ -130,20 +141,59 @@ modify the network and read those values in the LightningModule
parser.add_argument('--learning_rate', type=float, default=0.002)
return parser

Now pass in the params when you init your model
2. You can also pass in a dict or Namespace, but this obscures the parameters your module is looking
for. The user would have to search the file to find what is parametrized.

.. code-block:: python

# using a argparse.Namespace
class LitMNIST(LightningModule):

def __init__(self, hparams, *args, **kwargs):
super().__init__()
self.hparams = hparams

self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim)
self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10)

def train_dataloader(self):
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)

One way to get around this is to convert a Namespace or dict into key-value pairs using `**`

.. code-block:: python

parser = ArgumentParser()
parser = LitMNIST.add_model_specific_args(parser)
hparams = parser.parse_args()
model = LitMNIST(hparams)
args = parser.parse_args()
dict_args = vars(args)
model = LitMNIST(**dict_args)

Within any LightningModule all the arguments you pass into your `__init__` will be stored in
the checkpoint so that you know all the values that went into creating this model.

We will also add all of those values to the TensorBoard hparams tab (unless it's an object which
we won't). We also will store those values into checkpoints for you which you can use to init your
models.

.. code-block:: python

The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule.
This does two things:
class LitMNIST(LightningModule):

def __init__(self, layer_1_dim, some_other_param):
super().__init__()
self.layer_1_dim = layer_1_dim
self.some_other_param = some_other_param

self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim)

self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.some_other_param)
self.layer_3 = torch.nn.Linear(self.some_other_param, 10)


model = LitMNIST(10, 20)

1. It adds them automatically to TensorBoard logs under the hparams tab.
2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly.

Trainer args
^^^^^^^^^^^^
Expand Down Expand Up @@ -171,27 +221,27 @@ polluting the main.py file, the LightningModule lets you define arguments for ea

class LitMNIST(LightningModule):

def __init__(self, hparams):
def __init__(self, layer_1_dim, **kwargs):
super().__init__()
self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
self.layer_1 = torch.nn.Linear(28 * 28, layer_1_dim)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser])
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--layer_1_dim', type=int, default=128)
return parser

.. testcode::

class GoodGAN(LightningModule):

def __init__(self, hparams):
def __init__(self, encoder_layers, **kwargs):
super().__init__()
self.encoder = Encoder(layers=hparams.encoder_layers)
self.encoder = Encoder(layers=encoder_layers)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser])
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--encoder_layers', type=int, default=12)
return parser

Expand All @@ -201,14 +251,14 @@ Now we can allow each model to inject the arguments it needs in the ``main.py``
.. code-block:: python

def main(args):
dict_args = vars(args)

# pick model
if args.model_name == 'gan':
model = GoodGAN(hparams=args)
model = GoodGAN(**dict_args)
elif args.model_name == 'mnist':
model = LitMNIST(hparams=args)
model = LitMNIST(**dict_args)

model = LitMNIST(hparams=args)
trainer = Trainer.from_argparse_args(args)
trainer.fit(model)

Expand Down
13 changes: 6 additions & 7 deletions docs/source/lr_finder.rst
Expand Up @@ -36,18 +36,17 @@ hyperparameters of the model.
# default: no automatic learning rate finder
trainer = Trainer(auto_lr_find=False)

When the ``lr`` or ``learning_rate`` key in hparams exists, this flag sets your learning_rate.
In both cases, if the respective fields are not found, an error will be thrown.

This flag sets your learning rate which can be accessed via ``self.lr`` or ``self.learning_rate``.

.. testcode::

class LitModel(LightningModule):

def __init__(self, hparams):
self.hparams = hparams
def __init__(self, learning_rate):
self.learning_rate = learning_rate

def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr|self.hparams.learning_rate)
return Adam(self.parameters(), lr=(self.lr or self.learning_rate))

# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkafteNicki once we merge this we need to update setting LR.
I think it should just set it in the model at self.lr or self.learning_rate

Expand Down Expand Up @@ -97,7 +96,7 @@ of this would look like

# update hparams of the model
model.hparams.lr = new_lr

# Fit model
trainer.fit(model)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/training_tricks.rst
Expand Up @@ -67,7 +67,7 @@ a binary search.
.. code-block:: python

def train_dataloader(self):
return DataLoader(train_dataset, batch_size=self.hparams.batch_size)
return DataLoader(train_dataset, batch_size=self.batch_size)

.. warning::

Expand Down
55 changes: 28 additions & 27 deletions docs/source/weights_loading.rst
Expand Up @@ -59,24 +59,20 @@ Or disable it by passing
trainer = Trainer(checkpoint_callback=False)


The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init.
The Lightning checkpoint also saves the arguments passed into the LightningModule init
under the `module_arguments` key in the checkpoint.

.. note:: hparams is a `Namespace <https://docs.python.org/2/library/argparse.html#argparse.Namespace>`_.

.. testcode::

from argparse import Namespace
.. code-block:: python

# usually these come from command line args
args = Namespace(learning_rate=0.001)
class MyLightningModule(LightningModule):

# define you module to have hparams as the first arg
# this means your checkpoint will have everything that went into making
# this model (in this case, learning rate)
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, *args, **kwargs):
super().__init__()

def __init__(self, hparams, *args, **kwargs):
self.hparams = hparams
# all init args were saved to the checkpoint
checkpoint = torch.load(CKPT_PATH)
print(checkpoint['module_arguments'])
# {'learning_rate': the_value}

Manual saving
^^^^^^^^^^^^^
Expand All @@ -92,37 +88,42 @@ You can manually save checkpoints and restore your model from the checkpointed s
Checkpoint Loading
------------------

To load a model along with its weights, biases and hyperparameters use following method.
To load a model along with its weights, biases and `module_arguments` use following method.
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

model = MyLightingModule.load_from_checkpoint(PATH)
model.eval()
y_hat = model(x)

The above only works if you used `hparams` in your model definition

.. testcode::

class LitModel(LightningModule):
print(model.learning_rate)
# prints the learning_rate you used in this checkpoint

def __init__(self, hparams):
self.hparams = hparams
self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)
model.eval()
y_hat = model(x)

But if you don't and instead pass individual parameters
But if you don't want to use the values saved in the checkpoint, pass in your own here

.. testcode::

class LitModel(LightningModule):

def __init__(self, in_dim, out_dim):
self.l1 = nn.Linear(in_dim, out_dim)
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.l1 = nn.Linear(self.in_dim, self.out_dim)

you can restore the model like this

.. code-block:: python

# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)

# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)

# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)


Expand Down