New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
save_hyperparameters
and mixins/inheritance
#16206
Comments
Hi @RuRo I'm not sure how this can be supported in a reliable way. Any ideas? |
Please, correct me, if I'm wrong: the current implementation walks the stack frame starting from the call site, going up. It checks, if the frame is inside a class, that inherits from LM or LDM and saves the hyper-parameters, if it does. If my understanding is correct, then the simplest way to fix this is simply to force the parameter saving for the innermost frame (i.e. the frame, in which
Can you clarify this point? Is this documented anywhere? What problems could be caused by calling
This is "multiple inheritance" only in the strictest sense of the word. The method resolution order is strictly linear and the final class Mixins are incredibly useful for |
I did not write "No, we won't do this". I wrote "I'm not sure how this can be supported in a reliable way. Any ideas?". This is not hypocritical, it means that I am curious how this would be supported correctly given the requirements, it don't see how it can be at the moment.
The hard requirement is that arguments captured by from pytorch_lightning.core.module import LightningModule
class A(LightningModule):
def __init__(self, foo=1):
super().__init__()
self.save_hyperparameters()
class B(LightningModule):
def __init__(self, foo=2):
super().__init__()
self.save_hyperparameters()
class Main(A, B, LightningModule):
def __init__(self, bar):
super().__init__()
self.save_hyperparameters()
m = Main(bar=2)
print(m.hparams)
# Output:
# "bar": 2
# "foo": 1 Here, due to the fact that save_hyperparameters() gets called at multiple levels, we save the parameter "foo", which we shouldn't. We saved "too many" parameters, and when passing them to the original init, we get an error:
These are the sort of issues we want to avoid. That's why we don't show examples with multiple inheritance and only ever show |
Perhaps one could redesign the feature in some way, and name the method differently. Like |
And I didn't write "You are being hypocritical", I wrote that outright forbidding multiple inheritance or mixins solely based on the fact, that it's a dangerous would be hypocritical. Perhaps, that part of my message came across a bit too harsh. My apologies. I guess, a better way to word this would have been "I don't think, that we should prohibit such use, because ...". I'll read and respond to the rest of your comment a bit later, I just wanted to preemptively address this part. |
I think, that this example doesn't actually have to do anything with multiple inheritance or Mixins. The issue here is that in this case the import pytorch_lightning as pl
class A(pl.LightningModule):
def __init__(self, foo=1):
super().__init__()
self.save_hyperparameters()
class Main(A):
def __init__(self, bar):
super().__init__()
self.save_hyperparameters()
m = Main(bar=2)
print(m.hparams)
# "bar": 2
# "foo": 1
m2 = Main(**m.hparams)
# TypeError: Main.__init__() got an unexpected keyword argument 'foo' P.S. Calling |
Yes exactly, the issue is simply that import pytorch_lightning as pl
class A(pl.LightningModule):
def __init__(self, foo=1):
super().__init__()
# self.save_hyperparameters() # not allowed here
class Main(A):
def __init__(self, bar):
super().__init__()
self.save_hyperparameters()
m = Main(bar=2)
print(m.hparams)
m2 = Main(**m.hparams)
print(m2.hparams) To prevent this, we could "count" the number of times the method was called, and error once it gets called multiple times on the instance. This doesn't support your use case, but at least we would be showing a meaningful error of unintended misuse. |
@awaelchli Ah, then I misunderstood, which class is the "innermost". (I would call However, if your explanation is correct, what is even the point of |
For example, consider class ModuleA(pl.LightningModule):
def __init__(self, foo=1):
super().__init__()
self.save_hyperparameters()
class ModuleB(pl.LightningModule):
def __init__(self, bar=2):
super().__init__()
self.save_hyperparameters()
self.composition = ModuleA()
b = ModuleB()
a = b.composition
print(a.hparams)
# "bar": 2
# "foo": 1
a = a.__class__(**a.hparams)
# ModuleA.__init__() got an unexpected keyword argument 'bar' There is no inheritance in this example and it's still fundamentally broken, because the |
A second option would be to record both types of hyperparameters.
It's an artifact of an older design that did try to take inherited calls into account. At the same time, we it supports class A(pl.LightningModule):
def __init__(self, foo=1):
super().__init__()
init_me()
def init_me(self):
self.save_hyperparameters() # want to walk up call stack to find init local vars Although this is not exactly a pattern we show in examples, it must be supported |
For reference, here are all test cases. In case you want to study the edge cases we have: |
Okay, then how about this. Currently, As far as I can tell, this would fix my use case while preserving the current semantics. |
Feel free to take stab at it. You can run the tests with
to see if anything breaks. |
Bug description
Consider the following piece of code:
Prior to 1.7.2, this used to print
but starting with 1.7.2 this now prints
This change was probably unintentionally caused by #14151.
Given that the
pytorch-lightning
codebase itself heavily uses Mixins, I would expect that this should continue working as before. 😆How to reproduce the bug
No response
Error messages and logs
No response
Environment
I have bisected the breaking change to
1.7.2
.More info
No response
cc @Borda
The text was updated successfully, but these errors were encountered: