Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float16 compatibility #15

Closed
d-monnet opened this issue Jul 6, 2023 · 4 comments
Closed

Float16 compatibility #15

d-monnet opened this issue Jul 6, 2023 · 4 comments
Labels
question Further information is requested

Comments

@d-monnet
Copy link

d-monnet commented Jul 6, 2023

Hi there,

I would like to know if Float16 is supported. I followed this tutorial https://jso.dev/FluxNLPModels.jl/dev/tutorial/ and naively tried

w16 = Float16.(nlp.w)
obj(nlp,w16)

but got a Float32. Therefore I assume at least some computations are performed with Float32 when evaluating the objective. I also tried to modify the function getdata() as

function get_data(bs) 
   ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"

  # Loading Dataset
  xtrain, ytrain = MLDatasets.MNIST(Tx = Float16, split = :train)[:]
  xtest, ytest = MLDatasets.MNIST(Tx = Float16, split = :test)[:]
  .
  .
  .
end

but still got a Float32 when evaluating the objective.
Any idea how to run in Float16 (or any other format)?

@d-monnet d-monnet changed the title Unstable type with obj() Float16 compatibility Jul 6, 2023
@tmigot tmigot added the question Further information is requested label Jul 13, 2023
@d-monnet
Copy link
Author

d-monnet commented Jul 31, 2023

Ok I found the issue: obj calls set_var!() which is not type stable.
The issue comes from nlp.w .= new_w at

function set_vars!(nlp::AbstractFluxNLPModel{T, S}, new_w::AbstractVector{T}) where {T <: Number, S}
which is not type stable. The operator .= casts right hand side vector into left hand side vector's format.
For example:

x32 = ones(Float32,10)
x16 = ones(Float16,10)
x32 .= x16 # this is still a Float32

That is, even is the argument of obj is a Vector{Float16}, it is cast in whatever the parameter type S of FluxNLPModel{T,S} is.

@d-monnet
Copy link
Author

d-monnet commented Jul 31, 2023

Ok I found the issue: obj calls set_var!() which is not type stable. The issue comes from nlp.w .= new_w at

function set_vars!(nlp::AbstractFluxNLPModel{T, S}, new_w::AbstractVector{T}) where {T <: Number, S}

which is not type stable. The operator .= casts right hand side vector into left hand side vector's format.
For example:

x32 = ones(Float32,10)
x16 = ones(Float16,10)
x32 .= x16 # this is still a Float32

That is, even is the argument of obj is a Vector{Float16}, it is cast in whatever the parameter type S of FluxNLPModel{T,S} is.

In fact this is not even the bottom of the issue: Flux.destructure does not allow FP format modification via the restructure mechanism. From destructure documentation: "Such restoration follows the rules of ChainRulesCore.ProjectTo, and thus will restore floating point precision"
Since the restructure is called in set_var(), we're still can't allow fp format switch.
Any workaround would be welcomed!

@farhadrclass
Copy link
Collaborator

farhadrclass commented Oct 12, 2023

I can change the backend to change the model everytime

a quick change is as :

f64(m) = Flux.paramtype(Float64, m) # similar to https://github.com/FluxML/Flux.jl/blob/d21460060e055dca1837c488005f6b1a8e87fa1b/src/functor.jl#L217

then to change our model we use :

fluxnlp.model= f64(fluxnlp.model)

@farhadrclass
Copy link
Collaborator

Flux just recently added support for this
https://fluxml.ai/Flux.jl/stable/utilities/#Flux.f16

farhadrclass added a commit that referenced this issue Oct 30, 2023
farhadrclass added a commit that referenced this issue Oct 30, 2023
farhadrclass added a commit to Farhad-phd/FluxNLPModels.jl that referenced this issue Oct 31, 2023
farhadrclass added a commit to Farhad-phd/FluxNLPModels.jl that referenced this issue Oct 31, 2023
New Flux Update
## v0.14.0 (July 2023)
* Flux now requires julia v1.9 or later.
* CUDA.jl is not a hard dependency anymore. Support is now provided through the extension mechanism, by loading `using Flux, CUDA`.
  The package cuDNN.jl also needs to be installed in the environment. (You will get instructions if this is missing.)
* After a deprecations cycle, the macro `@epochs` and the functions `Flux.stop`, `Flux.skip`, `Flux.zeros`, `Flux.ones` have been removed.
farhadrclass added a commit that referenced this issue Nov 23, 2023
tmigot added a commit that referenced this issue Dec 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants