diff --git a/camel-dependencies/pom.xml b/camel-dependencies/pom.xml index 016d429f7ad43..46f700ca9c17f 100644 --- a/camel-dependencies/pom.xml +++ b/camel-dependencies/pom.xml @@ -176,7 +176,7 @@ 1.8.0 1.8.1 2.4.1 - 0.11.0 + 0.16.0 3.5.0 3.2.13 6.5.2 diff --git a/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java b/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java index af54b7131830b..ef665aaca04a4 100644 --- a/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java +++ b/components/camel-djl/src/test/java/org/apache/camel/component/djl/training/MnistTraining.java @@ -20,10 +20,10 @@ import java.io.IOException; import java.nio.file.Paths; -import ai.djl.Device; import ai.djl.Model; import ai.djl.basicdataset.cv.classification.Mnist; import ai.djl.basicmodelzoo.basic.Mlp; +import ai.djl.engine.Engine; import ai.djl.metric.Metrics; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; @@ -37,13 +37,9 @@ import ai.djl.training.loss.Loss; import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; // Helper to train mnist model for tests public final class MnistTraining { - private static final Logger LOG = LoggerFactory.getLogger(MnistTraining.class); - private static final String MODEL_DIR = "src/test/resources/models/mnist"; private static final String MODEL_NAME = "mlp"; @@ -62,9 +58,10 @@ public static void main(String[] args) throws IOException, TranslateException { RandomAccessDataset trainingSet = prepareDataset(Dataset.Usage.TRAIN, 64, Long.MAX_VALUE); RandomAccessDataset validateSet = prepareDataset(Dataset.Usage.TEST, 64, Long.MAX_VALUE); + final Engine engine = Engine.getInstance(); // setup training configuration DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .addEvaluator(new Accuracy()).optDevices(Device.getDevices(Device.getGpuCount())) + .addEvaluator(new Accuracy()).optDevices(engine.getDevices(engine.getGpuCount())) .addTrainingListeners(TrainingListener.Defaults.logging()); try (Trainer trainer = model.newTrainer(config)) { diff --git a/parent/pom.xml b/parent/pom.xml index 24b75185f7ab3..7197563842770 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -154,7 +154,7 @@ 0.15.1 3.4.4 3.5.0 - 0.11.0 + 0.16.0 1.8.0 1.8.1 2.4.1