Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using Zygote #669

Merged
merged 92 commits into from Sep 11, 2019
Merged

using Zygote #669

merged 92 commits into from Sep 11, 2019

Conversation

@MikeInnes
Copy link
Member

@MikeInnes MikeInnes commented Mar 8, 2019

Otherwise known as "break all the things". This will be a huge change so I'm beginning to prepare now, even though Zygote is still a couple of months off from being really ready. Do not try this at home (yet) – this branch is eventually aimed at beta testers, but isn't even ready for that yet.

The idea is to break as little code as possible, which means supporting the current Params API; but I also want to start prototyping the nicer things discussed in #628 and other issues.

Blocking issues:

Nice to have:

  • Robust nested AD (may not be a blocker if one can still use Tracker with Flux).
  • Zygote support for modules / globals as discussed in #628, along with #637.
  • Better train/test mode as in #643.

If you're the kind of person who ignores triangular road signs, you can try this with

]add Flux#zygote Zygote#master
@MikeInnes
Copy link
Member Author

@MikeInnes MikeInnes commented Apr 4, 2019

I initially had started doing #666 and #637 on this branch, but that's turning into a big project, so I've stripped the commits for now (preserved on mji/step). As excited as I am to get rid of Params I think the right move is to shelve that for now and focus on the core AD issues.

@staticfloat
Copy link
Contributor

@staticfloat staticfloat commented Apr 9, 2019

In my own experiments trying to use Zygote with Flux (I just do model = mapleaves(Flux.data, model) first, then define my own Zygote-based update step:

zyg_update!(opt, model, updates::Nothing) = nothing
function zyg_update!(opt, model::AbstractArray, updates::AbstractArray)
    # Sub off to Flux's ADAM optimizer
    Δ = Flux.Optimise.update!(opt, model, updates)
    return model .-= Δ
end

function zyg_update!(opt, model, updates)
    if nfields(model) == 0
        return model
    end

    for field_idx in 1:nfields(model)
        zyg_update!(opt, getfield(model, field_idx), getfield(updates, field_idx))
    end
end

Things actually work fairly well, except BatchNorm freaks out, complaining about mutating arrays. To work around this, I am using my own BatchNorm implementation, re-architected to work with Zygote. Not sure if that is the direction you want to go with this Mike, but it worked well for us on TPU. I will note, anecdotally, that my convolution and batchnom-heavy workload (a convolutional autoencoder for large images) uses ~20% less memory with Zygote than with Tracker.

@staticfloat
Copy link
Contributor

@staticfloat staticfloat commented May 3, 2019

I wanted to use this with the new NNlib overhaul, so I rebased this branch on top of master; I'm not certain I did everything right, but sf/zygote_updated contains my rebased version. Mike, if you like it, you can just force-push it to #zygote and keep on working, or you can just do the rebase yourself.

test/cuda/curnn.jl Show resolved Hide resolved
@@ -3,25 +3,25 @@
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.

```julia
using Flux, Flux.Tracker
using Flux, Flux.Zygote
Copy link
Member Author

@MikeInnes MikeInnes Sep 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhairyagandhi96 Flux already exports gradient, so this may not be necessary

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi Sep 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

docs/src/training/optimisers.md Outdated Show resolved Hide resolved
docs/src/training/optimisers.md Outdated Show resolved Hide resolved
Co-Authored-By: Mike J Innes <mike.j.innes@gmail.com>
@MikeInnes MikeInnes marked this pull request as ready for review Sep 11, 2019
@MikeInnes
Copy link
Member Author

@MikeInnes MikeInnes commented Sep 11, 2019

bors r+

bors bot added a commit that referenced this issue Sep 11, 2019
669: using Zygote r=MikeInnes a=MikeInnes

Otherwise known as "break all the things". This will be a huge change so I'm beginning to prepare now, even though Zygote is still a couple of months off from being really ready. **Do not try this at home** (yet) – this branch is eventually aimed at beta testers, but isn't even ready for that yet.

The idea is to break as little code as possible, which means supporting the current `Params` API; but I also want to start prototyping the nicer things discussed in #628 and other issues.

Blocking issues:

* [x] Get the tests passing.
* [x] Check tests on GPU.
* [x] Rewrite all the docs.
* [x] Cache invalidation (JuliaLabs/Cassette.jl#6).
* [x] Moving over adjoints (FluxML/Zygote.jl#81).
* [x] General Zygote robustness.

Nice to have:

* [ ] Robust nested AD (may not be a blocker if one can still use Tracker with Flux).
* [x] Zygote support for modules / globals as discussed in #628, along with #637.
* [x] Better train/test mode as in #643.

If you're the kind of person who ignores triangular road signs, you can try this with

```julia
]add Flux#zygote Zygote#master
```

Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
Co-authored-by: Elliot Saba <staticfloat@gmail.com>
Co-authored-by: thebhatman <manjunathbhat9920@gmail.com>
@MikeInnes
Copy link
Member Author

@MikeInnes MikeInnes commented Sep 11, 2019

Seem to be some issues with our GPU CI, so just merging.

@MikeInnes MikeInnes merged commit bdeb9c6 into master Sep 11, 2019
1 of 3 checks passed
@bors
Copy link
Contributor

@bors bors bot commented Sep 11, 2019

Build failed

rnn.state = Tracker.data(rnn.state)
"""
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
Copy link
Contributor

@iblis17 iblis17 Feb 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, is there an alternative?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

5 participants