Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Choosing model serialization format(s) for cross-framework support (like HuggingFace) #1907

Closed
darsnack opened this issue Mar 14, 2022 · 9 comments

Comments

@darsnack
Copy link
Member

darsnack commented Mar 14, 2022

(continuing Slack discussion)

Julia to other frameworks

Since we are moving away from @save "weight.bson" params(model) in #1875, we should probably think about a recommended object to serialize that will also be friendly with other frameworks. State dicts (i.e. PyTorch) use the keys to encode some structural information like model["encoder/weight"]. The closest match to that in Flux right now would be using Functors to turn the model into a nested named tuple. This can be saved using your favorite Julia serializer, and also loaded into a Flux model with types via #1875. For interfacing with external sources like HuggingFace, a thin translation between nt[:encoder][:weight] and dict["encoder/weight"] is do-able.

It would be good to get a collection of other commonly used storage formats outside Julia/Flux to make sure we choose something that has the widest compatibility.

Other frameworks to Julia

In terms of Flux users being able to use models saved from other frameworks (e.g. downloaded from HuggingFace), ONNX is probably the ideal format here:

Simply loading and running an ONNX model is do-able with ONNXRuntime.jl. Translating an ONNX graph to Julia functions from e.g. NNlib is semi-do-able in ONNX.jl#master but not yet stable/released. Translating an ONNX graph to Flux layer types does not exist except maybe one-off examples (at least using SkipConnection etc...a sequence of Convs is of course easy).

cc @dfdx @DrChainsaw

@darsnack darsnack changed the title Best model serialization format cross-framework support (like HuggingFace) Choosing model serialization format(s) for cross-framework support (like HuggingFace) Mar 14, 2022
@DrChainsaw
Copy link
Contributor

On popular (well, one person) demand, I did register ONNXNaiveNASflux which has a couple of ONNX -> Flux layer translations here which I'd be happy to donate to some other package.

What is not Flux native is the computation graph format which instead comes from NaiveNASlib. I might be burned by my pre-JuMP attempts to align parameter shapes by traversing the graph, but I think that a general ONNX-Graph to Chain translation function might be hard to get right, but I haven't given it much thought.

The "impedance mismatch" as @ToucheSir called it is a thing in the sense that there are alot of valid ONNX operators which can't be turned into Flux layers (which is why I think the ONNX.jl approach is the right way if one has the ambition to cover the spec).

@dfdx
Copy link

dfdx commented Mar 15, 2022

general ONNX-Graph to Chain translation function might be hard to get right, but I haven't given it much thought.

Looking at initializers of distilbert-base-uncased from HuggingFace, I believe reconstructing Flux layers would require some heuristics, but certainly doable:

 "embeddings.word_embeddings.weight"
 "embeddings.position_embeddings.weight"
 "embeddings.LayerNorm.weight"
 "embeddings.LayerNorm.bias"
 "transformer.layer.0.attention.q_lin.bias"
 "transformer.layer.0.attention.k_lin.bias"
 "transformer.layer.0.attention.v_lin.bias"
 "transformer.layer.0.attention.out_lin.bias"
 "transformer.layer.0.sa_layer_norm.weight"
 "transformer.layer.0.sa_layer_norm.bias"
 "transformer.layer.0.ffn.lin1.bias"
 "transformer.layer.0.ffn.lin2.bias"
 "transformer.layer.0.output_layer_norm.weight"
 "transformer.layer.0.output_layer_norm.bias"
 "transformer.layer.1.attention.q_lin.bias"
 "transformer.layer.1.attention.k_lin.bias"
 "transformer.layer.1.attention.v_lin.bias"
 "transformer.layer.1.attention.out_lin.bias"
 "transformer.layer.1.sa_layer_norm.weight"
 "transformer.layer.1.sa_layer_norm.bias"
 "transformer.layer.1.ffn.lin1.bias"
 "transformer.layer.1.ffn.lin2.bias"
 "transformer.layer.1.output_layer_norm.weight"
 "transformer.layer.1.output_layer_norm.bias"
 "transformer.layer.2.attention.q_lin.bias"
 "transformer.layer.2.attention.k_lin.bias"
 "transformer.layer.2.attention.v_lin.bias"
 ⋮

Also, semi-automatic code generation might be a good approach. For example, given a set of structured names like aaa.weight, aaa.bias we can "guess" the underlying layer type and map inputs to it. There will definitely be mistakes, but it should be possible to go through the generated code of the most popular 50-100 models and make sure they are not too weird.

Once I finish the current task in Yota, I'm going to allocate one or two months purely for ONNX.jl, and now I know what will my next target model to support :)

@DrChainsaw
Copy link
Contributor

Hmm, looking at the distributions of “likes” in this thread makes me believe I have some grave misunderstanding of the topic, but I’ll make one more attempt. Sorry if this is just noise.

For example, given a set of structured names like aaa.weight, aaa.bias we can "guess" the underlying layer type and map inputs to it

Why would you need to rely on guesses/heuristics based on names? Can’t you just look at the OP-type and verify that everything which the corresponding Flux layer struct needs as parameters is present, either as an initializer or as a (possibly propagated) constant? This (minus constant propagation) is pretty much what ONNXNaiveNASflux does and it is pretty straight forward (at least compared to all the other more general headaches with ONNX import, such as row major vs col major).

I also don’t see the connection between this and the part about the Chain which was quoted above. The problem with Chain vs ONNX graph is that the former is strictly linear DAG where the output of one node must be input to the next. You can work around this ofc by adding blocks like Parallel and SkipConnection which have an internal branchy graph. I don’t disagree that using those and some heuristics one might be able to cover a large range of models. I do however get premature PSTD thinking about maintaining such heuristics (e.g. issue #454857: Failed to load model some/new/hot/architecture_v45667.onnx) but perhaps if one allows adding the appropriate blocks to Flux as a response it could be more manageable.

@dfdx
Copy link

dfdx commented Mar 27, 2022

Why would you need to rely on guesses/heuristics based on names? Can’t you just look at the OP-type and verify that everything which the corresponding Flux layer struct needs as parameters is present

Here I assume that we don't have a Flux model to map ONNX graph to, but only the graph itself, and we need to infer corresponding Flux layers purely from the initializers.

I agree with the argument about Chain and heuristics in general, but at the moment I don't see another way to quickly import thousands of models. So at least we can try and see how it goes. After all, uncertainty like this is exactly why we have breaking changes :)

@DrChainsaw
Copy link
Contributor

Here I assume that we don't have a Flux model to map ONNX graph to, but only the graph itself, and we need to infer corresponding Flux layers purely from the initializers.

But why not from the OP types? For example, optype=Conv means we can use Flux.Conv as long as we can find the second input (W) in the initializers or as a (possibly propagated) Constant and that all parameters have a Flux equivalent. If not, then one either needs to throw an error about the model not being a valid Flux model or fallback to some other implementation (e.g. the one in ONNX.jl) and accept that some parts of the model are not Flux layers.

@darsnack
Copy link
Member Author

darsnack commented Mar 27, 2022

For simple ops like convolution, this works, and it's exactly what ONNX.jl does on dev (except with NNlib instead of Flux). But the example above features MHA layers which has no ONNX op type. We'd rather map the full MHA layer to something from NeuralAttentionlib for performance. We also can't directly infer that this is MHA from the graph of op types cause it looks like batched matmul + normalization. But from the structured names, we can make a pretty good guess. So, of course, we want to make the 1-1 mapping when possible, but some cases might require some guess work.

@DrChainsaw
Copy link
Contributor

Ah, now it is clear to me. I read Flux layers as being the layers defined in Flux, including simple ones (e.g. Conv). Sorry for the noise.

@darsnack
Copy link
Member Author

Not noise, the whole reason for the issue is to discuss 😀

@mcabbott mcabbott mentioned this issue Feb 3, 2023
@CarloLucibello
Copy link
Member

I guess this can be closed now that we recommend to serialize the output of Flux.state

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants