Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot use trained gluonTS model #2985

Open
vm3538 opened this issue Feb 8, 2024 · 0 comments
Open

Cannot use trained gluonTS model #2985

vm3538 opened this issue Feb 8, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@vm3538
Copy link

vm3538 commented Feb 8, 2024

Description

After training a model in Python, I cannot use the trained PyTorch model with DJL.
I always get different exceptions when trying to fix code, but did not manage to make it work.

Error Message

Using the predict method:

Exception in thread "main" ai.djl.translate.TranslateException: ai.djl.engine.EngineException: Expected at most 7 argument(s) for operator 'forward', but received 8 argument(s). Declaration: forward(__torch__.gluonts.torch.model.deepar.lightning_module.DeepARLightningModule self, Tensor cat_feature_slice, Tensor feat_static_real, Tensor a, Tensor past_target, Tensor past_observed_values, Tensor future_time_feat) -> Tensor
	at ai.djl.inference.Predictor.batchPredict(Predictor.java:192)
	at ai.djl.inference.Predictor.predict(Predictor.java:129)
	at Main.predict(Main.java:50)
	at Main.main(Main.java:28)
Caused by: ai.djl.engine.EngineException: Expected at most 7 argument(s) for operator 'forward', but received 8 argument(s). Declaration: forward(__torch__.gluonts.torch.model.deepar.lightning_module.DeepARLightningModule self, Tensor cat_feature_slice, Tensor feat_static_real, Tensor a, Tensor past_target, Tensor past_observed_values, Tensor future_time_feat) -> Tensor
	at ai.djl.pytorch.jni.PyTorchLibrary.moduleRunMethod(Native Method)
	at ai.djl.pytorch.jni.IValueUtils.forward(IValueUtils.java:57)
	at ai.djl.pytorch.engine.PtSymbolBlock.forwardInternal(PtSymbolBlock.java:146)
	at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:79)
	at ai.djl.nn.Block.forward(Block.java:127)
	at ai.djl.inference.Predictor.predictInternal(Predictor.java:143)
	at ai.djl.inference.Predictor.batchPredict(Predictor.java:183)
	... 3 more

Using the predict1 method:

Exception in thread "main" ai.djl.translate.TranslateException: ai.djl.nn.UninitializedParameterException: The array for parameter "embedding" has not been initialized
	at ai.djl.inference.Predictor.batchPredict(Predictor.java:192)
	at ai.djl.inference.Predictor.predict(Predictor.java:129)
	at Main.predict1(Main.java:66)
	at Main.main(Main.java:29)
Caused by: ai.djl.nn.UninitializedParameterException: The array for parameter "embedding" has not been initialized
	at ai.djl.nn.Parameter.getArray(Parameter.java:133)
	at ai.djl.training.ParameterStore.getValue(ParameterStore.java:109)
	at ai.djl.timeseries.block.FeatureEmbedding.forwardInternal(FeatureEmbedding.java:58)
	at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:79)
	at ai.djl.timeseries.block.FeatureEmbedder.forwardInternal(FeatureEmbedder.java:73)
	at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:79)
	at ai.djl.nn.Block.forward(Block.java:127)
	at ai.djl.timeseries.model.deepar.DeepARNetwork.unrollLaggedRnn(DeepARNetwork.java:219)
	at ai.djl.timeseries.model.deepar.DeepARPredictionNetwork.forwardInternal(DeepARPredictionNetwork.java:47)
	at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:79)
	at ai.djl.nn.Block.forward(Block.java:127)
	at ai.djl.inference.Predictor.predictInternal(Predictor.java:143)
	at ai.djl.inference.Predictor.batchPredict(Predictor.java:183)
	... 3 more

How to Reproduce?

You can comment out the first predict to get a different exception.

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.model.deepar.DeepARNetwork;
import ai.djl.timeseries.translator.DeepARTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Main {
    public static void main(String[] args) throws TranslateException, MalformedModelException, IOException, ModelNotFoundException {
        predict();
        predict1();
    }

    private static void predict() throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        Path modelPath = Paths.get("data/model.pt");
        Translator<TimeSeriesData, Forecast> translator = getTranslator();
        Criteria<TimeSeriesData, Forecast> criteria = Criteria.builder()
                .setTypes(TimeSeriesData.class, Forecast.class)
                .optModelPath(modelPath)
                .optTranslator(translator)
                .optEngine("PyTorch")
                .optDevice(Device.cpu())
                .optArgument("freq", "M")
                .optArgument("prediction_length", 12)
                .optArgument("num_feat_dynamic_real", 0)
                .optArgument("num_feat_static_real", 0)
                .optArgument("num_feat_static_cat", 0)
                .build();
        try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel(); Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor(translator, Device.cpu())) {
            try (NDManager manager = model.getNDManager()) {
                TimeSeriesData input = getInput(manager);
                Forecast predict = predictor.predict(input);
                System.out.println("prediction: " + predict);
            }
        }
    }

    private static void predict1() throws MalformedModelException, IOException, TranslateException {
        try (Model model = Model.newInstance("data/model.pt")) {
            DeepARNetwork preconditionNetwork = getDeepARModel();
            model.setBlock(preconditionNetwork);
            Path path = Paths.get("data/model.pt");
            Map<String, String> parameters = Map.of("hasParameter", "false");
            model.load(path, null, parameters);

            DeepARTranslator translator = getTranslator();
            try (Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor(translator)) {
                Forecast predict = predictor.predict(getInput(model.getNDManager()));
                System.out.println("prediction: " + predict);
            }
        }
    }

    private static DeepARNetwork getDeepARModel() {
        DeepARNetwork.Builder builder = DeepARNetwork.builder()
                .setCardinality(List.of(1))//Required, else we get an exception
                .setFreq("M")
                .setPredictionLength(12)
                .optUseFeatDynamicReal(false)
                .optUseFeatStaticCat(false)
                .optUseFeatStaticReal(false);
        return builder.buildPredictionNetwork();
    }

    private static DeepARTranslator getTranslator() {
        Map<String, Object> args = new HashMap<>();
        args.put("freq", "M");
        args.put("prediction_length", 12);
        args.put("num_feat_dynamic_real", 0);
        args.put("num_feat_static_real", 0);
        args.put("num_feat_static_cat", 0);
        return DeepARTranslator.builder(args).build();
    }

    private static TimeSeriesData getInput(NDManager manager) {
        TimeSeriesData input = new TimeSeriesData(1);
        input.setStartTime(LocalDateTime.now());
        input.setField(FieldName.TARGET, manager.zeros(new Shape(1, 48)));
        return input;
    }
}

Steps to reproduce

  1. Put the trained model (model.pt) inside the data folder.
  2. Run the Main.java

What have you tried to solve it?

I tried to read some examples, to debug it, but I found nothing interesting.
The documentation seems outdated, as the Python code to export models needs to be modified to be able to export a model.

Environment Info

Running on Windows.
I used jars downloaded from Maven:

  • api-0.26.0.jar
  • basicdataset-0.26.0.jar
  • pytorch-engine-0.26.0.jar
  • pytorch-jni-1.13.1-0.26.0.jar
  • pytorch-native-cpu-1.13.1-win-x86_64.jar
  • pytorch-native-cpu-1.13.1.jar
  • timeseries-0.26.0.jar
@vm3538 vm3538 added the bug Something isn't working label Feb 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant