From b86a6cb5db1cfcb440724528eb1c7890d72fb012 Mon Sep 17 00:00:00 2001 From: Mathieu Ouillon Date: Wed, 8 Jan 2025 16:16:22 -0500 Subject: [PATCH 1/6] Fix DJL model loading error caused by missing PyTorch dependencies This commit resolves an issue where DJL (Deep Java Library) attempts to download PyTorch dependencies during runtime. Added explicit dependency management for DJL version 0.30.0, including the following modules: - pytorch-native-cpu - pytorch-jni - pytorch-engine --- parent/pom.xml | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/parent/pom.xml b/parent/pom.xml index 9de1ce0e22..f75ed4666a 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -18,6 +18,19 @@ + + + + + ai.djl + bom + 0.30.0 + pom + import + + + + @@ -52,12 +65,40 @@ 0.30.0 compile + ai.djl.pytorch pytorch-model-zoo 0.30.0 + + ai.djl.pytorch + pytorch-engine + 0.30.0 + runtime + + + + ai.djl.pytorch + pytorch-native-cpu + linux-x86_64 + runtime + 2.4.0 + + + + ai.djl.pytorch + pytorch-jni + 2.4.0-0.30.0 + runtime + + + + ai.djl + api + + From 2a30edddd7310249709b99aa21b4c67be268226b Mon Sep 17 00:00:00 2001 From: Mathieu Ouillon Date: Wed, 8 Jan 2025 16:38:40 -0500 Subject: [PATCH 2/6] Move AI model from reconstruction/alert to etc/nnet. Use CLASResources class to get the path of the model. --- .../nnet/ALERT/model_AHDC/model_AHDC.pt | Bin .../main/java/org/jlab/rec/service/AHDCEngine.java | 6 ++++-- 2 files changed, 4 insertions(+), 2 deletions(-) rename reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt => etc/nnet/ALERT/model_AHDC/model_AHDC.pt (100%) diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt b/etc/nnet/ALERT/model_AHDC/model_AHDC.pt similarity index 100% rename from reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt rename to etc/nnet/ALERT/model_AHDC/model_AHDC.pt diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java index ed27d104e9..af019ba4f1 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java @@ -38,6 +38,7 @@ import org.jlab.rec.ahdc.PreCluster.PreCluster; import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; import org.jlab.rec.ahdc.Track.Track; +import org.jlab.utils.CLASResources; import java.io.File; import java.io.IOException; @@ -80,9 +81,10 @@ public NDList processInput(TranslatorContext translatorContext, float[] floats) return new NDList(samples); } }; - + + String path = CLASResources.getResourcePath("etc/nnet/ALERT/model_AHDC/"); Criteria my_model = Criteria.builder().setTypes(float[].class, Float.class) - .optModelPath(Paths.get(System.getenv("CLAS12DIR") + "/../reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/")) + .optModelPath(Paths.get(path)) .optEngine("PyTorch") .optTranslator(my_translator) .optProgress(new ProgressBar()) From fe1b8c63b5faeeab485a7565f8d44972f8e32112 Mon Sep 17 00:00:00 2001 From: Mathieu Ouillon Date: Wed, 8 Jan 2025 16:57:03 -0500 Subject: [PATCH 3/6] Refactor: Encapsulate model preparation and loading into a dedicated class Moved the preparation of the model and loading logic into a specific class to encapsulate functionality and reduce the number of imports. --- .../main/java/org/jlab/rec/ahdc/AI/Model.java | 58 ++++++++++++++++++ .../java/org/jlab/rec/service/AHDCEngine.java | 59 ++----------------- 2 files changed, 63 insertions(+), 54 deletions(-) create mode 100644 reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java new file mode 100644 index 0000000000..7c9f26c755 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java @@ -0,0 +1,58 @@ +package org.jlab.rec.ahdc.AI; + +import ai.djl.MalformedModelException; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import org.jlab.utils.CLASResources; + +import java.io.IOException; +import java.nio.file.Paths; + +public class Model { + private ZooModel model; + + public Model() { + Translator my_translator = new Translator() { + @Override + public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception { + return ndList.get(0).getFloat(); + } + + @Override + public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception { + NDManager manager = NDManager.newBaseManager(); + NDArray samples = manager.zeros(new Shape(floats.length)); + samples.set(floats); + return new NDList(samples); + } + }; + + String path = CLASResources.getResourcePath("etc/nnet/ALERT/model_AHDC/"); + Criteria my_model = Criteria.builder().setTypes(float[].class, Float.class) + .optModelPath(Paths.get("etc/nnet/ALERT/model_AHDC/")) + .optEngine("PyTorch") + .optTranslator(my_translator) + .optProgress(new ProgressBar()) + .build(); + + + try { + model = my_model.loadModel(); + } catch (IOException | ModelNotFoundException | MalformedModelException e) { + throw new RuntimeException(e); + } + + } + + public ZooModel getModel() { + return model; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java index af019ba4f1..7eb6570075 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java @@ -1,29 +1,12 @@ package org.jlab.rec.service; -import ai.djl.MalformedModelException; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.Shape; -import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.repository.zoo.ZooModel; -import ai.djl.training.util.ProgressBar; -import ai.djl.translate.TranslateException; -import ai.djl.translate.Translator; -import ai.djl.translate.TranslatorContext; import org.jlab.clas.reco.ReconstructionEngine; import org.jlab.clas.tracking.kalmanfilter.Material; import org.jlab.io.base.DataBank; import org.jlab.io.base.DataEvent; import org.jlab.io.hipo.HipoDataSource; import org.jlab.io.hipo.HipoDataSync; -import org.jlab.jnp.hipo4.data.SchemaFactory; -import org.jlab.rec.ahdc.AI.AIPrediction; -import org.jlab.rec.ahdc.AI.PreClustering; -import org.jlab.rec.ahdc.AI.PreclusterSuperlayer; -import org.jlab.rec.ahdc.AI.TrackConstruction; -import org.jlab.rec.ahdc.AI.TrackPrediction; +import org.jlab.rec.ahdc.AI.*; import org.jlab.rec.ahdc.Banks.RecoBankWriter; import org.jlab.rec.ahdc.Cluster.Cluster; import org.jlab.rec.ahdc.Cluster.ClusterFinder; @@ -38,11 +21,8 @@ import org.jlab.rec.ahdc.PreCluster.PreCluster; import org.jlab.rec.ahdc.PreCluster.PreClusterFinder; import org.jlab.rec.ahdc.Track.Track; -import org.jlab.utils.CLASResources; import java.io.File; -import java.io.IOException; -import java.nio.file.Paths; import java.util.*; public class AHDCEngine extends ReconstructionEngine { @@ -51,7 +31,7 @@ public class AHDCEngine extends ReconstructionEngine { private boolean use_AI_for_trackfinding; private String findingMethod; private HashMap materialMap; - private ZooModel model; + private Model model; public AHDCEngine() { super("ALERT", "ouillon", "1.0.1"); @@ -67,36 +47,7 @@ public boolean init() { materialMap = MaterialMap.generateMaterials(); } - Translator my_translator = new Translator() { - @Override - public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception { - return ndList.get(0).getFloat(); - } - - @Override - public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception { - NDManager manager = NDManager.newBaseManager(); - NDArray samples = manager.zeros(new Shape(floats.length)); - samples.set(floats); - return new NDList(samples); - } - }; - - String path = CLASResources.getResourcePath("etc/nnet/ALERT/model_AHDC/"); - Criteria my_model = Criteria.builder().setTypes(float[].class, Float.class) - .optModelPath(Paths.get(path)) - .optEngine("PyTorch") - .optTranslator(my_translator) - .optProgress(new ProgressBar()) - .build(); - - - try { - model = my_model.loadModel(); - } catch (IOException | ModelNotFoundException | MalformedModelException e) { - throw new RuntimeException(e); - } - + model = new Model(); return true; } @@ -182,8 +133,8 @@ public int compare(Hit a1, Hit a2) { try { AIPrediction aiPrediction = new AIPrediction(); - predictions = aiPrediction.prediction(tracks, model); - } catch (ModelNotFoundException | MalformedModelException | IOException | TranslateException e) { + predictions = aiPrediction.prediction(tracks, model.getModel()); + } catch (Exception e) { throw new RuntimeException(e); } From d2de369b5a0473e905e6fbaaab356899c54ec0e7 Mon Sep 17 00:00:00 2001 From: MathieuOuillon Date: Thu, 9 Jan 2025 15:06:51 -0500 Subject: [PATCH 4/6] Update the wires checked during pre-clustering to ensure the AI can handle cases at the border. --- .../main/java/org/jlab/rec/ahdc/AI/PreClustering.java | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java index 17cb3f365f..c931a57aac 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java @@ -55,12 +55,19 @@ public ArrayList find_preclusters_for_AI(List AHDC_hits) { ArrayList temp = new ArrayList<>(); temp.add(hit); hit.setUse(true); + int expected_wire_plus = hit.getWireId() + 1; + int expected_wire_minus = hit.getWireId() - 1; + if (hit.getWireId() == 1) + expected_wire_minus = hit.getNbOfWires(); + if (hit.getWireId() == hit.getNbOfWires() ) + expected_wire_plus = 1; + boolean has_next = true; while (has_next) { has_next = false; for (Hit hit1 : p) { - if (hit1.is_NoUsed() && (hit1.getWireId() == temp.get(temp.size() - 1).getWireId() + 1 || hit1.getWireId() == temp.get(temp.size() - 1).getWireId() - 1)) { + if (hit1.is_NoUsed() && (hit1.getWireId() == expected_wire_minus || hit1.getWireId() == expected_wire_plus)) { temp.add(hit1); hit1.setUse(true); has_next = true; From 29e827a1a08188ded18b659c622831591c72a106 Mon Sep 17 00:00:00 2001 From: MathieuOuillon Date: Thu, 9 Jan 2025 15:19:06 -0500 Subject: [PATCH 5/6] Fix Math.atan2: Swap x and y components. --- .../alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java index 10d2351ab6..8a968ba512 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java @@ -48,7 +48,7 @@ public Track(ArrayList hitslist) { double p = 150.0;//MeV/c //take first hit. Hit hit = hitslist.get(0); - double phi = Math.atan2(hit.getX(), hit.getY()); + double phi = Math.atan2(hit.getY(), hit.getX()); //hitslist. this.px0 = p*Math.sin(phi); this.py0 = p*Math.cos(phi); From 399ff29b6e792c8e65c412337d95aaa50a57da9d Mon Sep 17 00:00:00 2001 From: MathieuOuillon Date: Thu, 9 Jan 2025 15:29:48 -0500 Subject: [PATCH 6/6] Create a separate comparator for Hit class to enable sorting by phi, as it can't be done within the class itself. --- .../org/jlab/rec/ahdc/AI/PreClustering.java | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java index c931a57aac..ad88fc1f88 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java @@ -32,14 +32,21 @@ public ArrayList find_preclusters_for_AI(List AHDC_hits) { ArrayList s5l1 = fill(AHDC_hits, 5, 1); // Sort hits of each layers by phi: - s1l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s2l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s2l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s3l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s3l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s4l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s4l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); - s5l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + Comparator comparator = new Comparator<>() { + @Override + public int compare(Hit a1, Hit a2) { + return Double.compare(a1.getPhi(), a2.getPhi()); + } + }; + + s1l1.sort(comparator); + s2l1.sort(comparator); + s2l2.sort(comparator); + s3l1.sort(comparator); + s3l2.sort(comparator); + s4l1.sort(comparator); + s4l2.sort(comparator); + s5l1.sort(comparator); ArrayList> all_super_layer = new ArrayList<>(Arrays.asList(s1l1, s2l1, s2l2, s3l1, s3l2, s4l1, s4l2, s5l1));