-
Notifications
You must be signed in to change notification settings - Fork 744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify scalers, move to gluonts.torch.scaler
#2632
Conversation
very cool! LGTM! |
gluonts.torch.scaler
gluonts.torch.scaler
src/gluonts/torch/scaler.py
Outdated
@@ -37,21 +36,12 @@ class MeanScaler(nn.Module): | |||
minimum possible scale that is used for any item. | |||
""" | |||
|
|||
@validated() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We loose the validation property?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I use our own dataclass
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's give it a try
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll just keep them validated for now
src/gluonts/torch/scaler.py
Outdated
default_scale: float = 0.0 | ||
minimum_scale: float = 1e-10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these need to be torch.tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.clamp
also deals with numbers; for the default_scale
, isn't this some kind of "immutable" (mind the quotes) property of the object, so that replacing torch.where
with if
is not really harmful?
src/gluonts/torch/scaler.py
Outdated
self.register_buffer("minimum_scale", torch.tensor(minimum_scale)) | ||
dim: int = -1 | ||
keepdim: bool = False | ||
minimum_scale: float = 1e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, minimum_scale
should be a torch.Tensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? I think we can add a torch.Tensor
to a number
|
src/gluonts/torch/scaler.py
Outdated
self.default_scale, | ||
batch_scale, | ||
) | ||
if self.default_scale is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the effect if this change on tracing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that, as long as self.default_scale
stays constant (which it should) then tracing should be fine with it and produce the right code for the model. But I need to verify
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also according to the warning here https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch-jit-trace it should be fine as long as the control flow is not affected by the value in input tensors (and I would add: if it’s not otherwise changed throughout the model execution, like the forward
changing the value of self.default_scale
here for some reason)
7bbe5be
to
25921da
Compare
364c701
to
d5bcf75
Compare
gluonts.torch.scaler
gluonts.torch.scaler
gluonts.torch.scaler
gluonts.torch.scaler
Sorry for the late comment but any takers for Scalers being of |
@abdulfatir on the output side the distribution is an AffineTransform which takes the loc and scale from the scalers... but you mean more general scalers? And I believe currently it's an appropriate place to have the scalers (in the model) since then the model can use the loc and scale as input as well (instead of just on the emission side). |
Currently, I don't have an example of a general scaler in mind, but yes, using The primary benefit IMO is clarity. You have a scaler (Transform) that normalizes the data and then use its inverse (and |
Also, inside models I think we should be able to provide a |
@abdulfatir I think it makes sense to consider this. The (log) scale should be accessible via |
@lostella If the transform is of |
For completeness, I would like to add that the inverse transform (which we would need at the output side) is available via the import torch
from torch.distributions.transforms import AffineTransform
tr = AffineTransform(10., 1.)
inv_tr = tr.inv
x = torch.rand(2, 3, 4)
y = tr(x)
torch.allclose(x, inv_tr(y), atol=1e-5) # True |
The catch is that these scaling operations are not really affine transformations of the data: for some array |
@abdulfatir agreed! We can do that in a separate PR |
Description of changes: There's no need for scalers to be
torch.nn.Module
since they don't really hold parameters. Also fixes the defaultkeepdim
ofMeanScaler
for consistency.cc @kashif
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup