Skip to content

Introduce EntityEmbeddings#267

Merged
ablaom merged 42 commits into
devfrom
entity-embeddings
Sep 9, 2024
Merged

Introduce EntityEmbeddings#267
ablaom merged 42 commits into
devfrom
entity-embeddings

Conversation

@EssamWisam

@EssamWisam EssamWisam commented Aug 3, 2024

Copy link
Copy Markdown
Collaborator

Basic Description

This PR extends the NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor models of the MLJFlux.jl library such that:

  • These model now support tables with categorical columns. Iff any are present, an extra entity embedding layer is introduced after the input as described in the paper Entity Embeddings of Categorical Variables by Cheng Guo, Felix Berkhahn.

  • It's possible, after training any of such models, to transform the categorical columns of a new sample, seen or unseen in training for the purposes of encoding the categorical column for a further model or transformer in a pipeline

See the updated documentation to learn more.

Implementation Plan

The following was my plan for implementing this feature which is more nontrivial than it may seem in the first glance.

  • Implement EntityEmbedder layer and test it with theory-based example
  • Introduce the EntityEmbedder layer to MLJFlux with more formal tests that compare a mathematical implementation to the actual one
  • Introduce OrdinalEncoding with tests (as well as multi-column encoding transformer from MLJTransforms for later)
  • Introduce functionality to infer categorical variables, number of levels and prepare EntityEmbedder input
  • Integrate ordinal encoding and preparing categorical embedding inputs in MLJInterface.fit
  • Let MLJInterface.fit insert EntityEmbedder into the model chain when needed (input has categorical columns)
  • Likewise, adapt MLJInterface.update accordingly
  • Adapt the predict and fitresult for ordinal encoding and storing embedding matrices in classifier.jl and regressor.jl
  • Separate the case where there is no entity embedding in MLJInterface.fit (refactoring)
  • Make categorical embedder completely transparent when no categorical variables are there (instead of being there free of parameters)
  • Refactor code in MLJModelsInterface.jl for less redundancy and more organization
  • Use better default for the new dimensionality of EntityEmbedder
  • Expose new dimensionality argument in types.jl
  • Allow transform for each method that accesses embedding matrices of the EntityEmbedder
  • Better variable names
  • Modify documentation to take into account EntityEmbedder (plan to also make a tutorial(s) later)
  • Ensure existing tests pass with no problem
  • Finish writing tests for the majority if not all functional components introduced (mainly entity-embedding.jl, entity-embedding-utils.jl and encoders.jl
  • Finish writing some end-to-end tests for entity embeddings over the four models

@codecov

codecov Bot commented Aug 3, 2024

Copy link
Copy Markdown

Codecov Report

Attention: Patch coverage is 96.82540% with 8 lines in your changes missing coverage. Please review.

Project coverage is 96.48%. Comparing base (70dff6e) to head (e5d8141).
Report is 2 commits behind head on dev.

Files with missing lines Patch % Lines
src/entity_embedding_utils.jl 94.82% 3 Missing ⚠️
src/encoders.jl 97.01% 2 Missing ⚠️
src/fit_utils.jl 92.85% 2 Missing ⚠️
src/mlj_model_interface.jl 97.82% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##              dev     #267      +/-   ##
==========================================
+ Coverage   92.42%   96.48%   +4.06%     
==========================================
  Files          11       14       +3     
  Lines         330      512     +182     
==========================================
+ Hits          305      494     +189     
+ Misses         25       18       -7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@EssamWisam EssamWisam requested a review from ablaom August 6, 2024 00:10
Comment thread Project.toml Outdated
Comment thread src/MLJFlux.jl Outdated
Comment thread src/MLJFlux.jl Outdated
Comment thread src/types.jl
Comment thread src/types.jl
rng::Union{AbstractRNG, Int64}
optimiser_changes_trigger_retraining::Bool
acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()`
embedding_dims::Dict{Symbol, Real}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ablaom I handle differently depending on whether it's an integer or float as in the docs.

Comment thread src/types.jl Outdated
Comment thread src/types.jl Outdated
Comment thread src/types.jl
Comment thread src/types.jl Outdated
Comment thread src/types.jl
Comment thread src/mlj_model_interface.jl
Comment thread src/mlj_model_interface.jl Outdated
Co-authored-by: Anthony Blaom, PhD <anthony.blaom@gmail.com>
Comment thread src/mlj_model_interface.jl
Comment thread src/entity_embedding_utils.jl
Comment thread Project.toml Outdated

[targets]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "ScientificTypes", "Test"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think MLJBase, which depends on ScientificTypes, re-exports all the public ScientificTypes methods, so you may be able to dump it here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ablaom So replace using ScientificTypes: coerce, Multiclass, OrderedFactor with using MLJBase: coerce, Multiclass, OrderedFactor in the test file? This is minor because it's in the test only right and it shouldn't redownload already downloaded package?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think using MLJBase suffices. All those objects are re-exported.

Comment thread src/regressor.jl
"""
shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y))
is_embedding_enabled_type(::MultitargetNeuralNetworkRegressor) = true

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you are now acting on instances instead of types, I'd change the name of your trait from is_embedding_enabled_type to is_embedding_enabled, but this is just a suggestion.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing.

Comment thread src/entity_embedding_utils.jl Outdated
Comment thread src/entity_embedding_utils.jl Outdated
Comment thread src/entity_embedding_utils.jl Outdated
hasproperty(transformer, :embedding_dims) || return Xnew
ordinal_mappings, embedding_matrices = fitresult[3:4]
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
Xnew_transf = embedding_transform(Xnew, embedding_matrices)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do we have any test for transform(::MLJFluxModel, ...) ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think transform is tested in entity_embedding.jl

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see it's covered in Codecov.

Can I ask that this method (and corresponding test) be moved to "mlj_model_interface.jl", where the other implementations of MLJModelInterface methods (such as fit and predict) live?

Comment thread src/image.jl Outdated
Comment thread src/mlj_model_interface.jl Outdated

@ablaom ablaom left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the speedy response to my review.

A couple of minor points left. Note in particular the need to annotate the type of model (transformer) in MLJModelInterface.transform overloading. And we should have a test for this method for one of the models.

@EssamWisam

Copy link
Copy Markdown
Collaborator Author

@ablaom have addressed everything (except removing ScientificTypes from tests). I noticed that after adding categorical variables to integration tests, GPU testing fails :(. Will look into that when I get the chance.

@ablaom

ablaom commented Sep 1, 2024

Copy link
Copy Markdown
Collaborator

@EssamWisam You can try raising the tolerance near test/classifier.jl:114, which seems to be the fail.

@EssamWisam

Copy link
Copy Markdown
Collaborator Author

Hooray it worked 🎉 @ablaom

@EssamWisam EssamWisam requested a review from ablaom September 1, 2024 22:57

@ablaom ablaom left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates.

Just a couple more details, as flagged.

@EssamWisam EssamWisam requested a review from ablaom September 8, 2024 02:36
@EssamWisam

Copy link
Copy Markdown
Collaborator Author

@ablaom let's finalize this!

@ablaom ablaom left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again @EssamWisam for this valuable contribution.

Just made a few doc string tweaks after reviewing these again. Ready to go. 🎉

@ablaom ablaom merged commit 945016d into dev Sep 9, 2024
@ablaom ablaom deleted the entity-embeddings branch September 9, 2024 00:52
@ablaom ablaom mentioned this pull request Sep 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants