Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save_hyperparameters and mixins/inheritance #16206

Closed
RuRo opened this issue Dec 26, 2022 · 13 comments · Fixed by #16369
Closed

save_hyperparameters and mixins/inheritance #16206

RuRo opened this issue Dec 26, 2022 · 13 comments · Fixed by #16369
Labels
discussion In a discussion stage feature Is an improvement or enhancement

Comments

@RuRo
Copy link
Contributor

RuRo commented Dec 26, 2022

Bug description

Consider the following piece of code:

import pytorch_lightning as pl

class A:
    def __init__(self, bar="abc", **kwargs):
        self.save_hyperparameters()
        super().__init__(**kwargs)

        print("A")
        print(self.hparams)

class B:
    def __init__(self, foo=123, **kwargs):
        self.save_hyperparameters()
        super().__init__(**kwargs)

        print("B")
        print(self.hparams)

class M(A, B, pl.LightningModule):
    pass

m = M()

Prior to 1.7.2, this used to print

B
"bar": abc
"foo": 123
A
"bar": abc
"foo": 123

but starting with 1.7.2 this now prints

B

A

This change was probably unintentionally caused by #14151.

Given that the pytorch-lightning codebase itself heavily uses Mixins, I would expect that this should continue working as before. 😆

How to reproduce the bug

No response

Error messages and logs

No response

Environment

I have bisected the breaking change to 1.7.2.

pip install 'pytorch-lightning<1.7.2'  # good
pip install 'pytorch-lightning==1.7.2' # bad

More info

No response

cc @Borda

@RuRo RuRo added the needs triage Waiting to be triaged by maintainers label Dec 26, 2022
@awaelchli awaelchli added feature Is an improvement or enhancement discussion In a discussion stage and removed needs triage Waiting to be triaged by maintainers labels Dec 27, 2022
@awaelchli
Copy link
Member

Hi @RuRo

I'm not sure how this can be supported in a reliable way. Any ideas?
Btw, the example you show here is not the intended way of using save_hyperparameters(). The call to save_hyperparameters should always happen in the most specific class, never in the parents. I think we want to be careful here, and not encourage the use of multiple inheritance this way.

@RuRo
Copy link
Contributor Author

RuRo commented Dec 27, 2022

I'm not sure how this can be supported in a reliable way. Any ideas?

Please, correct me, if I'm wrong: the current implementation walks the stack frame starting from the call site, going up. It checks, if the frame is inside a class, that inherits from LM or LDM and saves the hyper-parameters, if it does.

If my understanding is correct, then the simplest way to fix this is simply to force the parameter saving for the innermost frame (i.e. the frame, in which save_hyperparameters was actually called) regardless of the __class__.

Btw, the example you show here is not the intended way of using save_hyperparameters(). The call to save_hyperparameters should always happen in the most specific class, never in the parents.

Can you clarify this point? Is this documented anywhere? What problems could be caused by calling save_hyperparameters in the parents? Since the save_hyperparameters call walks the stack at the end of the day, I see no harm in calling it in the parent class.

I think we want to be careful here, and not encourage the use of multiple inheritance this way.

This is "multiple inheritance" only in the strictest sense of the word. The method resolution order is strictly linear and the final class M behaves almost exactly like if it had a linear inheritance chain M -> A -> B -> LM.

Mixins are incredibly useful for pytorch-lightning, because they allow me to decompose parts of my module and mix and match them as required. Given that pytorch-lightning itself heavily uses Mixins in it's codebase, I think that it would be a little hypocritical to declare this practice as "too dangerous" for it's users.

@awaelchli
Copy link
Member

I did not write "No, we won't do this". I wrote "I'm not sure how this can be supported in a reliable way. Any ideas?". This is not hypocritical, it means that I am curious how this would be supported correctly given the requirements, it don't see how it can be at the moment.

Can you clarify this point? Is this documented anywhere?

The hard requirement is that arguments captured by save_hyperparameters must be the ones that reconstruct the object, i.e., take the set of saved hparams and M(**hparams) must give the same object as it was when instantiated the first time. This is a requirement because this is part of the feature that allows us to restore the model with M.load_from_checkpoint. The counter example is

from pytorch_lightning.core.module import LightningModule


class A(LightningModule):
    def __init__(self, foo=1):
        super().__init__()
        self.save_hyperparameters()


class B(LightningModule):
    def __init__(self, foo=2):
        super().__init__()
        self.save_hyperparameters()


class Main(A, B, LightningModule):
    def __init__(self, bar):
        super().__init__()
        self.save_hyperparameters()


m = Main(bar=2)
print(m.hparams)

# Output:
# "bar": 2
# "foo": 1

Here, due to the fact that save_hyperparameters() gets called at multiple levels, we save the parameter "foo", which we shouldn't. We saved "too many" parameters, and when passing them to the original init, we get an error:

Main(**m.params)

TypeError: Main.__init__() got an unexpected keyword argument 'foo'

These are the sort of issues we want to avoid. That's why we don't show examples with multiple inheritance and only ever show save_hyperparameters() in the one LightningModule in examples. Unfortunately, we haven't found a way to detect this and error out in a reliable way (without false positive on valid usage).

@awaelchli
Copy link
Member

Perhaps one could redesign the feature in some way, and name the method differently. Like self.log_hyperparameters() and then it can capture whatever it wants without the requirement that these parameters be used to re-instantiate the model.

@RuRo
Copy link
Contributor Author

RuRo commented Dec 27, 2022

I did not write "No, we won't do this". I wrote "I'm not sure how this can be supported in a reliable way. Any ideas?". This is not hypocritical, it means that I am curious how this would be supported correctly given the requirements, it don't see how it can be at the moment.

And I didn't write "You are being hypocritical", I wrote that outright forbidding multiple inheritance or mixins solely based on the fact, that it's a dangerous would be hypocritical. Perhaps, that part of my message came across a bit too harsh. My apologies. I guess, a better way to word this would have been "I don't think, that we should prohibit such use, because ...".

I'll read and respond to the rest of your comment a bit later, I just wanted to preemptively address this part.

@RuRo
Copy link
Contributor Author

RuRo commented Dec 28, 2022

The hard requirement is that arguments captured by save_hyperparameters must be the ones that reconstruct the object, i.e., take the set of saved hparams and M(**hparams) must give the same object as it was when instantiated the first time. This is a requirement because this is part of the feature that allows us to restore the model with M.load_from_checkpoint. The counter example is

I think, that this example doesn't actually have to do anything with multiple inheritance or Mixins. The issue here is that in this case the __init__ method doesn't follow the Liskov Substitution Principle. Here's a simplified example, that also has the same problem, but doesn't have multiple inheritance:

import pytorch_lightning as pl
class A(pl.LightningModule):
    def __init__(self, foo=1):
        super().__init__()
        self.save_hyperparameters()

class Main(A):
    def __init__(self, bar):
        super().__init__()
        self.save_hyperparameters()

m = Main(bar=2)
print(m.hparams)
# "bar": 2
# "foo": 1

m2 = Main(**m.hparams)
# TypeError: Main.__init__() got an unexpected keyword argument 'foo'

P.S. Calling save_hyperparameters only from the innermost class also doesn't fix the issue in this case.

@awaelchli
Copy link
Member

awaelchli commented Dec 28, 2022

Yes exactly, the issue is simply that self.save_hyperparameters() gets called multiple times in different contexts. So if you call it only from the inner most class, then it works as expected (just comment the call in A.init).

import pytorch_lightning as pl
class A(pl.LightningModule):
    def __init__(self, foo=1):
        super().__init__()
        # self.save_hyperparameters()  # not allowed here

class Main(A):
    def __init__(self, bar):
        super().__init__()
        self.save_hyperparameters()

m = Main(bar=2)
print(m.hparams)
m2 = Main(**m.hparams)
print(m2.hparams)

To prevent this, we could "count" the number of times the method was called, and error once it gets called multiple times on the instance. This doesn't support your use case, but at least we would be showing a meaningful error of unintended misuse.

@RuRo
Copy link
Contributor Author

RuRo commented Dec 28, 2022

@awaelchli Ah, then I misunderstood, which class is the "innermost". (I would call A the innermost class in this case).

However, if your explanation is correct, what is even the point of save_hyperparameters walking up the call stack? I think, that we should take a step back and re-examine, how we want save_hyperparameters to behave in different situations. Even without any inheritance (except directly from LightningModule) the current implementation can produce quite surprising results.

@RuRo
Copy link
Contributor Author

RuRo commented Dec 28, 2022

For example, consider

class ModuleA(pl.LightningModule):
    def __init__(self, foo=1):
        super().__init__()
        self.save_hyperparameters()

class ModuleB(pl.LightningModule):
    def __init__(self, bar=2):
        super().__init__()
        self.save_hyperparameters()
        self.composition = ModuleA()

b = ModuleB()
a = b.composition

print(a.hparams)
# "bar": 2
# "foo": 1

a = a.__class__(**a.hparams)
# ModuleA.__init__() got an unexpected keyword argument 'bar'

There is no inheritance in this example and it's still fundamentally broken, because the save_hyperparameters() call inside ModuleA keeps walking up the stack past the ModuleA.__init__ call.

@awaelchli
Copy link
Member

awaelchli commented Dec 28, 2022

A second option would be to record both types of hyperparameters.

  1. The ones from the inner most call (how to detect that?? Save these in a container that is used to reconstruct the object.

  2. The total number of "hyperparameters", accumulated over multiple calls of save_hyperparameters (maybe this is what's being logged to the logger?

  3. definitely needs to go into the checkpoint, because that's what we need to satisfy the important requirement described in save_hyperparameters and mixins/inheritance #16206 (comment)

what is even the point of save_hyperparameters walking up the call stack

It's an artifact of an older design that did try to take inherited calls into account. At the same time, we it supports save_hyperparameters being called from a method or function, for example:

class A(pl.LightningModule):
    def __init__(self, foo=1):
        super().__init__()
        init_me()

    def init_me(self):
        self.save_hyperparameters()  # want to walk up call stack to find init local vars

Although this is not exactly a pattern we show in examples, it must be supported

@awaelchli
Copy link
Member

For reference, here are all test cases. In case you want to study the edge cases we have:
https://github.com/Lightning-AI/lightning/blob/master/tests/tests_pytorch/models/test_hparams.py

@RuRo
Copy link
Contributor Author

RuRo commented Dec 28, 2022

Okay, then how about this. Currently, collect_init_args checks if issubclass(local_vars["__class__"], classes) what do you think about replacing this check with isinstance(init_self, classes), where init_self is obtained via get_init_args (the argument doesn't have to be actually called "self", it's just the first argument accepted by __init__).

As far as I can tell, this would fix my use case while preserving the current semantics.

@awaelchli
Copy link
Member

Feel free to take stab at it. You can run the tests with

pytest -v tests/tests_pytorch/models/tests_hparams.py 

to see if anything breaks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion In a discussion stage feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants