Skip to content

Commit

Permalink
move after defn
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 27, 2021
1 parent c0a6ac6 commit 9fa587c
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,25 @@ end

# Param-style wrappers

"""
Params([A, B])
Container for implicit parameters, used when differentiating
a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`.
"""
struct Params
order::Buffer # {Any, Vector{Any}}
params::IdSet{Any} # TODO store ids only
end

Params() = Params(Buffer([], false), IdSet())
Params(xs) = Params(Buffer(xs, false), IdSet(xs))
Params(ps::Params) = ps
Params(xs::Tuple) = Params(collect(xs))

@forward Params.order Base.iterate, Base.length, Base.getindex
@forward Params.params Base.in

"""
gradient(() -> loss(), ps::Params) -> Grads
Expand Down Expand Up @@ -135,25 +154,6 @@ function withgradient(f, ps::Params)
(val = y, grad = back(sensitivity(y)))
end

"""
Params([A, B])
Container for implicit parameters, used when differentiating
a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`.
"""
struct Params
order::Buffer # {Any, Vector{Any}}
params::IdSet{Any} # TODO store ids only
end

Params() = Params(Buffer([], false), IdSet())
Params(xs) = Params(Buffer(xs, false), IdSet(xs))
Params(ps::Params) = ps
Params(xs::Tuple) = Params(collect(xs))

@forward Params.order Base.iterate, Base.length, Base.getindex
@forward Params.params Base.in

function Base.union!(ps::Params, itrs...)
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
return ps
Expand Down

0 comments on commit 9fa587c

Please sign in to comment.