-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add testmode! back for normalization layers #1044
Conversation
I'm dubious about defaulting to :auto in testmode!(m, mode=true) # set to test mode when called 1 arg
testmode!(m, mode) # force to input the mode |
btw, rebase has gone wrong here |
Yeah I agree. I'll change that.
Any suggestions on how to resolve? I see two options — wait until the other PR I have outstanding has merged into master, or roll back the commits above then try to rebase from upstream. I haven't done the latter before, but I can figure it out. The weird commit history is why I set this PR as a draft. |
your other PR is orthogonal to this, so you could have just branched from current master without any fear about the merging order. I don't know what would be the best way to fix this. What I would do, since the PR small, would be something a little bit dirty:
|
I ended up rebasing from before the other PR. The commit history should be clean now. |
@dhairyagandhi96 @MikeInnes @xukai92 any comments on this? This is non-breaking, but once it gets in reverting would be breaking |
@darsnack should we also add |
Are you thinking of something like |
I did update |
yeah, I was thinking trainmode!(m, mode=true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) I know it's redundant, but it seems more natural to not break arbitrarily the symmetry toward testmode in the interface. This way, one would do |
Also, we should warn in the docstring to set back the model to train once test phase is over, or something along these lines |
ok, looks good, we can merge whenever you are ready. If you also bump Flux's minor version in the project.toml we can also proceed with tagging a new release |
I think this is ready to merge |
wait, it looks like the version was already bumped, sorry, could your change it back? |
Oops should probably have double-checked that after merging. Should be good to go again. |
nice, thanks! bors r+ |
Build succeeded |
This is awesome! Could this be tagged? I would love using this as soon as possible! :) |
This turns out to be breaking for cases where people instantiate a layer by manually specifying the fields instead of using the outer constructor. (e.g. ObjectDetector.jl) |
I would've thought that simply defining the |
ouch |
This is the problematic line in ObjectDetector.jl |
Sufficient for the functionality added in this PR? I don't follow how that allows users to force train/test mode on a per layer basis. |
That would prevent the breakages too, I think |
If the usage in ObjectDetector is unusual, I'd happily take a PR to correct it |
for freezing layers, the recommended way would be via the parameters anyway, so I think it's an orthogonal concern, unless that's specifically desired |
Even if it is unusual, adding fields to a type is a breaking change unless the field defaults to a value. @ianshmean your PR should have been part of this one — that was my bad.
The normalization layers update some fields on the forward pass when part of an AD trace. I don't think freezing by excluding the parameters will stop this update. This is standard, but v0.10 made the change of automatically deciding whether to update these fields or not. There are use-cases that are not standard training where the fields should not be update even though a gradient is being computed. |
The parameters that we collect are the ones we update, I believe |
I think I am missing something. If you could explain how to address #909 with the |
I suggest reverting this and working on it some more. I have a bunch of concerns about the API, and the code is pretty strange (e.g. why are there a bunch of additions to |
I added the generic functions that work on any layer to
By this do you mean freezing via trainable(bn::BatchNorm) = (bn.β, bn.γ) This affects what gets updated during the gradient step. I don't see how that affects L189-190: μ = mean(x, dims = axes)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m These are recomputed every pass that is contained within an AD trace regardless of what μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...) In my mind, latching onto functor/trainable means rewriting the normalization layers so that all updates are part of the gradient computation (i.e. custom adjoints). |
The |
1432: Generalize train/testmode! to all Functors r=CarloLucibello a=ToucheSir Addresses #1044 (comment). See also https://discourse.julialang.org/t/do-i-have-to-implement-flux-testmode-for-my-own-models/52038. ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
Fixed #909
I added
testmode!(m, mode)
back to Flux as per v0.9. Now themode
can befalse
,true
, or:auto
/nothing
with the default being:auto
for newly constructed layers. In:auto
mode, theistraining()
functions added in v0.10 are used to determine whether we are evaluating within an AD trace or not.Also plan on adding a doc section in an additional commit.