Skip to content

Commit

Permalink
rrule errors, improvements, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 7, 2022
1 parent 310b71e commit 537b011
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 31 deletions.
55 changes: 26 additions & 29 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ _makelazy(x) = x
function _underscoredepth(ex::Expr)
# Meta.isexpr(ex, :tuple) && :_ in ex.args && return 10
ex.head in (:call, :kw, :(->), :block) || return 0
ex.args[1] == :(=>) && ex.args[2] == :_ && return 1
ex.args[1] === :(=>) && ex.args[2] === :_ && return 1
m = maximum(_underscoredepth, ex.args)
m == 0 ? 0 : m+1
end
_underscoredepth(ex) = Int(ex == :_)
_underscoredepth(ex) = Int(ex === :_)

function _makefun(ex)
T = Meta.isexpr(ex, :call) ? ex.args[1] : Type
Expand All @@ -281,7 +281,7 @@ autosizefor(::Type, x::AbstractArray) = size(x, max(1, ndims(x)-1))
autosizefor(::Type{<:Dense}, x::AbstractArray) = size(x, 1)
autosizefor(::Type{<:LayerNorm}, x::AbstractArray) = size(x, 1)

_replaceunderscore(e, s) = e == :_ ? s : e
_replaceunderscore(e, s) = e === :_ ? s : e
_replaceunderscore(ex::Expr, s) = Expr(ex.head, map(a -> _replaceunderscore(a, s), ex.args)...)

mutable struct LazyLayer
Expand All @@ -290,45 +290,42 @@ mutable struct LazyLayer
layer
end

function (l::LazyLayer)(x::AbstractArray)
l.layer == nothing || return l.layer(x)
lay = l.make(x)
y = lay(x)
l.layer = lay # mutate after we know that call worked
@functor LazyLayer

function (l::LazyLayer)(x::AbstractArray, ys::AbstractArray...)
l.layer === nothing || return l.layer(x, ys...)
made = l.make(x) # for something like `Bilinear((_,__) => 7)`, perhaps need `make(xy...)`, later.
y = made(x, ys...)
l.layer = made # mutate after we know that call worked
return y
end

#=
Flux.outputsize(Chain(Dense(2=>3)), (4,)) # nice error
Flux.outputsize(Dense(2=>3), (4,)) # no nice error
@autosize (4,) Dense(2=>3) # no nice error
@autosize (3,) Dense(2 => _) # shouldn't work, weird error
@autosize (3,5,6) LayerNorm(_,_) # no complaint, but
ans(rand(3,5,6)) # this fails
=#

@functor LazyLayer

function striplazy(x)
fs, re = functor(x)
function striplazy(m)
fs, re = functor(m)
re(map(striplazy, fs))
end
striplazy(l::LazyLayer) = l.layer == nothing ? error("should be initialised!") : l.layer
function striplazy(l::LazyLayer)
l.layer === nothing || return l.layer
error("LazyLayer should be initialised, e.g. by outputsize(model, size), before using stiplazy")
end

# Could make LazyLayer usable outside of @autosize, for instance allow Chain(@lazy Dense(_ => 2))?
# But then it will survive to produce weird structural gradients etc.

function ChainRulesCore.rrule(l::LazyLayer, x)
l(x), _ -> error("LazyLayer should never be used within a gradient. Call striplazy(model) first to remove all.")
end
function ChainRulesCore.rrule(::typeof(striplazy), m)
striplazy(m), _ -> error("striplazy should never be used within a gradient")
end

params!(p::Params, x::LazyLayer, seen = IdSet()) = error("LazyLayer should never be used within params(m). Call striplazy(m) first.")
function Base.show(io::IO, l::LazyLayer)
printstyled(io, "LazyLayer(", color=:light_black)
if l.layer == nothing
printstyled(io, l.str, color=:red)
printstyled(io, l.str, color=:magenta)
else
printstyled(io, l.layer, color=:green)
printstyled(io, l.layer, color=:cyan)
end
printstyled(io, ")", color=:light_black)
end
Expand Down
26 changes: 24 additions & 2 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ end
m = @autosize (3,) Dense(_ => 4)
@test randn(3) |> m |> size == (4,)

m = @autosize (3, 1) Chain(Dense(_ => 4), Dense(4 => 10), softmax)
m = @autosize (3, 1) Chain(Dense(_, 4), Dense(4 => 10), softmax)
@test randn(3, 5) |> m |> size == (10, 5)

m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last
Expand Down Expand Up @@ -201,6 +201,9 @@ end
m = @autosize (3, 1) Flux.Bilinear(_ => 10)
@test randn(3, 4) |> m |> size == (10, 4)

m = @autosize (3,) SkipConnection(Dense(_ => _), Flux.Bilinear(_ => 10)) # Bilinear gets two inputs
@test randn(3, 4) |> m |> size == (10, 4)

@test_throws Exception @eval @autosize (3,) Flux.Bilinear((_,3) => 10)

# first docstring example
Expand All @@ -219,4 +222,23 @@ end
Dense(_ => 10),
) |> gpu # moves to GPU after initialisation
@test randn(Float32, img..., 1, 32) |> gpu |> m |> size == (10, 32)
end
end

@testset "LazyLayer" begin
# This is what `@autosize` uses, ideally nobody should make these by hand!
# Implicitly testeed by the macro, explicitly here too:
ld = Flux.LazyLayer("Dense(_ => 3, relu; init=??)", x -> Dense(Flux.autosizefor(Dense, x) => 3, relu, init=ones), nothing)

lm = Chain(ld, Flux.Scale(3))
@test string(ld) == "LazyLayer(Dense(_ => 3, relu; init=??))"
@test_throws Exception Flux.striplazy(lm)

@test lm([1,2]) == [3,3,3]

@test string(ld) == "LazyLayer(Dense(2 => 3, relu))"
@test Flux.striplazy(ld) isa Dense

@test_throws Exception Flux.params(lm)
@test_throws Exception gradient(x -> sum(abs2, lm(x)), [1,2])
@test_throws Exception gradient(m -> sum(abs2, Flux.striplazy(m)([1,2])), ld)
end

0 comments on commit 537b011

Please sign in to comment.