diff --git a/src/fdiff/utils/extraction.py b/src/fdiff/utils/extraction.py index ed62c73..bc2e6c8 100644 --- a/src/fdiff/utils/extraction.py +++ b/src/fdiff/utils/extraction.py @@ -6,7 +6,7 @@ from omegaconf import DictConfig, OmegaConf from fdiff.dataloaders.datamodules import Datamodule -from fdiff.models.score_models import MLPScoreModule, ScoreModule +from fdiff.models.score_models import LSTMScoreModule, MLPScoreModule, ScoreModule def get_training_params(datamodule: Datamodule, trainer: pl.Trainer) -> dict[str, Any]: @@ -70,6 +70,8 @@ def get_model_typle(cfg: DictConfig | dict) -> ScoreModule | MLPScoreModule: return ScoreModule case "fdiff.models.score_models.MLPScoreModule": return MLPScoreModule + case "fdiff.models.score_models.LSTMScoreModule": + return LSTMScoreModule case _: raise NotImplementedError(f"Model class {model_class} not implemented yet.")