In [1]:
]activate

[32m[1m  Activating[22m[39m environment at `~/.julia/environments/v1.6/Project.toml`


In [3]:
using FastAI
using Flux
using Zygote

## `ParamGroups`

In [60]:
struct ParamGroups
    map::IdDict
end
ParamGroups() = ParamGroups(IdDict())
Base.show(io::IO, ::ParamGroups) = print(io, "ParamGroups(...)")

In [61]:
getgroup(pg::ParamGroups, x::AbstractArray) = get(pg.map, x, nothing)

getgroup (generic function with 1 method)

In [62]:
function assigngroups!(pg::ParamGroups, grouper, m)
    for (group, m_) in group(grouper, m)
        for p in params(m_)
            pg.map[p] = group
        end
    end
end

assigngroups! (generic function with 1 method)

In [63]:
abstract type ParamGrouper end

struct IndexGrouper <: ParamGrouper
    idxs
end

group(grouper::IndexGrouper, m) = Dict(i => m[is] for (i, is) in enumerate(grouper.idxs))

group (generic function with 1 method)

In [64]:
function ParamGroups(grouper::ParamGrouper, m)
    pg = ParamGroups()
    assigngroups!(pg, grouper, m)
    return pg
end

ParamGroups

### Example

In [84]:
model = Chain(Dense(3, 5), Dense(5, 3))

Chain(Dense(3, 1), Dense(1, 3))

In [75]:
pg = ParamGroups(IndexGrouper([1, 2]), model)

ParamGroups(...)

In [76]:
for p in params(model)
    @show getgroup(pg, p)
end

getgroup(pg, p) = 1
getgroup(pg, p) = 1
getgroup(pg, p) = 2
getgroup(pg, p) = 2


## `DiscriminativeLR`

In [95]:
using Flux.Optimise
import Flux.Optimise: apply!

In [136]:
"""
    DiscriminativeLR(paramgroups, factors)

Use different learning rates based on `paramgroups`. `factors` maps
each group to a factor that the learning rate is multiplied by, so
for a parameter `x` the factor is `get(factors, getgroup(paramgroups, x), 1)`.

See [`ParamGroups`](#).

"""
struct DiscriminativeLR 
    pg::ParamGroups
    factors::Dict
end

DiscriminativeLR

In [130]:
function apply!(o::DiscriminativeLR, x, Δ::AbstractArray{T}) where T
    factor = convert(T, get(o.factors, getgroup(pg, x), one(T)))
    if factor == one(T)
        return Δ
    else
        @. Δ *= factor
    end
end

apply! (generic function with 20 methods)

## Examples

In [138]:
model = Chain(Dense(3, 5), Dense(5, 3))
pg = ParamGroups(IndexGrouper([1, 2]), model)

ParamGroups(...)

We map group 1 to a learning rate multiplier of `0`, so it is not trained, and group 2 to a multiplier of `1`, so it is trained regularly. With Flux's composable `Optimiser`, we can easily use this together with regular gradient descent. 

In [139]:
o = Optimiser(
    DiscriminativeLR(pg, Dict(1 => 0., 2 => 1.)),
    Descent(0.1)
)

Optimiser(Any[DiscriminativeLR(ParamGroups(...), Dict(2 => 1.0, 1 => 0.0)), Descent(0.1)])

In [140]:
xs, ys = rand(3, 1), rand(3, 1)
lossfn(xs, ys) = Flux.mse(model(x), ys)
ps = params(model)
gs = gradient(() -> lossfn(xs, ys), ps)

Grads(...)

In [141]:
for p in ps
    @show p[1]
end

p[1] = 0.078491025f0
p[1] = 0.0f0
p[1] = 0.17071337f0
p[1] = 0.0f0


In [142]:
Optimise.update!(o, ps, gs)

In [143]:
for p in ps
    @show p[1]
end

p[1] = 0.078491025f0
p[1] = 0.0f0
p[1] = 0.16609591f0
p[1] = 0.030807652f0


As you can see, only the parameters in group 2 were updated.