Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Code refactoring #1023

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft

[FIX] Code refactoring #1023

wants to merge 17 commits into from

Conversation

elephaint
Copy link
Contributor

@elephaint elephaint commented May 31, 2024

This is a work-in-progress draft PR and open for discussion. The main goal of the PR is to unify API across different model types, and unify loss functions across different loss types.

Refactoring:

  • Fuses BaseWindows, BaseMultivariate and BaseRecurrent into BaseModel, removing the need for separate classes and unifying model API across different model types. Instead, this PR introduces two model attributes, yielding four possible model options: RECURRENT (True/False) and MULTIVARIATE (True/False). We currently have a model for every combination except a recurrent multivariate model (e.g. a multivariate LSTM), however this is now relatively simple to add. In addition, this change allows to have models that can be recurrent or not, or multivariate or not on-the-fly, based on users' input. This also allows for easier modelling going forward.
  • Unifies model API across all models, adding missing input variables to all model types.
  • Refactors losses, a.o. removing unnecessary domain_map functions.
  • Moves loss.domain_map outside of models to BaseModel

Features:

  • All losses compatible with all types of models (e.g. univariate/multivariate, direct/recurrent) OR appropriate protection added.
  • DistributionLoss now supports the use of quantile in predict, allowing for easy quantile retrieval for all DistributionLosses.
  • Mixture losses (GMM, PMM and NBMM) now support learned weights for weighted mixture distribution outputs.
  • Mixture losses now support the use of quantile in predict, allowing for easy quantile retrieval.
  • Improved stability of ISQF by adding softplus protection around some parameters instead of using .abs

Bug fixes:

  • MASE loss now works.
  • Added various protections around parameter combinations that are invalid (e.g. regarding losses)
  • StudentT increase default DoF to 3 to reduce unbound variance issues.

Breaking changes:

  • Rewrite of all recurrent models to get rid of the quadratic (in the sequence dimension) space complexity. As a result, it is impossible to load a recurrent model from a previous version into this version.
  • TCN and DRNN are now windows models, not recurrent models.

Tests:

  • Added test_losses.py that can test model - loss - valid_loss combinations, for a variety of dataset settings.

Todo:

  • Unify / clean example use cases per model
  • Documentation update
    -- Readme: correct model segmentation
    -- Examples: restructure
  • Tests:
    -- Test all models on all loss / valid-loss combinations
    -- Test all models on speed/scaling as compared to current implementation across a set of datasets.
    -- Tests on multivariate models

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant