Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to apply L2 regularization to a subset of parameters? #1284

Closed
jonathan-laurent opened this issue Jul 18, 2020 · 5 comments · Fixed by #1444
Closed

How to apply L2 regularization to a subset of parameters? #1284

jonathan-laurent opened this issue Jul 18, 2020 · 5 comments · Fixed by #1444

Comments

@jonathan-laurent
Copy link

jonathan-laurent commented Jul 18, 2020

When training a neural network with an L2 regularization, it is often advised not to regularize the bias parameters (in contrast with weight parameters).

I implemented this as follows in AlphaZero.jl:

regularized_params_(l) = []
regularized_params_(l::Flux.Dense) = [l.W]
regularized_params_(l::Flux.Conv) = [l.weight]

function foreach_flux_node(f::Function, x, seen = IdDict())
  Functors.isleaf(x) && return
  haskey(seen, x) && return
  seen[x] = true
  f(x)
  for child in Flux.trainable(x)
    foreach_flux_node(f, child, seen)
  end
end

function regularized_params(net::FluxNetwork)
  ps = Flux.Params()
  foreach_flux_node(net) do p
    for r in regularized_params_(p)
      any(x -> x === r, ps) || push!(ps, r)
    end
  end
  return ps
end

regularization_term(nn) = sum(sum(w .* w) for w in regularized_params(nn))

This feels a bit hackish though (and also it relies on internals and so it tends to break at every new Flux release).
Do you see any better way? Shouldn't we make this easier?

@CarloLucibello
Copy link
Member

maybe you can obtain some slight simplification using delete! instead

nonregularized_params_(net) = ....

function regularized_params(net::FluxNetwork)
  ps = Flux.params(net)
  for p in nonregularized_params_(net)
      delete!(ps, p)
  end
  return ps
end

Either you push! or delete!, you can avoid the presence check, it's done internally

@CarloLucibello
Copy link
Member

Maybe we can implement in Flux something similar to foreach_flux_node, like a modules function. What do you think?

@jonathan-laurent
Copy link
Author

I don't see how the delete! solution would work. Indeed, Flux.params returns a collection of arrays and there is no way to distinguish bias parameters from weight parameters in those arrays, right?

@CarloLucibello
Copy link
Member

sorry, forget what I said, I was being stupid. Yes, I don't see how to simplify this, besides making a modules function part of Flux.
Another option for you would be to filter out AbstractVector parameters and keep AbstractArrays, but that would not play well with BatchǸorm layers

@DhairyaLGandhi
Copy link
Member

Currently there isn't a simple way to filter out biases specifically. I can see this becoming a real need for bigger models. It will be a bit of a manual process with the current infrastructure since we currently don't distinguish between weights and biases as both are assumed to be parameters, but with defining a functor definition that splits these out would do the trick

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants