Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Introduction of equivariant DAE and DPNet models. #11

Open
wants to merge 16 commits into
base: dev
Choose a base branch
from

Conversation

Danfoa
Copy link

@Danfoa Danfoa commented Mar 31, 2024

This pull request is currently a Work in Progress (WIP), and I'm actively seeking feedback. Early insights into the design choices being made are crucial. This proactive approach allows for discussions, modifications, or rejections of these choices before significant effort is expended. This code is built upon the previous pull request #9

Objective

The primary goal of this update is to introduce Equivariant Versions of DPNet and DAE models. To achieve this, some restructuration is needed to boost modularity and introduce enhancements to boost the performance of the deep-learning-based models (e.g., proper initialization of trainable evolution operators). To achieve this, I am trying to establish the structure that facilitates the use and extension of AE-based models and representation-based models (e.g., DPNets) within a unified structure.

Introducing New Base Classes

To enable the modular implementation of DAE, e-DAE, DPNet, and e-DPNet models within the existing kooplearn codebase, I am introducing two main classes:

  • LatentBaseModel: Abstract class child of BaseModel. This class introduces the template functions methods necessary to be implemented by latent variable models. Namely, encode_context, decode_context, evolve_context, compute_loss_and_metrics. Additionally, it unifies generic code that need not be implemented by new model instances, such as load, save, some data management routines, etc.

    View LatentBaseModel class

  • LightningLatentModel: This generic Lightning Module is designed to train instances of LatentBaseModel. It simplifies logging, optimizers, and the training/validation/testing steps. The overarching goal is to minimize, if not eliminate, the need for users to define a custom LightningModule. That is, to limit the interface of the user to the instances of LatentBaseModel. In case some custom behavior for training is needed, extending this class and adjusting the relevant hook/callback will suffice. Thus, here we unify recurrent code needed for all latent models

    View LightningLatentModel Code

Codebase Adaptation

This effort is an adaptation of the codebase from the DynamicsHarmonicsAnalysis repository, where a single LightningModule class and Trainer were previously utilized to fit DAE, e-DAE, DPNet, and e-DPNet models. The goal is to distill this generic Lightning module to its essence and tailor it to the kooplearn API framework.

We introduce some base classes for latent models of markov dynamics. This classes enable us to define a generic LightingModel that deals with the fitting of latent variable models. Thus merging all the generic code required to train a latent variable model  (DAE, DPNets) into a single flexible class, which should only be extended for very specific scenarios.

This class appropriatedly handles logging of metrics and losses from training/validation/testing runs, which was previously not handled (only training metrics were being logged.

Work in progress.
Update functions processing eigenvalues required during mode decomposition.

The functions allow to get a vector of (unsorted) eigenvalues, and cluster them into real, and complex conjugate pairs, returing additionally the original indices of the eigenvalues, to relate them with the appropriate eigenvectors.

For DAE and LatentModels we  moved the eig method to the LatentBaseModel class, as for any latent model with an evolution operator, the eig function is generic.
This commit provides the functionality to perform the mode decomposition using a  `LatentBaseModel` instance.

The mode decomposition is applied to the evolution operator in latent space, and then a *linear* reprojection to the state space, ensures the decomposition of latent modes, can be transfered to the decomposition of state modes.
This commit introduces:

 - Modifications in the forcasting of latent states after a LatentBaseModel instance is fitted. Once fitted, the eigval decomposition of the evolution operator is used for forecasting.
 -   Utility plots to show the dynamics of the eigenfunctions in complex plane and Real part vs time in ModesInfo data class.
This commit introduces the equivariant Dynamics Autoencoder, a child class of Dynamics autoencoder with equivariant constraints.

TODO:
 - Mode decomposition needs to be performed for each Isotypic space.
 - Linear decoder appears to be broken in some of the previous commits, need fixing.
In this state of the code the dynamic mode decomposition using the DAE architecture is operational.

The mode decomposition hinges on the use of the linear decoder to map states from the latent space Z to the original state space X. The prediction performance of the DAE model using the linear decoder hinges strongly on the learned latent space. DAE architectures are very sensible to the initialization of the evolution operator, leading to poor prediction performance in the latent space (which in practice is mitigated by the non-linear decoder).

To get good performance, the user needs to initialize the evolution operator such that the frequencies of oscilation of the eigenvalues of the operator span a reasonable space of frequencies present in the dynamics of X.
This commit introduces the Equivariant Linear decoder for the E-DAE architecture. To introduce this change, we made the following changes:

- the Linear decoder of both DAE and E-DAE architectures is defined as a `torch.nn.Linear` instance. Such that we have more flexibility in the use of bias terms, device management by lightning, loading and saving, etc.
- We made the encoder and decoder (both `torch.nn.Modules`) arguments of `LatentBaseModel.encode_context` and `LatentBaseModel.decode_context` methods. This enable us to flexibility select between the Linear and Non-Linear decoders for AE-based architectures. Which seems reasonable when reconstruction error is crucial, at least for DAE. For E-DAE the linear decoder appears to feature better reconstruction error (in the robot case example). 
- TODO: The equivariant linear decoder of E-DAE should be made an `escnn.nn.Linear` layer, but this should take some hacky play with the the BasisManager class.
- TODO: The E-DAE.modes method should return an iterable of ModesInfo data classes per isotypic subspace.
Modifies the eig method of the E-DAE model to exploit the sparsity of the evolution operator.
This commit, enables the use of a learnable linear decoder for E-DAE and DAE architectures, which results in optimal performance when the objective is to perform mode decomposition.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant